Commit 668fc8c7 authored by Reimar Leike's avatar Reimar Leike
Browse files

added a Gaussian Energy with variabel sigma, has to be debugged

parent df116915
Pipeline #70181 failed with stages
in 19 minutes and 25 seconds
......@@ -44,7 +44,7 @@ from .operators.value_inserter import ValueInserter
from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
BernoulliEnergy, StandardHamiltonian, AveragedEnergy, QuadraticFormOperator,
Squared2NormOperator, StudentTEnergy)
Squared2NormOperator, StudentTEnergy, VariableCovarianceGaussianEnergy)
from .operators.convolution_operators import FuncConvolutionOperator
from .probing import probe_with_posterior_samples, probe_diagonal, \
......
......@@ -19,6 +19,7 @@ import numpy as np
from .. import utilities
from ..domain_tuple import DomainTuple
from ..multi_domain import MultiDomain
from ..field import Field
from ..multi_field import MultiField
from ..linearization import Linearization
......@@ -28,7 +29,7 @@ from .operator import Operator
from .sampling_enabler import SamplingEnabler
from .sandwich_operator import SandwichOperator
from .scaling_operator import ScalingOperator
from .simple_linear_operators import VdotOperator
from .simple_linear_operators import VdotOperator, FieldAdapter
class EnergyOperator(Operator):
......@@ -95,6 +96,63 @@ class QuadraticFormOperator(EnergyOperator):
return x.new(val, jac)
return Field.scalar(0.5*x.vdot(self._op(x)))
class VariableCovarianceGaussianEnergy(EnergyOperator):
"""Computes a negative-log Gaussian with unknown covariance.
Represents up to constants in :math:`m`:
.. math ::
E(f) = - \\log G(s, D) = 0.5 (s)^\\dagger D^{-1} (s),
an information energy for a Gaussian distribution with residual s and
covariance D.
Parameters
----------
residual : key
residual of the Gaussian.
inverse_covariance : key
Inverse covariance of the Gaussian.
domain : Domain, DomainTuple, tuple of Domain
Operator domain. By default it is inferred from `mean` or
`covariance` if specified
"""
def __init__(self, domain, residual, inverse_covariance):
self._residual = residual
self._icov = inverse_covariance
self._domain = MultiDomain.make({self._residual:domain,
self._icov:domain})
self._singledom = domain
def apply(self, x):
self._check_input(x)
lin = isinstance(x, Linearization)
xval = x.val if lin else x
res = .5*xval[self._residual].vdot(xval[self._residual]*xval[self._icov])\
- .5*xval[self._icov].log().sum()
if not lin:
return res
FA_res = FieldAdapter(self._singledom, self._residual)
FA_sig = FieldAdapter(self._singledom, self._icov)
jac_res = xval[self._residual]*xval[self._icov]
jac_res = VdotOperator(jac_res)(FA_res)
jac_sig = .5*(xval[self._residual].absolute()**2)
jac_sig = VdotOperator(jac_sig)(FA_sig)
jac_sig = jac_sig - VdotOperator(1./xval[self._residual])(FA_sig)
jac = (jac_sig + jac_res)(x.jac)
res = x.new(res, jac)
if not x.want_metric:
return res
mf = {self._residual:xval[self._icov],
self._icov:.5*xval[self._icov]**(-2)}
mf = MultiField.from_dict(mf)
metric = makeOp(mf)
metric = SandwichOperator(x.jac, metric)
return res.add_metric(metric)
class GaussianEnergy(EnergyOperator):
"""Computes a negative-log Gaussian.
......
......@@ -42,6 +42,12 @@ def field(request):
s = S.draw_sample()
return ift.MultiField.from_dict({'s1': s})['s1']
def test_variablecovariancegaussian(field):
dc = {'a':field, 'b':field.exp()}
mf = ift.MultiField.from_dict(dc)
energy = ift.VariableCovarianceGaussianEnergy(field.domain,
residual='a', inverse_covariance='b')
ift.extra.check_jacobian_consistency(energy, mf)
def test_gaussian(field):
energy = ift.GaussianEnergy(domain=field.domain)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment