From 151b37f4af7aaedb2b0d8035fa010b6ac78094fa Mon Sep 17 00:00:00 2001 From: Philipp Arras <parras@mpa-garching.mpg.de> Date: Sun, 21 Jun 2020 13:01:39 +0200 Subject: [PATCH] Refactoring and add test --- src/minimization/metric_gaussian_kl.py | 23 ++++++++++-------- test/test_operators/test_simplification.py | 27 +++++++++++++++++++++- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/src/minimization/metric_gaussian_kl.py b/src/minimization/metric_gaussian_kl.py index 65b95b079..475da112a 100644 --- a/src/minimization/metric_gaussian_kl.py +++ b/src/minimization/metric_gaussian_kl.py @@ -51,16 +51,21 @@ def _modify_sample_domain(sample, domain): """Takes only keys from sample which are also in domain and inserts zeros in sample if key is not in domain.""" from ..multi_domain import MultiDomain - if not isinstance(sample, MultiField): - assert sample.domain is domain - return sample - assert isinstance(domain, MultiDomain) - if sample.domain is domain: + from ..field import Field + from ..domain_tuple import DomainTuple + from ..sugar import makeDomain + domain = makeDomain(domain) + if isinstance(domain, DomainTuple) and isinstance(sample, Field): + if sample.domain is not domain: + raise TypeError return sample - out = {kk: vv for kk, vv in sample.items() if kk in domain.keys()} - out = MultiField.from_dict(out, domain) - assert domain is out.domain - return out + elif isinstance(domain, MultiDomain) and isinstance(sample, MultiField): + if sample.domain is domain: + return sample + out = {kk: vv for kk, vv in sample.items() if kk in domain.keys()} + out = MultiField.from_dict(out, domain) + return out + raise TypeError class MetricGaussianKL(Energy): diff --git a/test/test_operators/test_simplification.py b/test/test_operators/test_simplification.py index 6a0814d5e..f1fa7420c 100644 --- a/test/test_operators/test_simplification.py +++ b/test/test_operators/test_simplification.py @@ -15,7 +15,7 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. -from numpy.testing import assert_, assert_allclose +from numpy.testing import assert_, assert_allclose, assert_raises import nifty7 as ift from nifty7.operators.simplify_for_const import ConstantOperator @@ -41,3 +41,28 @@ def test_simplification(): assert_allclose(op(f1)["a"].val, op2.force(f1)["a"].val) assert_allclose(op(f1)["b"].val, op2.force(f1)["b"].val) # FIXME Add test for ChainOperator._simplify_for_constant_input_nontrivial() + + +def test_modify_sample_domain(): + func = ift.minimization.metric_gaussian_kl._modify_sample_domain + dom0 = ift.RGSpace(1) + dom1 = ift.RGSpace(2) + field = ift.full(dom0, 1.) + ift.extra.assert_equal(func(field, dom0), field) + + mdom0 = ift.makeDomain({'a': dom0, 'b': dom1}) + mdom1 = ift.makeDomain({'a': dom0}) + mfield0 = ift.full(mdom0, 1.) + mfield1 = ift.full(mdom1, 1.) + mfield01 = ift.MultiField.from_dict({'a': ift.full(dom0, 1.), + 'b': ift.full(dom1, 0.)}) + + ift.extra.assert_equal(func(mfield0, mdom0), mfield0) + ift.extra.assert_equal(func(mfield0, mdom1), mfield1) + ift.extra.assert_equal(func(mfield1, mdom0), mfield01) + ift.extra.assert_equal(func(mfield1, mdom1), mfield1) + + with assert_raises(TypeError): + func(mfield0, dom0) + with assert_raises(TypeError): + func(field, dom1) -- GitLab