Commit acb3d258 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'fix_mpi_kl' into 'NIFTy_5'

Fix mpi kl

See merge request !366
parents 3490e9ce 0dfcda3b
Pipeline #63108 passed with stages
in 7 minutes and 39 seconds
......@@ -20,6 +20,7 @@ from .multi_field import MultiField
from .operators.operator import Operator
from .operators.adder import Adder
from .operators.log1p import Log1p
from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter
......@@ -51,7 +52,7 @@ from .operators.value_inserter import ValueInserter
from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
BernoulliEnergy, StandardHamiltonian, AveragedEnergy, QuadraticFormOperator,
Squared2NormOperator)
Squared2NormOperator, StudentTEnergy)
from .operators.convolution_operators import FuncConvolutionOperator
from .probing import probe_with_posterior_samples, probe_diagonal, \
......
......@@ -159,6 +159,7 @@ class MetricGaussianKL_MPI(Energy):
if napprox > 1:
met._approximation = makeOp(approximation2endo(met, napprox))
_samples = []
rand_state = np.random.get_state()
for i in range(lo, hi):
if mirror_samples:
np.random.seed(i//2+seed_offset)
......@@ -169,8 +170,9 @@ class MetricGaussianKL_MPI(Energy):
_samples.append(((i % 2)*2-1) *
met.draw_sample(from_inverse=True))
else:
np.random.seed(i)
np.random.seed(i+seed_offset)
_samples.append(met.draw_sample(from_inverse=True))
np.random.set_state(rand_state)
_samples = tuple(_samples)
if mirror_samples:
n_samples *= 2
......@@ -240,8 +242,11 @@ class MetricGaussianKL_MPI(Energy):
raise NotImplementedError()
lin = self._lin.with_want_metric()
samp = full(self._hamiltonian.domain, 0.)
rand_state = np.random.get_state()
np.random.seed(rank+np.random.randint(99999))
for v in self._samples:
samp = samp + self._hamiltonian(lin+v).metric.draw_sample(from_inverse=False, dtype=dtype)
np.random.set_state(rand_state)
return allreduce_sum_field(samp)
def metric_sample(self, from_inverse=False, dtype=np.float64):
......
......@@ -27,6 +27,7 @@ from .linear_operator import LinearOperator
from .operator import Operator
from .sampling_enabler import SamplingEnabler
from .sandwich_operator import SandwichOperator
from .scaling_operator import ScalingOperator
from .simple_linear_operators import VdotOperator
......@@ -248,6 +249,43 @@ class InverseGammaLikelihood(EnergyOperator):
return res.add_metric(metric)
class StudentTEnergy(EnergyOperator):
"""Computes likelihood energy of expected event frequency constrained by
event data.
.. math ::
E(f) = -\\log \\text{Bernoulli}(d|f)
= -d^\\dagger \\log f - (1-d)^\\dagger \\log(1-f),
where f is a field defined on `d.domain` with the expected
frequencies of events.
Parameters
----------
d : Field
Data field with events (1) or non-events (0).
theta : Scalar
Degree of freedom parameter for the student t distribution
"""
def __init__(self, domain, theta):
self._domain = DomainTuple.make(domain)
self._theta = theta
from .log1p import Log1p
self._l1p = Log1p(domain)
def apply(self, x):
self._check_input(x)
v = ((self._theta+1)/2)*self._l1p(x**2/self._theta).sum()
if not isinstance(x, Linearization):
return Field.scalar(v)
if not x.want_metric:
return v
met = ScalingOperator((self._theta+1)/(self._theta+3), self.domain)
met = SandwichOperator.make(x.jac, met)
return v.add_metric(met)
class BernoulliEnergy(EnergyOperator):
"""Computes likelihood energy of expected event frequency constrained by
event data.
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..field import Field
from ..multi_field import MultiField
from .operator import Operator
from .diagonal_operator import DiagonalOperator
from ..linearization import Linearization
from ..sugar import from_local_data
from numpy import log1p
class Log1p(Operator):
"""computes x -> log(1+x)
"""
def __init__(self, dom):
self._domain = dom
self._target = dom
def apply(self, x):
lin = isinstance(x, Linearization)
xval = x.val if lin else x
xlval = xval.local_data
res = from_local_data(xval.domain, log1p(xlval))
if not lin:
return res
jac = DiagonalOperator(1/(1+xval))
return x.new(res, jac@x.jac)
......@@ -47,6 +47,11 @@ def test_gaussian(field):
ift.extra.check_jacobian_consistency(energy, field)
def test_studentt(field):
energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-6)
def test_inverse_gamma(field):
field = field.exp()
space = field.domain
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment