From 90bcde2c580736f992308efdbd8105f83c43725f Mon Sep 17 00:00:00 2001
From: Reimar Leike <reimar@mpa-garhcing.mpg.de>
Date: Fri, 19 Jun 2020 17:32:31 +0200
Subject: [PATCH] Add test for VariableCovarianceGaussianEnergy

---
 test/test_operators/test_fisher_metric.py | 21 +++++++++++++++++++--
 1 file changed, 19 insertions(+), 2 deletions(-)

diff --git a/test/test_operators/test_fisher_metric.py b/test/test_operators/test_fisher_metric.py
index 6700bd022..1308cee50 100644
--- a/test/test_operators/test_fisher_metric.py
+++ b/test/test_operators/test_fisher_metric.py
@@ -29,6 +29,8 @@ pmp = pytest.mark.parametrize
 field = list2fixture([ift.from_random(sp, 'normal') for sp in spaces] +
                      [ift.from_random(sp, 'normal', dtype=np.complex128) for sp in spaces])
 
+dtype = list2fixture([np.float64,
+                     np.complex128])
 Nsamp = 2000
 np.random.seed(42)
 
@@ -110,7 +112,7 @@ def test_GaussianEnergy(field):
     icov = ift.makeOp(icov)
     get_noisy_data = lambda mean: mean + icov.draw_sample_with_dtype(
         from_inverse=True, dtype=dtype)
-    E_init = lambda mean: ift.GaussianEnergy(mean=mean, inverse_covariance=icov)
+    E_init = lambda data: ift.GaussianEnergy(mean=data, inverse_covariance=icov)
     energy_tester(field, get_noisy_data, E_init)
 
 
@@ -122,5 +124,20 @@ def test_PoissonEnergy(field):
     get_noisy_data = lambda mean: ift.makeField(mean.domain, np.random.poisson(mean.val))
     # Make rate positive and high enough to avoid bad statistic
     lam = 10*(field**2).clip(0.1, None)
-    E_init = lambda mean: ift.PoissonianEnergy(mean)
+    E_init = lambda data: ift.PoissonianEnergy(data)
     energy_tester(lam, get_noisy_data, E_init)
+
+def test_VariableCovarianceGaussianEnergy(dtype):
+    dom = ift.UnstructuredDomain(3)
+    res = ift.from_random(dom, 'normal', dtype=dtype)
+    ivar = ift.from_random(dom, 'normal')**2+4.
+    mf = ift.MultiField.from_dict({'res':res, 'ivar':ivar})
+    energy = ift.VariableCovarianceGaussianEnergy(dom, 'res', 'ivar', dtype)
+    def get_noisy_data(mean):
+        samp = ift.from_random(dom, 'normal', dtype)
+        samp = samp/mean['ivar'].sqrt()
+        return samp + mean['res']
+    def E_init(data):
+        adder = ift.Adder(ift.MultiField.from_dict({'res':data}), neg=True)
+        return energy.partial_insert(adder)
+    energy_tester(mf, get_noisy_data, E_init)
-- 
GitLab