Commit 4231c07b authored by Reimar H Leike's avatar Reimar H Leike
Browse files

added a student-t energy, a Log(1+x) nonlinearity and a test for the energy

parent 156c9d79
Pipeline #62423 passed with stages
in 7 minutes and 5 seconds
...@@ -20,6 +20,7 @@ from .multi_field import MultiField ...@@ -20,6 +20,7 @@ from .multi_field import MultiField
from .operators.operator import Operator from .operators.operator import Operator
from .operators.adder import Adder from .operators.adder import Adder
from .operators.log1p import Log1p
from .operators.diagonal_operator import DiagonalOperator from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter
...@@ -51,7 +52,7 @@ from .operators.value_inserter import ValueInserter ...@@ -51,7 +52,7 @@ from .operators.value_inserter import ValueInserter
from .operators.energy_operators import ( from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood, EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
BernoulliEnergy, StandardHamiltonian, AveragedEnergy, QuadraticFormOperator, BernoulliEnergy, StandardHamiltonian, AveragedEnergy, QuadraticFormOperator,
Squared2NormOperator) Squared2NormOperator, StudentTEnergy)
from .operators.convolution_operators import FuncConvolutionOperator from .operators.convolution_operators import FuncConvolutionOperator
from .probing import probe_with_posterior_samples, probe_diagonal, \ from .probing import probe_with_posterior_samples, probe_diagonal, \
......
...@@ -27,6 +27,7 @@ from .linear_operator import LinearOperator ...@@ -27,6 +27,7 @@ from .linear_operator import LinearOperator
from .operator import Operator from .operator import Operator
from .sampling_enabler import SamplingEnabler from .sampling_enabler import SamplingEnabler
from .sandwich_operator import SandwichOperator from .sandwich_operator import SandwichOperator
from .scaling_operator import ScalingOperator
from .simple_linear_operators import VdotOperator from .simple_linear_operators import VdotOperator
...@@ -64,7 +65,6 @@ class Squared2NormOperator(EnergyOperator): ...@@ -64,7 +65,6 @@ class Squared2NormOperator(EnergyOperator):
return x.new(val, jac) return x.new(val, jac)
return Field.scalar(x.vdot(x)) return Field.scalar(x.vdot(x))
class QuadraticFormOperator(EnergyOperator): class QuadraticFormOperator(EnergyOperator):
"""Computes the L2-norm of a Field or MultiField with respect to a """Computes the L2-norm of a Field or MultiField with respect to a
specific kernel given by `endo`. specific kernel given by `endo`.
...@@ -248,6 +248,43 @@ class InverseGammaLikelihood(EnergyOperator): ...@@ -248,6 +248,43 @@ class InverseGammaLikelihood(EnergyOperator):
return res.add_metric(metric) return res.add_metric(metric)
class StudentTEnergy(EnergyOperator):
"""Computes likelihood energy of expected event frequency constrained by
event data.
.. math ::
E(f) = -\\log \\text{Bernoulli}(d|f)
= -d^\\dagger \\log f - (1-d)^\\dagger \\log(1-f),
where f is a field defined on `d.domain` with the expected
frequencies of events.
Parameters
----------
d : Field
Data field with events (1) or non-events (0).
theta : Scalar
Degree of freedom parameter for the student t distribution
"""
def __init__(self, domain, theta):
self._domain = DomainTuple.make(domain)
self._theta = theta
from .log1p import Log1p
self._l1p = Log1p(domain)
def apply(self, x):
self._check_input(x)
v = ((self._theta+1)/2)*self._l1p(x**2/self._theta).sum()
if not isinstance(x, Linearization):
return Field.scalar(v)
if not x.want_metric:
return v
met = ScalingOperator(self.domain, (self._theta+1)/(self._theta+3))
met = SandwichOperator.make(x.jac, met)
return v.add_metric(met)
class BernoulliEnergy(EnergyOperator): class BernoulliEnergy(EnergyOperator):
"""Computes likelihood energy of expected event frequency constrained by """Computes likelihood energy of expected event frequency constrained by
event data. event data.
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..field import Field
from ..multi_field import MultiField
from .operator import Operator
from .diagonal_operator import DiagonalOperator
from ..linearization import Linearization
from ..sugar import from_local_data
from numpy import log1p
class Log1p(Operator):
"""computes x -> log(1+x)
"""
def __init__(self, dom):
self._domain = dom
self._target = dom
def apply(self, x):
lin = isinstance(x, Linearization)
xval = x.val if lin else x
xlval = xval.local_data
res = from_local_data(x.domain, log1p(xlval))
if not lin:
return res
jac = DiagonalOperator(1/(1+xval))
return x.new(res, jac@x.jac)
...@@ -46,6 +46,9 @@ def test_gaussian(field): ...@@ -46,6 +46,9 @@ def test_gaussian(field):
energy = ift.GaussianEnergy(domain=field.domain) energy = ift.GaussianEnergy(domain=field.domain)
ift.extra.check_jacobian_consistency(energy, field) ift.extra.check_jacobian_consistency(energy, field)
def test_studentt(field):
energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-6)
def test_inverse_gamma(field): def test_inverse_gamma(field):
field = field.exp() field = field.exp()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment