Commit 00c9005f authored by Philipp Arras's avatar Philipp Arras
Browse files

Add nontrivial simplify for constant input

parent 0b26ae98
Pipeline #76113 passed with stages
in 13 minutes and 43 seconds
......@@ -175,6 +175,47 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
met = MultiField.from_dict({self._kr: i.val, self._ki: met**(-2)})
return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt))
def _simplify_for_constant_input_nontrivial(self, c_inp):
from .simplify_for_const import ConstantEnergyOperator
assert len(c_inp.keys()) == 1
key = c_inp.keys()[0]
assert key in self._domain.keys()
cst = c_inp[key]
if key == self._kr:
res = _SpecialGammaEnergy(cst).ducktape(self._ki)
else:
dt = self._dt[self._kr]
res = GaussianEnergy(inverse_covariance=makeOp(cst),
sampling_dtype=dt).ducktape(self._kr)
trlog = cst.log().sum().val_rw()
if not _iscomplex(dt):
trlog /= 2
res = res + ConstantEnergyOperator(res.domain, -trlog)
res = res + ConstantEnergyOperator(self._domain, 0.)
assert res.domain is self.domain
assert res.target is self.target
return None, res
class _SpecialGammaEnergy(EnergyOperator):
def __init__(self, residual):
self._domain = DomainTuple.make(residual.domain)
self._resi = residual
self._cplx = _iscomplex(self._resi.dtype)
self._scale = ScalingOperator(self._domain, 1 if self._cplx else .5)
def apply(self, x):
self._check_input(x)
r = self._resi
if self._cplx:
res = 0.5*(r*x.real).vdot(r).real - x.ptw("log").sum()
else:
res = 0.5*((r*x).vdot(r) - x.ptw("log").sum())
if not x.want_metric:
return res
met = makeOp((self._scale(x.val))**(-2))
return res.add_metric(SamplingDtypeSetter(met, self._resi.dtype))
class GaussianEnergy(EnergyOperator):
"""Computes a negative-log Gaussian.
......
......@@ -347,6 +347,12 @@ class NullOperator(LinearOperator):
tgt = self.target.keys() if isinstance(self.target, MultiDomain) else '()'
return f'{tgt} <- NullOperator <- {dom}'
def draw_sample(self, from_inverse=False):
if self._domain is not self._target:
raise RuntimeError
from ..sugar import full
return full(self._domain, 0.)
class PartialExtractor(LinearOperator):
def __init__(self, domain, target):
......
......@@ -109,7 +109,10 @@ class SlowPartialConstantOperator(Operator):
class ConstantEnergyOperator(EnergyOperator):
def __init__(self, dom, output):
from ..sugar import makeDomain
from ..field import Field
self._domain = makeDomain(dom)
if not isinstance(output, Field):
output = Field.scalar(float(output))
if self.target is not output.domain:
raise TypeError
self._output = output
......
......@@ -88,6 +88,15 @@ def test_variablecovariancegaussian(field):
energy(ift.Linearization.make_var(mf, want_metric=True)).metric.draw_sample()
def test_specialgamma(field):
if isinstance(field.domain, ift.MultiDomain):
return
energy = ift.operators.energy_operators._SpecialGammaEnergy(field)
loc = ift.from_random(energy.domain).exp()
ift.extra.check_jacobian_consistency(energy, loc, tol=1e-6, ntries=ntries)
energy(ift.Linearization.make_var(loc, want_metric=True)).metric.draw_sample()
def test_inverse_gamma(field):
if isinstance(field.domain, ift.MultiDomain):
return
......
......@@ -28,6 +28,7 @@ def _flat_PS(k):
pmp = pytest.mark.parametrize
ntries = 10
@pmp('space', [ift.GLSpace(5),
......@@ -70,4 +71,55 @@ def test_gaussian_energy(space, nonlinearity, noise, seed):
energy = ift.GaussianEnergy(d, N) @ d_model()
ift.extra.check_jacobian_consistency(
energy, xi0, ntries=10, tol=1e-6)
energy, xi0, ntries=ntries, tol=1e-6)
@pmp('cplx', [True, False])
def testgaussianenergy_compatibility(cplx):
dt = np.complex128 if cplx else np.float64
dom = ift.UnstructuredDomain(3)
e = ift.VariableCovarianceGaussianEnergy(dom, 'resi', 'icov', dt)
resi = ift.from_random(dom)
if cplx:
resi = resi + 1j*ift.from_random(dom)
loc0 = ift.MultiField.from_dict({'resi': resi})
loc1 = ift.MultiField.from_dict({'icov': ift.from_random(dom).exp()})
loc = loc0.unite(loc1)
val0 = e(loc).val
_, e0 = e.simplify_for_constant_input(loc0)
val1 = e0(loc).val
val2 = e0(loc.unite(loc0)).val
np.testing.assert_equal(val1, val2)
np.testing.assert_equal(val0, val1)
_, e1 = e.simplify_for_constant_input(loc1)
val1 = e1(loc).val
val2 = e1(loc.unite(loc1)).val
np.testing.assert_equal(val0, val1)
np.testing.assert_equal(val1, val2)
ift.extra.check_jacobian_consistency(e, loc, ntries=ntries)
ift.extra.check_jacobian_consistency(e0, loc, ntries=ntries)
ift.extra.check_jacobian_consistency(e1, loc, ntries=ntries)
# Test jacobian is zero
lin = ift.Linearization.make_var(loc, want_metric=True)
grad = e(lin).gradient.val
grad0 = e0(lin).gradient.val
grad1 = e1(lin).gradient.val
samp = e(lin).metric.draw_sample().val
samp0 = e0(lin).metric.draw_sample().val
samp1 = e1(lin).metric.draw_sample().val
np.testing.assert_equal(samp0['resi'], 0.)
np.testing.assert_equal(samp1['icov'], 0.)
np.testing.assert_equal(grad0['resi'], 0.)
np.testing.assert_equal(grad1['icov'], 0.)
np.testing.assert_(all(samp['resi'] != 0))
np.testing.assert_(all(samp['icov'] != 0))
np.testing.assert_(all(samp0['icov'] != 0))
np.testing.assert_(all(samp1['resi'] != 0))
np.testing.assert_(all(grad['resi'] != 0))
np.testing.assert_(all(grad['icov'] != 0))
np.testing.assert_(all(grad0['icov'] != 0))
np.testing.assert_(all(grad1['resi'] != 0))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment