Commit ac61e3f5 authored by Philipp Arras's avatar Philipp Arras
Browse files

Work on simplify for constant input

parent e6188b31
......@@ -160,6 +160,39 @@ class GaussianEnergy(EnergyOperator):
metric = SandwichOperator.make(x.jac, self._icov)
return res.add_metric(metric)
def _simplify_for_constant_input_nontrivial(self, c_inp):
from .operator import _ConstantOperator
from ..multi_domain import MultiDomain
if self._icov is not None:
raise NotImplementedError
# No need to implement support for DomainTuple since this done by
# Operator.simplify_for_constant_input()
c_dom = {}
var_dom = {}
for kk in self._domain.keys():
if kk in c_inp.domain.keys():
c_dom[kk] = self._domain[kk]
else:
var_dom[kk] = self._domain[kk]
var_dom = MultiDomain.make(var_dom)
c_dom = MultiDomain.make(c_dom)
c_mean = None if self._mean is None else self._mean.extract(c_dom)
var_mean = None if self._mean is None else self._mean.extract(var_dom)
c_op = _ConstantOperator(self._domain,
GaussianEnergy(c_mean, None, c_inp.domain)(c_inp))
var_op = GaussianEnergy(var_mean, None, var_dom) #@ rest
newop = var_op + c_op
import nifty5 as ift
fld = ift.from_random('normal', newop.domain)
print(newop(fld).val)
print(self(fld).val)
print(self)
assert (newop-self)(fld).val == 0
return None, newop
class PoissonianEnergy(EnergyOperator):
"""Computes likelihood Hamiltonians of expected count field constrained by
......
......@@ -28,6 +28,7 @@ pmp = pytest.mark.parametrize
@pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b']))
@pmp('mirror_samples', (True, False))
def test_kl(constants, point_estimates, mirror_samples):
simplify = len(constants) == 1 and set(constants) == set(point_estimates)
np.random.seed(42)
dom = ift.RGSpace((12,), (2.12))
op0 = ift.HarmonicSmoothingOperator(dom, 3)
......@@ -51,9 +52,26 @@ def test_kl(constants, point_estimates, mirror_samples):
mirror_samples=mirror_samples,
napprox=0,
_samples=kl.samples)
if simplify:
cdom = {}
for kk in constants:
cdom[kk] = h.domain[kk]
cdom = ift.MultiDomain.make(cdom)
cst = mean0.extract(cdom)
_, hcst = h.simplify_for_constant_input(cst)
val0, val1 = h(mean0), hcst(mean0)
assert_allclose(val0.to_global_data(), val1.to_global_data())
klcst = ift.MetricGaussianKL(mean0,
hcst,
nsamps,
mirror_samples=mirror_samples,
napprox=0,
_samples=kl.samples)
# Test value
assert_allclose(kl.value, klpure.value)
if simplify:
assert_allclose(kl.value, klcst.value)
# Test gradient
for kk in h.domain.keys():
......@@ -62,6 +80,9 @@ def test_kl(constants, point_estimates, mirror_samples):
res0 = 0*res0
res1 = kl.gradient.to_global_data()[kk]
assert_allclose(res0, res1)
if simplify:
res2 = klcst.gradient.to_global_data()[kk]
assert_allclose(res0, res2)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment