diff --git a/test/test_operators/test_fisher_metric.py b/test/test_operators/test_fisher_metric.py index f6a8066e23f8c15553112c671048424b9cb117aa..6700bd022619137d3426923d9f43d5f359260f23 100644 --- a/test/test_operators/test_fisher_metric.py +++ b/test/test_operators/test_fisher_metric.py @@ -23,6 +23,7 @@ import nifty7 as ift from ..common import list2fixture, setup_function, teardown_function spaces = [ift.GLSpace(5), + ift.MultiDomain.make({'': ift.RGSpace(5, distances=.789)}), (ift.RGSpace(3, distances=.789), ift.UnstructuredDomain(2))] pmp = pytest.mark.parametrize field = list2fixture([ift.from_random(sp, 'normal') for sp in spaces] + @@ -58,18 +59,32 @@ def test_complex2real(): assert np.all((f == op(op.adjoint_times(f))).val) -def energy_tester_complex(pos, get_noisy_data, energy_initializer): - op = _complex2real(pos.domain) +def energy_tester(pos, get_noisy_data, energy_initializer): + if isinstance(pos, ift.Field): + if np.issubdtype(pos.dtype, np.complexfloating): + op = _complex2real(pos.domain) + else: + op = ift.ScalingOperator(pos.domain, 1.) + else: + ops = [] + for k,dom in pos.domain.items(): + if np.issubdtype(pos[k].dtype, np.complexfloating): + ops.append(_complex2real(dom).ducktape(k).ducktape_left(k)) + else: + FA = ift.FieldAdapter(dom, k) + ops.append(FA.adjoint @ FA) + realizer = ift.utilities.my_sum(ops) + from nifty7.operator_spectrum import _DomRemover + flattener = _DomRemover(realizer.target) + op = flattener @ realizer + npos = op(pos) nget_noisy_data = lambda mean: get_noisy_data(op.adjoint_times(mean)) nenergy_initializer = lambda mean: energy_initializer(mean) @ op.adjoint - energy_tester(npos, nget_noisy_data, nenergy_initializer) + _actual_energy_tester(npos, nget_noisy_data, nenergy_initializer) -def energy_tester(pos, get_noisy_data, energy_initializer): - if np.issubdtype(pos.dtype, np.complexfloating): - energy_tester_complex(pos, get_noisy_data, energy_initializer) - return +def _actual_energy_tester(pos, get_noisy_data, energy_initializer): domain = pos.domain test_vec = ift.from_random(domain, 'normal') results = [] @@ -101,9 +116,9 @@ def test_GaussianEnergy(field): def test_PoissonEnergy(field): if not isinstance(field, ift.Field): - return + pytest.skip("MultiField Poisson energy not supported") if np.iscomplexobj(field.val): - return + pytest.skip("Poisson energy not defined for complex flux") get_noisy_data = lambda mean: ift.makeField(mean.domain, np.random.poisson(mean.val)) # Make rate positive and high enough to avoid bad statistic lam = 10*(field**2).clip(0.1, None)