From b99e06c0d96060c0ca1e3f98cc5bed42ad694d88 Mon Sep 17 00:00:00 2001 From: Reimar Leike <reimar@mpa-garhcing.mpg.de> Date: Wed, 17 Jun 2020 18:19:41 +0200 Subject: [PATCH] Adjusted fisher test to always make Fisher matrices reaL, code for GaussianEnergy was reverted --- nifty6/operators/energy_operators.py | 13 +--------- test/test_operators/test_fisher_metric.py | 31 ++++++++++++++++++++++- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/nifty6/operators/energy_operators.py b/nifty6/operators/energy_operators.py index 21949965b..627c09eaa 100644 --- a/nifty6/operators/energy_operators.py +++ b/nifty6/operators/energy_operators.py @@ -244,18 +244,7 @@ class GaussianEnergy(EnergyOperator): self._met = inverse_covariance if sampling_dtype is not None: self._met = SamplingDtypeSetter(self._met, sampling_dtype) - if isinstance(sampling_dtype, dict): - from .sandwich_operator import SandwichOperator - scale = {k:np.sqrt(2.) if np.issubdtype(v, np.complexfloating) - else 1. for k,v in sampling_dtype.items()} - scale = _build_MultiScalingOperator(self._domain, scale) - self._met = SandwichOperator.make(scale, self._met) - else: - if np.issubdtype(sampling_dtype, np.complexfloating): - from .sandwich_operator import SandwichOperator - scale = ScalingOperator(self._met.domain,np.sqrt(2)) - self._met = SandwichOperator.make(scale, self._met) - + def _checkEquivalence(self, newdom): newdom = makeDomain(newdom) if self._domain is None: diff --git a/test/test_operators/test_fisher_metric.py b/test/test_operators/test_fisher_metric.py index 46edccca8..bacb89244 100644 --- a/test/test_operators/test_fisher_metric.py +++ b/test/test_operators/test_fisher_metric.py @@ -23,7 +23,6 @@ import nifty6 as ift from ..common import list2fixture, setup_function, teardown_function spaces = [ift.GLSpace(5), - ift.MultiDomain.make({'': ift.RGSpace(5, distances=.789)}), (ift.RGSpace(3, distances=.789), ift.UnstructuredDomain(2))] pmp = pytest.mark.parametrize field = list2fixture([ift.from_random(sp, 'normal') for sp in spaces] + @@ -38,7 +37,34 @@ def _to_array(d): assert isinstance(d, dict) return np.concatenate(list(d.values())) +def _complex2real(sp): + tup = tuple([d for d in sp]) + rsp = ift.DomainTuple.make((ift.UnstructuredDomain(2),) + tup) + rl = ift.DomainTupleFieldInserter(rsp, 0, (0,)) + im = ift.DomainTupleFieldInserter(rsp, 0, (1,)) + x = ift.ScalingOperator(sp, 1) + return rl(x.real)+im(x.imag) + +def test_complex2real(): + sp = ift.UnstructuredDomain(3) + op = _complex2real(ift.makeDomain(sp)) + f = ift.from_random(op.domain, 'normal', dtype=np.complex128) + assert np.all((f == op.adjoint_times(op(f))).val) + assert op(f).dtype == np.float64 + f = ift.from_random(op.target, 'normal') + assert np.all((f == op(op.adjoint_times(f))).val) + +def energy_tester_complex(pos, get_noisy_data, energy_initializer): + op = _complex2real(pos.domain) + npos = op(pos) + nget_noisy_data = lambda mean : get_noisy_data(op.adjoint_times(mean)) + nenergy_initializer = lambda mean : energy_initializer(mean) @ op.adjoint + energy_tester(npos, nget_noisy_data, nenergy_initializer) + def energy_tester(pos, get_noisy_data, energy_initializer): + if np.issubdtype(pos.dtype, np.complexfloating): + energy_tester_complex(pos, get_noisy_data, energy_initializer) + return domain = pos.domain test_vec = ift.from_random(domain, 'normal') results = [] @@ -48,6 +74,8 @@ def energy_tester(pos, get_noisy_data, energy_initializer): energy = energy_initializer(data) grad = energy(lin).jac.adjoint(ift.full(energy.target, 1.)) results.append(_to_array((grad*grad.s_vdot(test_vec)).val)) + print(energy) + print(grad) res = np.mean(np.array(results), axis=0) std = np.std(np.array(results), axis=0)/np.sqrt(Nsamp) energy = energy_initializer(data) @@ -57,6 +85,7 @@ def energy_tester(pos, get_noisy_data, energy_initializer): def test_GaussianEnergy(field): dtype = field.dtype + icov = ift.from_random(field.domain, 'normal')**2 icov = ift.makeOp(icov) get_noisy_data = lambda mean : mean + icov.draw_sample_with_dtype( -- GitLab