Commit d37dc659 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'gig-energy' into 'NIFTy_6'

Gig energy

See merge request !411
parents df116915 7a6052d7
Pipeline #70299 passed with stages
in 15 minutes and 32 seconds
...@@ -44,7 +44,7 @@ from .operators.value_inserter import ValueInserter ...@@ -44,7 +44,7 @@ from .operators.value_inserter import ValueInserter
from .operators.energy_operators import ( from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood, EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
BernoulliEnergy, StandardHamiltonian, AveragedEnergy, QuadraticFormOperator, BernoulliEnergy, StandardHamiltonian, AveragedEnergy, QuadraticFormOperator,
Squared2NormOperator, StudentTEnergy) Squared2NormOperator, StudentTEnergy, VariableCovarianceGaussianEnergy)
from .operators.convolution_operators import FuncConvolutionOperator from .operators.convolution_operators import FuncConvolutionOperator
from .probing import probe_with_posterior_samples, probe_diagonal, \ from .probing import probe_with_posterior_samples, probe_diagonal, \
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
from .. import utilities from .. import utilities
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..multi_domain import MultiDomain
from ..field import Field from ..field import Field
from ..multi_field import MultiField from ..multi_field import MultiField
from ..linearization import Linearization from ..linearization import Linearization
...@@ -28,7 +29,7 @@ from .operator import Operator ...@@ -28,7 +29,7 @@ from .operator import Operator
from .sampling_enabler import SamplingEnabler from .sampling_enabler import SamplingEnabler
from .sandwich_operator import SandwichOperator from .sandwich_operator import SandwichOperator
from .scaling_operator import ScalingOperator from .scaling_operator import ScalingOperator
from .simple_linear_operators import VdotOperator from .simple_linear_operators import VdotOperator, FieldAdapter
class EnergyOperator(Operator): class EnergyOperator(Operator):
...@@ -96,6 +97,47 @@ class QuadraticFormOperator(EnergyOperator): ...@@ -96,6 +97,47 @@ class QuadraticFormOperator(EnergyOperator):
return Field.scalar(0.5*x.vdot(self._op(x))) return Field.scalar(0.5*x.vdot(self._op(x)))
class VariableCovarianceGaussianEnergy(EnergyOperator):
"""Computes the negative log pdf of a Gaussian with unknown covariance.
The covariance is assumed to be diagonal.
.. math ::
E(s,D) = - \\log G(s, D) = 0.5 (s)^\\dagger D^{-1} (s) + 0.5 tr log(D),
an information energy for a Gaussian distribution with residual s and
diagonal covariance D.
The domain of this energy will be a MultiDomain with two keys,
the target will be the scalar domain.
Parameters
----------
domain : Domain, DomainTuple, tuple of Domain
domain of the residual and domain of the covariance diagonal.
residual : key
Residual key of the Gaussian.
inverse_covariance : key
Inverse covariance diagonal key of the Gaussian.
"""
def __init__(self, domain, residual_key, inverse_covariance_key):
self._r = str(residual_key)
self._icov = str(inverse_covariance_key)
dom = DomainTuple.make(domain)
self._domain = MultiDomain.make({self._r: dom, self._icov: dom})
def apply(self, x):
self._check_input(x)
res0 = x[self._r].vdot(x[self._r]*x[self._icov]).real
res1 = x[self._icov].log().sum()
res = 0.5*(res0-res1)
mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
metric = makeOp(MultiField.from_dict(mf))
return res.add_metric(SandwichOperator.make(x.jac, metric))
class GaussianEnergy(EnergyOperator): class GaussianEnergy(EnergyOperator):
"""Computes a negative-log Gaussian. """Computes a negative-log Gaussian.
......
...@@ -43,23 +43,33 @@ def field(request): ...@@ -43,23 +43,33 @@ def field(request):
return ift.MultiField.from_dict({'s1': s})['s1'] return ift.MultiField.from_dict({'s1': s})['s1']
def test_variablecovariancegaussian(field):
dc = {'a': field, 'b': field.exp()}
mf = ift.MultiField.from_dict(dc)
energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b')
ift.extra.check_jacobian_consistency(energy, mf, tol=1e-6)
energy(ift.Linearization.make_var(mf, want_metric=True)).metric.draw_sample()
def test_gaussian(field): def test_gaussian(field):
energy = ift.GaussianEnergy(domain=field.domain) energy = ift.GaussianEnergy(domain=field.domain)
ift.extra.check_jacobian_consistency(energy, field) ift.extra.check_jacobian_consistency(energy, field)
@pmp('icov', [lambda dom:ift.ScalingOperator(dom, 1.), @pmp('icov', [lambda dom:ift.ScalingOperator(dom, 1.),
lambda dom:ift.SandwichOperator.make(ift.GeometryRemover(dom))]) lambda dom:ift.SandwichOperator.make(ift.GeometryRemover(dom))])
def test_ScaledEnergy(field, icov): def test_ScaledEnergy(field, icov):
icov = icov(field.domain) icov = icov(field.domain)
energy = ift.GaussianEnergy(inverse_covariance=icov) energy = ift.GaussianEnergy(inverse_covariance=icov)
ift.extra.check_jacobian_consistency(energy.scale(0.3), field) ift.extra.check_jacobian_consistency(energy.scale(0.3), field)
lin = ift.Linearization.make_var(field, want_metric=True) lin = ift.Linearization.make_var(field, want_metric=True)
met1 = energy(lin).metric met1 = energy(lin).metric
met2 = energy.scale(0.3)(lin).metric met2 = energy.scale(0.3)(lin).metric
np.testing.assert_allclose(met1(field).val, met2(field).val / 0.3, rtol=1e-12) np.testing.assert_allclose(met1(field).val, met2(field).val/0.3, rtol=1e-12)
met2.draw_sample() met2.draw_sample()
def test_studentt(field): def test_studentt(field):
energy = ift.StudentTEnergy(domain=field.domain, theta=.5) energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-6) ift.extra.check_jacobian_consistency(energy, field, tol=1e-6)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment