From b99e06c0d96060c0ca1e3f98cc5bed42ad694d88 Mon Sep 17 00:00:00 2001
From: Reimar Leike <reimar@mpa-garhcing.mpg.de>
Date: Wed, 17 Jun 2020 18:19:41 +0200
Subject: [PATCH] Adjusted fisher test to always make Fisher matrices reaL,
 code for GaussianEnergy was reverted

---
 nifty6/operators/energy_operators.py      | 13 +---------
 test/test_operators/test_fisher_metric.py | 31 ++++++++++++++++++++++-
 2 files changed, 31 insertions(+), 13 deletions(-)

diff --git a/nifty6/operators/energy_operators.py b/nifty6/operators/energy_operators.py
index 21949965b..627c09eaa 100644
--- a/nifty6/operators/energy_operators.py
+++ b/nifty6/operators/energy_operators.py
@@ -244,18 +244,7 @@ class GaussianEnergy(EnergyOperator):
             self._met = inverse_covariance
         if sampling_dtype is not None:
             self._met = SamplingDtypeSetter(self._met, sampling_dtype)
-            if isinstance(sampling_dtype, dict):
-                from .sandwich_operator import SandwichOperator
-                scale = {k:np.sqrt(2.) if np.issubdtype(v, np.complexfloating)
-                        else 1. for k,v in sampling_dtype.items()}
-                scale = _build_MultiScalingOperator(self._domain, scale)
-                self._met = SandwichOperator.make(scale, self._met)
-            else:
-                if np.issubdtype(sampling_dtype, np.complexfloating):
-                    from .sandwich_operator import SandwichOperator
-                    scale = ScalingOperator(self._met.domain,np.sqrt(2))
-                    self._met = SandwichOperator.make(scale, self._met)
-    
+
     def _checkEquivalence(self, newdom):
         newdom = makeDomain(newdom)
         if self._domain is None:
diff --git a/test/test_operators/test_fisher_metric.py b/test/test_operators/test_fisher_metric.py
index 46edccca8..bacb89244 100644
--- a/test/test_operators/test_fisher_metric.py
+++ b/test/test_operators/test_fisher_metric.py
@@ -23,7 +23,6 @@ import nifty6 as ift
 from ..common import list2fixture, setup_function, teardown_function
 
 spaces = [ift.GLSpace(5),
-          ift.MultiDomain.make({'': ift.RGSpace(5, distances=.789)}),
           (ift.RGSpace(3, distances=.789), ift.UnstructuredDomain(2))]
 pmp = pytest.mark.parametrize
 field = list2fixture([ift.from_random(sp, 'normal') for sp in spaces] +
@@ -38,7 +37,34 @@ def _to_array(d):
     assert isinstance(d, dict)
     return np.concatenate(list(d.values()))
 
+def _complex2real(sp):
+    tup = tuple([d for d in sp])
+    rsp = ift.DomainTuple.make((ift.UnstructuredDomain(2),) + tup)
+    rl = ift.DomainTupleFieldInserter(rsp, 0, (0,))
+    im = ift.DomainTupleFieldInserter(rsp, 0, (1,))
+    x = ift.ScalingOperator(sp, 1)
+    return rl(x.real)+im(x.imag)
+
+def test_complex2real():
+    sp = ift.UnstructuredDomain(3)
+    op = _complex2real(ift.makeDomain(sp))
+    f = ift.from_random(op.domain, 'normal', dtype=np.complex128)
+    assert np.all((f == op.adjoint_times(op(f))).val)
+    assert op(f).dtype == np.float64
+    f = ift.from_random(op.target, 'normal')
+    assert np.all((f == op(op.adjoint_times(f))).val)
+    
+def energy_tester_complex(pos, get_noisy_data, energy_initializer):
+    op = _complex2real(pos.domain)
+    npos = op(pos)
+    nget_noisy_data = lambda mean : get_noisy_data(op.adjoint_times(mean))
+    nenergy_initializer = lambda mean : energy_initializer(mean) @ op.adjoint
+    energy_tester(npos, nget_noisy_data, nenergy_initializer)
+
 def energy_tester(pos, get_noisy_data, energy_initializer):
+    if np.issubdtype(pos.dtype, np.complexfloating):
+        energy_tester_complex(pos, get_noisy_data, energy_initializer)
+        return
     domain = pos.domain
     test_vec = ift.from_random(domain, 'normal')
     results = []
@@ -48,6 +74,8 @@ def energy_tester(pos, get_noisy_data, energy_initializer):
         energy = energy_initializer(data)
         grad = energy(lin).jac.adjoint(ift.full(energy.target, 1.))
         results.append(_to_array((grad*grad.s_vdot(test_vec)).val))
+    print(energy)
+    print(grad)
     res = np.mean(np.array(results), axis=0)
     std = np.std(np.array(results), axis=0)/np.sqrt(Nsamp)
     energy = energy_initializer(data)
@@ -57,6 +85,7 @@ def energy_tester(pos, get_noisy_data, energy_initializer):
 
 def test_GaussianEnergy(field):
     dtype = field.dtype
+
     icov = ift.from_random(field.domain, 'normal')**2
     icov = ift.makeOp(icov)
     get_noisy_data = lambda mean : mean + icov.draw_sample_with_dtype(
-- 
GitLab