From 743e98c9e262e6aa37cb958190b57763d592c52c Mon Sep 17 00:00:00 2001
From: Philipp Arras <parras@mpa-garching.mpg.de>
Date: Mon, 9 Mar 2020 12:22:30 +0100
Subject: [PATCH] MultiField tests for energies

---
 test/test_energy_gradients.py | 66 ++++++++++++++++++++---------------
 1 file changed, 38 insertions(+), 28 deletions(-)

diff --git a/test/test_energy_gradients.py b/test/test_energy_gradients.py
index 3344bdbb1..5f82ee5fe 100644
--- a/test/test_energy_gradients.py
+++ b/test/test_energy_gradients.py
@@ -28,26 +28,18 @@ from itertools import product
 SPACES = [ift.GLSpace(15),
           ift.RGSpace(64, distances=.789),
           ift.RGSpace([32, 32], distances=.789)]
+for sp in SPACES[:3]:
+    SPACES.append(ift.MultiDomain.make({'asdf': sp}))
 SEEDS = [4, 78, 23]
 PARAMS = product(SEEDS, SPACES)
 pmp = pytest.mark.parametrize
-# FIXME Test also with multifields in domain
 
 
 @pytest.fixture(params=PARAMS)
 def field(request):
     np.random.seed(request.param[0])
     S = ift.ScalingOperator(request.param[1], 1.)
-    s = S.draw_sample()
-    return ift.MultiField.from_dict({'s1': s})['s1']
-
-
-def test_variablecovariancegaussian(field):
-    dc = {'a': field, 'b': field.exp()}
-    mf = ift.MultiField.from_dict(dc)
-    energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b')
-    ift.extra.check_jacobian_consistency(energy, mf, tol=1e-6)
-    energy(ift.Linearization.make_var(mf, want_metric=True)).metric.draw_sample()
+    return S.draw_sample()
 
 
 def test_gaussian(field):
@@ -55,26 +47,52 @@ def test_gaussian(field):
     ift.extra.check_jacobian_consistency(energy, field)
 
 
-@pmp('icov', [lambda dom:ift.ScalingOperator(dom, 1.),
-              lambda dom:ift.SandwichOperator.make(ift.GeometryRemover(dom))])
-def test_ScaledEnergy(field, icov):
-    icov = icov(field.domain)
+def test_ScaledEnergy(field):
+    icov = ift.ScalingOperator(field.domain, 1.2)
     energy = ift.GaussianEnergy(inverse_covariance=icov)
     ift.extra.check_jacobian_consistency(energy.scale(0.3), field)
 
     lin = ift.Linearization.make_var(field, want_metric=True)
     met1 = energy(lin).metric
     met2 = energy.scale(0.3)(lin).metric
-    np.testing.assert_allclose(met1(field).val, met2(field).val/0.3, rtol=1e-12)
+    res1 = met1(field)
+    res2 = met2(field)/0.3
+    ift.extra.assert_allclose(res1, res2, 0, 1e-12)
     met2.draw_sample()
 
 
 def test_studentt(field):
+    if isinstance(field.domain, ift.MultiDomain):
+        return
     energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
     ift.extra.check_jacobian_consistency(energy, field, tol=1e-6)
 
 
+def test_hamiltonian_and_KL(field):
+    field = field.exp()
+    space = field.domain
+    lh = ift.GaussianEnergy(domain=space)
+    hamiltonian = ift.StandardHamiltonian(lh)
+    ift.extra.check_jacobian_consistency(hamiltonian, field)
+    S = ift.ScalingOperator(space, 1.)
+    samps = [S.draw_sample() for i in range(3)]
+    kl = ift.AveragedEnergy(hamiltonian, samps)
+    ift.extra.check_jacobian_consistency(kl, field)
+
+
+def test_variablecovariancegaussian(field):
+    if isinstance(field.domain, ift.MultiDomain):
+        return
+    dc = {'a': field, 'b': field.exp()}
+    mf = ift.MultiField.from_dict(dc)
+    energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b')
+    ift.extra.check_jacobian_consistency(energy, mf, tol=1e-6)
+    energy(ift.Linearization.make_var(mf, want_metric=True)).metric.draw_sample()
+
+
 def test_inverse_gamma(field):
+    if isinstance(field.domain, ift.MultiDomain):
+        return
     field = field.exp()
     space = field.domain
     d = np.random.normal(10, size=space.shape)**2
@@ -84,6 +102,8 @@ def test_inverse_gamma(field):
 
 
 def testPoissonian(field):
+    if isinstance(field.domain, ift.MultiDomain):
+        return
     field = field.exp()
     space = field.domain
     d = np.random.poisson(120, size=space.shape)
@@ -92,19 +112,9 @@ def testPoissonian(field):
     ift.extra.check_jacobian_consistency(energy, field, tol=1e-7)
 
 
-def test_hamiltonian_and_KL(field):
-    field = field.exp()
-    space = field.domain
-    lh = ift.GaussianEnergy(domain=space)
-    hamiltonian = ift.StandardHamiltonian(lh)
-    ift.extra.check_jacobian_consistency(hamiltonian, field)
-    S = ift.ScalingOperator(space, 1.)
-    samps = [S.draw_sample() for i in range(3)]
-    kl = ift.AveragedEnergy(hamiltonian, samps)
-    ift.extra.check_jacobian_consistency(kl, field)
-
-
 def test_bernoulli(field):
+    if isinstance(field.domain, ift.MultiDomain):
+        return
     field = field.sigmoid()
     space = field.domain
     d = np.random.binomial(1, 0.1, size=space.shape)
-- 
GitLab