Commit 4231c07b authored by Reimar H Leike's avatar Reimar H Leike
Browse files

added a student-t energy, a Log(1+x) nonlinearity and a test for the energy

parent 156c9d79
......@@ -20,6 +20,7 @@ from .multi_field import MultiField
from .operators.operator import Operator
from .operators.adder import Adder
from .operators.log1p import Log1p
from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter
......@@ -51,7 +52,7 @@ from .operators.value_inserter import ValueInserter
from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
BernoulliEnergy, StandardHamiltonian, AveragedEnergy, QuadraticFormOperator,
Squared2NormOperator, StudentTEnergy)
from .operators.convolution_operators import FuncConvolutionOperator
from .probing import probe_with_posterior_samples, probe_diagonal, \
......@@ -27,6 +27,7 @@ from .linear_operator import LinearOperator
from .operator import Operator
from .sampling_enabler import SamplingEnabler
from .sandwich_operator import SandwichOperator
from .scaling_operator import ScalingOperator
from .simple_linear_operators import VdotOperator
......@@ -64,7 +65,6 @@ class Squared2NormOperator(EnergyOperator):
return, jac)
return Field.scalar(x.vdot(x))
class QuadraticFormOperator(EnergyOperator):
"""Computes the L2-norm of a Field or MultiField with respect to a
specific kernel given by `endo`.
......@@ -248,6 +248,43 @@ class InverseGammaLikelihood(EnergyOperator):
return res.add_metric(metric)
class StudentTEnergy(EnergyOperator):
"""Computes likelihood energy of expected event frequency constrained by
event data.
.. math ::
E(f) = -\\log \\text{Bernoulli}(d|f)
= -d^\\dagger \\log f - (1-d)^\\dagger \\log(1-f),
where f is a field defined on `d.domain` with the expected
frequencies of events.
d : Field
Data field with events (1) or non-events (0).
theta : Scalar
Degree of freedom parameter for the student t distribution
def __init__(self, domain, theta):
self._domain = DomainTuple.make(domain)
self._theta = theta
from .log1p import Log1p
self._l1p = Log1p(domain)
def apply(self, x):
v = ((self._theta+1)/2)*self._l1p(x**2/self._theta).sum()
if not isinstance(x, Linearization):
return Field.scalar(v)
if not x.want_metric:
return v
met = ScalingOperator(self.domain, (self._theta+1)/(self._theta+3))
met = SandwichOperator.make(x.jac, met)
return v.add_metric(met)
class BernoulliEnergy(EnergyOperator):
"""Computes likelihood energy of expected event frequency constrained by
event data.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <>.
# Copyright(C) 2013-2019 Max-Planck-Society
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..field import Field
from ..multi_field import MultiField
from .operator import Operator
from .diagonal_operator import DiagonalOperator
from ..linearization import Linearization
from ..sugar import from_local_data
from numpy import log1p
class Log1p(Operator):
"""computes x -> log(1+x)
def __init__(self, dom):
self._domain = dom
self._target = dom
def apply(self, x):
lin = isinstance(x, Linearization)
xval = x.val if lin else x
xlval = xval.local_data
res = from_local_data(x.domain, log1p(xlval))
if not lin:
return res
jac = DiagonalOperator(1/(1+xval))
return, jac@x.jac)
......@@ -46,6 +46,9 @@ def test_gaussian(field):
energy = ift.GaussianEnergy(domain=field.domain)
ift.extra.check_jacobian_consistency(energy, field)
def test_studentt(field):
energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-6)
def test_inverse_gamma(field):
field = field.exp()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment