Skip to content
Snippets Groups Projects
Commit 151b37f4 authored by Philipp Arras's avatar Philipp Arras
Browse files

Refactoring and add test

parent d1ab800e
No related branches found
No related tags found
1 merge request!545Proper constants
Pipeline #77030 passed
......@@ -51,16 +51,21 @@ def _modify_sample_domain(sample, domain):
"""Takes only keys from sample which are also in domain and inserts zeros
in sample if key is not in domain."""
from ..multi_domain import MultiDomain
if not isinstance(sample, MultiField):
assert sample.domain is domain
return sample
assert isinstance(domain, MultiDomain)
if sample.domain is domain:
from ..field import Field
from ..domain_tuple import DomainTuple
from ..sugar import makeDomain
domain = makeDomain(domain)
if isinstance(domain, DomainTuple) and isinstance(sample, Field):
if sample.domain is not domain:
raise TypeError
return sample
out = {kk: vv for kk, vv in sample.items() if kk in domain.keys()}
out = MultiField.from_dict(out, domain)
assert domain is out.domain
return out
elif isinstance(domain, MultiDomain) and isinstance(sample, MultiField):
if sample.domain is domain:
return sample
out = {kk: vv for kk, vv in sample.items() if kk in domain.keys()}
out = MultiField.from_dict(out, domain)
return out
raise TypeError
class MetricGaussianKL(Energy):
......
......@@ -15,7 +15,7 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from numpy.testing import assert_, assert_allclose
from numpy.testing import assert_, assert_allclose, assert_raises
import nifty7 as ift
from nifty7.operators.simplify_for_const import ConstantOperator
......@@ -41,3 +41,28 @@ def test_simplification():
assert_allclose(op(f1)["a"].val, op2.force(f1)["a"].val)
assert_allclose(op(f1)["b"].val, op2.force(f1)["b"].val)
# FIXME Add test for ChainOperator._simplify_for_constant_input_nontrivial()
def test_modify_sample_domain():
func = ift.minimization.metric_gaussian_kl._modify_sample_domain
dom0 = ift.RGSpace(1)
dom1 = ift.RGSpace(2)
field = ift.full(dom0, 1.)
ift.extra.assert_equal(func(field, dom0), field)
mdom0 = ift.makeDomain({'a': dom0, 'b': dom1})
mdom1 = ift.makeDomain({'a': dom0})
mfield0 = ift.full(mdom0, 1.)
mfield1 = ift.full(mdom1, 1.)
mfield01 = ift.MultiField.from_dict({'a': ift.full(dom0, 1.),
'b': ift.full(dom1, 0.)})
ift.extra.assert_equal(func(mfield0, mdom0), mfield0)
ift.extra.assert_equal(func(mfield0, mdom1), mfield1)
ift.extra.assert_equal(func(mfield1, mdom0), mfield01)
ift.extra.assert_equal(func(mfield1, mdom1), mfield1)
with assert_raises(TypeError):
func(mfield0, dom0)
with assert_raises(TypeError):
func(field, dom1)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment