Commit 151b37f4 authored by Philipp Arras's avatar Philipp Arras
Browse files

Refactoring and add test

parent d1ab800e
Pipeline #77030 passed with stages
in 12 minutes and 22 seconds
......@@ -51,16 +51,21 @@ def _modify_sample_domain(sample, domain):
"""Takes only keys from sample which are also in domain and inserts zeros
in sample if key is not in domain."""
from ..multi_domain import MultiDomain
if not isinstance(sample, MultiField):
assert sample.domain is domain
return sample
assert isinstance(domain, MultiDomain)
if sample.domain is domain:
from ..field import Field
from ..domain_tuple import DomainTuple
from ..sugar import makeDomain
domain = makeDomain(domain)
if isinstance(domain, DomainTuple) and isinstance(sample, Field):
if sample.domain is not domain:
raise TypeError
return sample
out = {kk: vv for kk, vv in sample.items() if kk in domain.keys()}
out = MultiField.from_dict(out, domain)
assert domain is out.domain
return out
elif isinstance(domain, MultiDomain) and isinstance(sample, MultiField):
if sample.domain is domain:
return sample
out = {kk: vv for kk, vv in sample.items() if kk in domain.keys()}
out = MultiField.from_dict(out, domain)
return out
raise TypeError
class MetricGaussianKL(Energy):
......@@ -15,7 +15,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from numpy.testing import assert_, assert_allclose
from numpy.testing import assert_, assert_allclose, assert_raises
import nifty7 as ift
from nifty7.operators.simplify_for_const import ConstantOperator
......@@ -41,3 +41,28 @@ def test_simplification():
assert_allclose(op(f1)["a"].val, op2.force(f1)["a"].val)
assert_allclose(op(f1)["b"].val, op2.force(f1)["b"].val)
# FIXME Add test for ChainOperator._simplify_for_constant_input_nontrivial()
def test_modify_sample_domain():
func = ift.minimization.metric_gaussian_kl._modify_sample_domain
dom0 = ift.RGSpace(1)
dom1 = ift.RGSpace(2)
field = ift.full(dom0, 1.)
ift.extra.assert_equal(func(field, dom0), field)
mdom0 = ift.makeDomain({'a': dom0, 'b': dom1})
mdom1 = ift.makeDomain({'a': dom0})
mfield0 = ift.full(mdom0, 1.)
mfield1 = ift.full(mdom1, 1.)
mfield01 = ift.MultiField.from_dict({'a': ift.full(dom0, 1.),
'b': ift.full(dom1, 0.)})
ift.extra.assert_equal(func(mfield0, mdom0), mfield0)
ift.extra.assert_equal(func(mfield0, mdom1), mfield1)
ift.extra.assert_equal(func(mfield1, mdom0), mfield01)
ift.extra.assert_equal(func(mfield1, mdom1), mfield1)
with assert_raises(TypeError):
func(mfield0, dom0)
with assert_raises(TypeError):
func(field, dom1)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment