From 151b37f4af7aaedb2b0d8035fa010b6ac78094fa Mon Sep 17 00:00:00 2001
From: Philipp Arras <parras@mpa-garching.mpg.de>
Date: Sun, 21 Jun 2020 13:01:39 +0200
Subject: [PATCH] Refactoring and add test

---
 src/minimization/metric_gaussian_kl.py     | 23 ++++++++++--------
 test/test_operators/test_simplification.py | 27 +++++++++++++++++++++-
 2 files changed, 40 insertions(+), 10 deletions(-)

diff --git a/src/minimization/metric_gaussian_kl.py b/src/minimization/metric_gaussian_kl.py
index 65b95b079..475da112a 100644
--- a/src/minimization/metric_gaussian_kl.py
+++ b/src/minimization/metric_gaussian_kl.py
@@ -51,16 +51,21 @@ def _modify_sample_domain(sample, domain):
     """Takes only keys from sample which are also in domain and inserts zeros
     in sample if key is not in domain."""
     from ..multi_domain import MultiDomain
-    if not isinstance(sample, MultiField):
-        assert sample.domain is domain
-        return sample
-    assert isinstance(domain, MultiDomain)
-    if sample.domain is domain:
+    from ..field import Field
+    from ..domain_tuple import DomainTuple
+    from ..sugar import makeDomain
+    domain = makeDomain(domain)
+    if isinstance(domain, DomainTuple) and isinstance(sample, Field):
+        if sample.domain is not domain:
+            raise TypeError
         return sample
-    out = {kk: vv for kk, vv in sample.items() if kk in domain.keys()}
-    out = MultiField.from_dict(out, domain)
-    assert domain is out.domain
-    return out
+    elif isinstance(domain, MultiDomain) and isinstance(sample, MultiField):
+        if sample.domain is domain:
+            return sample
+        out = {kk: vv for kk, vv in sample.items() if kk in domain.keys()}
+        out = MultiField.from_dict(out, domain)
+        return out
+    raise TypeError
 
 
 class MetricGaussianKL(Energy):
diff --git a/test/test_operators/test_simplification.py b/test/test_operators/test_simplification.py
index 6a0814d5e..f1fa7420c 100644
--- a/test/test_operators/test_simplification.py
+++ b/test/test_operators/test_simplification.py
@@ -15,7 +15,7 @@
 #
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
 
-from numpy.testing import assert_, assert_allclose
+from numpy.testing import assert_, assert_allclose, assert_raises
 
 import nifty7 as ift
 from nifty7.operators.simplify_for_const import ConstantOperator
@@ -41,3 +41,28 @@ def test_simplification():
     assert_allclose(op(f1)["a"].val, op2.force(f1)["a"].val)
     assert_allclose(op(f1)["b"].val, op2.force(f1)["b"].val)
     # FIXME Add test for ChainOperator._simplify_for_constant_input_nontrivial()
+
+
+def test_modify_sample_domain():
+    func = ift.minimization.metric_gaussian_kl._modify_sample_domain
+    dom0 = ift.RGSpace(1)
+    dom1 = ift.RGSpace(2)
+    field = ift.full(dom0, 1.)
+    ift.extra.assert_equal(func(field, dom0), field)
+
+    mdom0 = ift.makeDomain({'a': dom0, 'b': dom1})
+    mdom1 = ift.makeDomain({'a': dom0})
+    mfield0 = ift.full(mdom0, 1.)
+    mfield1 = ift.full(mdom1, 1.)
+    mfield01 = ift.MultiField.from_dict({'a': ift.full(dom0, 1.),
+                                         'b': ift.full(dom1, 0.)})
+
+    ift.extra.assert_equal(func(mfield0, mdom0), mfield0)
+    ift.extra.assert_equal(func(mfield0, mdom1), mfield1)
+    ift.extra.assert_equal(func(mfield1, mdom0), mfield01)
+    ift.extra.assert_equal(func(mfield1, mdom1), mfield1)
+
+    with assert_raises(TypeError):
+        func(mfield0, dom0)
+    with assert_raises(TypeError):
+        func(field, dom1)
-- 
GitLab