Commit 2ffc1226 authored by Philipp Arras's avatar Philipp Arras
Browse files

Implement proper constant support 2/n

parent 7d22d7be
Pipeline #76999 failed with stages
in 5 minutes and 2 seconds
......@@ -274,6 +274,7 @@ class Operator(metaclass=NiftyMeta):
from .energy_operators import EnergyOperator
from .simplify_for_const import ConstantEnergyOperator, ConstantOperator
from ..multi_field import MultiField
from ..domain_tuple import DomainTuple
if c_inp is None or (isinstance(c_inp, MultiField) and len(c_inp.keys()) == 0):
return None, self
dom = c_inp.domain
......@@ -288,11 +289,13 @@ class Operator(metaclass=NiftyMeta):
raise ValueError
if dom is self.domain:
if isinstance(self, DomainTuple):
raise RuntimeError
if isinstance(self, EnergyOperator):
op = ConstantEnergyOperator(self.domain, self(c_inp))
op = ConstantEnergyOperator(self(c_inp))
else:
op = ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
op = ConstantOperator(self(c_inp))
return None, op
if not isinstance(dom, MultiDomain):
raise RuntimeError
return self._simplify_for_constant_input_nontrivial(c_inp)
......
......@@ -60,10 +60,10 @@ class ConstCollector(object):
class ConstantOperator(Operator):
def __init__(self, dom, output):
def __init__(self, output):
from ..sugar import makeDomain
self._domain = makeDomain(dom)
self._target = output.domain
self._domain = makeDomain({})
self._target = makeDomain(output.domain)
self._output = output
def apply(self, x):
......@@ -74,20 +74,17 @@ class ConstantOperator(Operator):
return self._output
def __repr__(self):
dom = self.domain.keys() if isinstance(self.domain, MultiDomain) else '()'
tgt = self.target.keys() if isinstance(self.target, MultiDomain) else '()'
return f'{tgt} <- ConstantOperator <- {dom}'
return f'{tgt} <- ConstantOperator'
class ConstantEnergyOperator(EnergyOperator):
def __init__(self, dom, output):
def __init__(self, output):
from ..sugar import makeDomain
from ..field import Field
self._domain = makeDomain(dom)
self._domain = makeDomain({})
if not isinstance(output, Field):
output = Field.scalar(float(output))
if self.target is not output.domain:
raise TypeError
self._output = output
def apply(self, x):
......@@ -95,13 +92,11 @@ class ConstantEnergyOperator(EnergyOperator):
if x.jac is not None:
val = self._output
jac = NullOperator(self._domain, self._target)
# FIXME Do we need a metric here?
met = NullOperator(self._domain, self._domain) if x.want_metric else None
return x.new(val, jac, met)
return self._output
def __repr__(self):
return 'ConstantEnergyOperator <- {}'.format(self.domain.keys())
class InsertionOperator(Operator):
def __init__(self, target, cst_field):
......
......@@ -85,41 +85,4 @@ def testgaussianenergy_compatibility(cplx):
loc0 = ift.MultiField.from_dict({'resi': resi})
loc1 = ift.MultiField.from_dict({'icov': ift.from_random(dom).exp()})
loc = loc0.unite(loc1)
val0 = e(loc).val
_, e0 = e.simplify_for_constant_input(loc0)
val1 = e0(loc).val
val2 = e0(loc.unite(loc0)).val
np.testing.assert_equal(val1, val2)
np.testing.assert_equal(val0, val1)
_, e1 = e.simplify_for_constant_input(loc1)
val1 = e1(loc).val
val2 = e1(loc.unite(loc1)).val
np.testing.assert_equal(val0, val1)
np.testing.assert_equal(val1, val2)
ift.extra.check_operator(e, loc, ntries=ntries)
ift.extra.check_operator(e0, loc, ntries=ntries, tol=1e-7)
ift.extra.check_operator(e1, loc, ntries=ntries)
# Test jacobian is zero
lin = ift.Linearization.make_var(loc, want_metric=True)
grad = e(lin).gradient.val
grad0 = e0(lin).gradient.val
grad1 = e1(lin).gradient.val
samp = e(lin).metric.draw_sample().val
samp0 = e0(lin).metric.draw_sample().val
samp1 = e1(lin).metric.draw_sample().val
np.testing.assert_equal(samp0['resi'], 0.)
np.testing.assert_equal(samp1['icov'], 0.)
np.testing.assert_equal(grad0['resi'], 0.)
np.testing.assert_equal(grad1['icov'], 0.)
np.testing.assert_(all(samp['resi'] != 0))
np.testing.assert_(all(samp['icov'] != 0))
np.testing.assert_(all(samp0['icov'] != 0))
np.testing.assert_(all(samp1['resi'] != 0))
np.testing.assert_(all(grad['resi'] != 0))
np.testing.assert_(all(grad['icov'] != 0))
np.testing.assert_(all(grad0['icov'] != 0))
np.testing.assert_(all(grad1['resi'] != 0))
ift.extra.check_operator(e, loc, ntries=20)
......@@ -15,6 +15,7 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import pytest
from numpy.testing import assert_, assert_allclose, assert_raises
......@@ -36,7 +37,6 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
op = ift.HarmonicSmoothingOperator(dom, 3)
if mf:
op = ift.ducktape(dom, None, 'a')*(op.ducktape('b'))
import numpy as np
lh = ift.GaussianEnergy(domain=op.target, sampling_dtype=np.float64) @ op
ic = ift.GradientNormController(iteration_limit=5)
ic.enable_logging()
......
......@@ -15,26 +15,25 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from numpy.testing import assert_allclose, assert_equal
from numpy.testing import assert_, assert_allclose
import nifty7 as ift
from ..common import setup_function, teardown_function
from nifty7.operators.simplify_for_const import ConstantOperator
def test_simplification():
from nifty7.operators.simplify_for_const import ConstantOperator
f1 = ift.Field.full(ift.RGSpace(10), 2.)
op = ift.FFTOperator(f1.domain)
_, op2 = op.simplify_for_constant_input(f1)
assert_equal(isinstance(op2, ConstantOperator), True)
assert_allclose(op(f1).val, op2(f1).val)
# f1 = ift.Field.full(ift.RGSpace(10), 2.)
# op = ift.FFTOperator(f1.domain)
# _, op2 = op.simplify_for_constant_input(f1)
# assert_(isinstance(op2, ConstantOperator))
# assert_allclose(op(f1).val, op2.force(f1).val)
dom = {"a": ift.RGSpace(10)}
f1 = ift.full(dom, 2.)
op = ift.FFTOperator(f1.domain["a"]).ducktape("a")
_, op2 = op.simplify_for_constant_input(f1)
assert_equal(isinstance(op2, ConstantOperator), True)
assert_allclose(op(f1).val, op2(f1).val)
assert_(isinstance(op2, ConstantOperator))
assert_allclose(op(f1).val, op2.force(f1).val)
dom = {"a": ift.RGSpace(10), "b": ift.RGSpace(5)}
f1 = ift.full(dom, 2.)
......@@ -45,11 +44,5 @@ def test_simplification():
op = (o1.ducktape("a").ducktape_left("a") +
o2.ducktape("b").ducktape_left("b"))
_, op2 = op.simplify_for_constant_input(f2)
assert_equal(isinstance(op2._op1, ConstantOperator), True)
assert_allclose(op(f1)["a"].val, op2(f1)["a"].val)
assert_allclose(op(f1)["b"].val, op2(f1)["b"].val)
lin = ift.Linearization.make_var(ift.MultiField.full(op2.domain, 2.), True)
assert_allclose(op(lin).val["a"].val,
op2(lin).val["a"].val)
assert_allclose(op(lin).val["b"].val,
op2(lin).val["b"].val)
assert_allclose(op(f1)["a"].val, op2.force(f1)["a"].val)
assert_allclose(op(f1)["b"].val, op2.force(f1)["b"].val)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment