Commit ac61e3f5 by Philipp Arras

### Work on simplify for constant input

parent e6188b31
 ... ... @@ -160,6 +160,39 @@ class GaussianEnergy(EnergyOperator): metric = SandwichOperator.make(x.jac, self._icov) return res.add_metric(metric) def _simplify_for_constant_input_nontrivial(self, c_inp): from .operator import _ConstantOperator from ..multi_domain import MultiDomain if self._icov is not None: raise NotImplementedError # No need to implement support for DomainTuple since this done by # Operator.simplify_for_constant_input() c_dom = {} var_dom = {} for kk in self._domain.keys(): if kk in c_inp.domain.keys(): c_dom[kk] = self._domain[kk] else: var_dom[kk] = self._domain[kk] var_dom = MultiDomain.make(var_dom) c_dom = MultiDomain.make(c_dom) c_mean = None if self._mean is None else self._mean.extract(c_dom) var_mean = None if self._mean is None else self._mean.extract(var_dom) c_op = _ConstantOperator(self._domain, GaussianEnergy(c_mean, None, c_inp.domain)(c_inp)) var_op = GaussianEnergy(var_mean, None, var_dom) #@ rest newop = var_op + c_op import nifty5 as ift fld = ift.from_random('normal', newop.domain) print(newop(fld).val) print(self(fld).val) print(self) assert (newop-self)(fld).val == 0 return None, newop class PoissonianEnergy(EnergyOperator): """Computes likelihood Hamiltonians of expected count field constrained by ... ...
 ... ... @@ -28,6 +28,7 @@ pmp = pytest.mark.parametrize @pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b'])) @pmp('mirror_samples', (True, False)) def test_kl(constants, point_estimates, mirror_samples): simplify = len(constants) == 1 and set(constants) == set(point_estimates) np.random.seed(42) dom = ift.RGSpace((12,), (2.12)) op0 = ift.HarmonicSmoothingOperator(dom, 3) ... ... @@ -51,9 +52,26 @@ def test_kl(constants, point_estimates, mirror_samples): mirror_samples=mirror_samples, napprox=0, _samples=kl.samples) if simplify: cdom = {} for kk in constants: cdom[kk] = h.domain[kk] cdom = ift.MultiDomain.make(cdom) cst = mean0.extract(cdom) _, hcst = h.simplify_for_constant_input(cst) val0, val1 = h(mean0), hcst(mean0) assert_allclose(val0.to_global_data(), val1.to_global_data()) klcst = ift.MetricGaussianKL(mean0, hcst, nsamps, mirror_samples=mirror_samples, napprox=0, _samples=kl.samples) # Test value assert_allclose(kl.value, klpure.value) if simplify: assert_allclose(kl.value, klcst.value) # Test gradient for kk in h.domain.keys(): ... ... @@ -62,6 +80,9 @@ def test_kl(constants, point_estimates, mirror_samples): res0 = 0*res0 res1 = kl.gradient.to_global_data()[kk] assert_allclose(res0, res1) if simplify: res2 = klcst.gradient.to_global_data()[kk] assert_allclose(res0, res2) # Test number of samples expected_nsamps = 2*nsamps if mirror_samples else nsamps ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!