Commit 151b37f4 authored by Philipp Arras's avatar Philipp Arras
Browse files

Refactoring and add test

parent d1ab800e
Pipeline #77030 passed with stages
in 12 minutes and 22 seconds
...@@ -51,16 +51,21 @@ def _modify_sample_domain(sample, domain): ...@@ -51,16 +51,21 @@ def _modify_sample_domain(sample, domain):
"""Takes only keys from sample which are also in domain and inserts zeros """Takes only keys from sample which are also in domain and inserts zeros
in sample if key is not in domain.""" in sample if key is not in domain."""
from ..multi_domain import MultiDomain from ..multi_domain import MultiDomain
if not isinstance(sample, MultiField): from ..field import Field
assert sample.domain is domain from ..domain_tuple import DomainTuple
return sample from ..sugar import makeDomain
assert isinstance(domain, MultiDomain) domain = makeDomain(domain)
if sample.domain is domain: if isinstance(domain, DomainTuple) and isinstance(sample, Field):
if sample.domain is not domain:
raise TypeError
return sample return sample
out = {kk: vv for kk, vv in sample.items() if kk in domain.keys()} elif isinstance(domain, MultiDomain) and isinstance(sample, MultiField):
out = MultiField.from_dict(out, domain) if sample.domain is domain:
assert domain is out.domain return sample
return out out = {kk: vv for kk, vv in sample.items() if kk in domain.keys()}
out = MultiField.from_dict(out, domain)
return out
raise TypeError
class MetricGaussianKL(Energy): class MetricGaussianKL(Energy):
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from numpy.testing import assert_, assert_allclose from numpy.testing import assert_, assert_allclose, assert_raises
import nifty7 as ift import nifty7 as ift
from nifty7.operators.simplify_for_const import ConstantOperator from nifty7.operators.simplify_for_const import ConstantOperator
...@@ -41,3 +41,28 @@ def test_simplification(): ...@@ -41,3 +41,28 @@ def test_simplification():
assert_allclose(op(f1)["a"].val, op2.force(f1)["a"].val) assert_allclose(op(f1)["a"].val, op2.force(f1)["a"].val)
assert_allclose(op(f1)["b"].val, op2.force(f1)["b"].val) assert_allclose(op(f1)["b"].val, op2.force(f1)["b"].val)
# FIXME Add test for ChainOperator._simplify_for_constant_input_nontrivial() # FIXME Add test for ChainOperator._simplify_for_constant_input_nontrivial()
def test_modify_sample_domain():
func = ift.minimization.metric_gaussian_kl._modify_sample_domain
dom0 = ift.RGSpace(1)
dom1 = ift.RGSpace(2)
field = ift.full(dom0, 1.)
ift.extra.assert_equal(func(field, dom0), field)
mdom0 = ift.makeDomain({'a': dom0, 'b': dom1})
mdom1 = ift.makeDomain({'a': dom0})
mfield0 = ift.full(mdom0, 1.)
mfield1 = ift.full(mdom1, 1.)
mfield01 = ift.MultiField.from_dict({'a': ift.full(dom0, 1.),
'b': ift.full(dom1, 0.)})
ift.extra.assert_equal(func(mfield0, mdom0), mfield0)
ift.extra.assert_equal(func(mfield0, mdom1), mfield1)
ift.extra.assert_equal(func(mfield1, mdom0), mfield01)
ift.extra.assert_equal(func(mfield1, mdom1), mfield1)
with assert_raises(TypeError):
func(mfield0, dom0)
with assert_raises(TypeError):
func(field, dom1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment