Commit 06d728fc authored by Reimar Leike's avatar Reimar Leike
Browse files

Implemented Fisher tests for Multi Fields

parent f59c284f
......@@ -23,6 +23,7 @@ import nifty7 as ift
from ..common import list2fixture, setup_function, teardown_function
spaces = [ift.GLSpace(5),
ift.MultiDomain.make({'': ift.RGSpace(5, distances=.789)}),
(ift.RGSpace(3, distances=.789), ift.UnstructuredDomain(2))]
pmp = pytest.mark.parametrize
field = list2fixture([ift.from_random(sp, 'normal') for sp in spaces] +
......@@ -58,18 +59,32 @@ def test_complex2real():
assert np.all((f == op(op.adjoint_times(f))).val)
def energy_tester_complex(pos, get_noisy_data, energy_initializer):
op = _complex2real(pos.domain)
def energy_tester(pos, get_noisy_data, energy_initializer):
if isinstance(pos, ift.Field):
if np.issubdtype(pos.dtype, np.complexfloating):
op = _complex2real(pos.domain)
else:
op = ift.ScalingOperator(pos.domain, 1.)
else:
ops = []
for k,dom in pos.domain.items():
if np.issubdtype(pos[k].dtype, np.complexfloating):
ops.append(_complex2real(dom).ducktape(k).ducktape_left(k))
else:
FA = ift.FieldAdapter(dom, k)
ops.append(FA.adjoint @ FA)
realizer = ift.utilities.my_sum(ops)
from nifty7.operator_spectrum import _DomRemover
flattener = _DomRemover(realizer.target)
op = flattener @ realizer
npos = op(pos)
nget_noisy_data = lambda mean: get_noisy_data(op.adjoint_times(mean))
nenergy_initializer = lambda mean: energy_initializer(mean) @ op.adjoint
energy_tester(npos, nget_noisy_data, nenergy_initializer)
_actual_energy_tester(npos, nget_noisy_data, nenergy_initializer)
def energy_tester(pos, get_noisy_data, energy_initializer):
if np.issubdtype(pos.dtype, np.complexfloating):
energy_tester_complex(pos, get_noisy_data, energy_initializer)
return
def _actual_energy_tester(pos, get_noisy_data, energy_initializer):
domain = pos.domain
test_vec = ift.from_random(domain, 'normal')
results = []
......@@ -101,9 +116,9 @@ def test_GaussianEnergy(field):
def test_PoissonEnergy(field):
if not isinstance(field, ift.Field):
return
pytest.skip("MultiField Poisson energy not supported")
if np.iscomplexobj(field.val):
return
pytest.skip("Poisson energy not defined for complex flux")
get_noisy_data = lambda mean: ift.makeField(mean.domain, np.random.poisson(mean.val))
# Make rate positive and high enough to avoid bad statistic
lam = 10*(field**2).clip(0.1, None)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment