Commit 6dfed616 authored by Reimar H Leike's avatar Reimar H Leike
Browse files

Merge branch 'gig-energy-pa' into 'gig-energy'


See merge request !410
parents 2206438b b4282bfe
Pipeline #70292 passed with stages
in 15 minutes and 36 seconds
......@@ -96,68 +96,44 @@ class QuadraticFormOperator(EnergyOperator):
return, jac)
return Field.scalar(0.5*x.vdot(self._op(x)))
class VariableCovarianceGaussianEnergy(EnergyOperator):
"""Computes a negative-log Gaussian with unknown covariance.
Represents up to constants in :math:`s`:
.. math ::
E(f) = - \\log G(s, D) = 0.5 (s)^\\dagger D^{-1} (s),
E(f) = - \\log G(s, D) = 0.5 (s)^\\dagger D^{-1} (s) + 0.5 tr log(D),
an information energy for a Gaussian distribution with residual s and
covariance D.
diagonal covariance D.
domain : Domain, DomainTuple, tuple of Domain
Operator domain. By default it is inferred from `s` or
`covariance` if specified
Operator domain.
residual : key
residual of the Gaussian.
inverse_covariance : key
Inverse covariance of the Gaussian.
Residual key of the Gaussian.
inverse_covariance : key
Inverse covariance diagonal key of the Gaussian.
def __init__(self, domain, residual, inverse_covariance):
self._residual = residual
self._icov = inverse_covariance
self._domain = MultiDomain.make({self._residual:domain,
self._singledom = domain
def __init__(self, domain, residual_key, inverse_covariance_key):
self._r = str(residual_key)
self._icov = str(inverse_covariance_key)
dom = DomainTuple.make(domain)
self._domain = MultiDomain.make({self._r: dom, self._icov: dom})
def apply(self, x):
lin = isinstance(x, Linearization)
xval = x.val if lin else x
res = .5*xval[self._residual].vdot(xval[self._residual]*xval[self._icov])\
- .5*xval[self._icov].log().sum()
if not lin:
return res
FA_res = FieldAdapter(self._singledom, self._residual)
FA_sig = FieldAdapter(self._singledom, self._icov)
jac_res = xval[self._residual]*xval[self._icov]
jac_res = VdotOperator(jac_res)(FA_res)
# So here we are varying w.r.t. inverse covariance
jac_sig = .5*(xval[self._residual].absolute()**2)
jac_sig = VdotOperator(jac_sig)(FA_sig)
jac_sig = jac_sig - .5*VdotOperator(1./xval[self._icov])(FA_sig)
jac = (jac_sig + jac_res)(x.jac)
res =, jac)
if not x.want_metric:
return res
mf = {self._residual:xval[self._icov],
mf = MultiField.from_dict(mf)
metric = makeOp(mf)
metric = SandwichOperator(x.jac, metric)
return res.add_metric(metric)
res0 = x[self._r].vdot(x[self._r]*x[self._icov]).real
res1 = x[self._icov].log().sum()
res = 0.5*(res0-res1)
mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
metric = makeOp(MultiField.from_dict(mf))
return res.add_metric(SandwichOperator.make(x.jac, metric))
class GaussianEnergy(EnergyOperator):
......@@ -42,30 +42,34 @@ def field(request):
s = S.draw_sample()
return ift.MultiField.from_dict({'s1': s})['s1']
def test_variablecovariancegaussian(field):
dc = {'a':field, 'b': field.exp()}
dc = {'a': field, 'b': field.exp()}
mf = ift.MultiField.from_dict(dc)
energy = ift.VariableCovarianceGaussianEnergy(field.domain,
residual='a', inverse_covariance='b')
energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b')
ift.extra.check_jacobian_consistency(energy, mf, tol=1e-6)
energy(ift.Linearization.make_var(mf, want_metric=True)).metric.draw_sample()
def test_gaussian(field):
energy = ift.GaussianEnergy(domain=field.domain)
ift.extra.check_jacobian_consistency(energy, field)
@pmp('icov', [lambda dom:ift.ScalingOperator(dom, 1.),
lambda dom:ift.SandwichOperator.make(ift.GeometryRemover(dom))])
lambda dom:ift.SandwichOperator.make(ift.GeometryRemover(dom))])
def test_ScaledEnergy(field, icov):
icov = icov(field.domain)
energy = ift.GaussianEnergy(inverse_covariance=icov)
ift.extra.check_jacobian_consistency(energy.scale(0.3), field)
lin = ift.Linearization.make_var(field, want_metric=True)
lin = ift.Linearization.make_var(field, want_metric=True)
met1 = energy(lin).metric
met2 = energy.scale(0.3)(lin).metric
np.testing.assert_allclose(met1(field).val, met2(field).val / 0.3, rtol=1e-12)
np.testing.assert_allclose(met1(field).val, met2(field).val/0.3, rtol=1e-12)
def test_studentt(field):
energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-6)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment