Commit 21f0128d authored by Philipp Frank's avatar Philipp Frank Committed by Philipp Arras
Browse files

Implement Geometric KL

parent 0acc9ff7
......@@ -64,13 +64,29 @@ The implementation tests for nonlinear operators are now available in
`ift.extra.check_operator()` and for linear operators
MetricGaussianKL interface
Users do not instantiate `MetricGaussianKL` by its constructor anymore. Rather
`MetricGaussianKL.make()` shall be used. Additionally, `mirror_samples` is not
set by default anymore.
`mirror_samples` is not set by default anymore.
A new posterior approximation scheme, called geometric Variational Inference
(geoVI) was introduced. `GeoMetricKL` is analogous to `MetricGaussianKL` with
the exception that geoVI samples are used instead of MGVI samples. For further
details see (<>).
A new subclass of `EnergyOperator` was introduced and all `EnergyOperator`s
that are likelihoods are now `LikelihoodOperator`s. A `LikelihoodOperator`
has to implement the function `get_transformation`, which returns a
coordinate transformation in which the Fisher metric of the likelihood becomes
the identity matrix.
Changes since NIFTy 5
......@@ -123,7 +123,7 @@ def main():
# Draw new samples to approximate the KL five times
for i in range(5):
# Draw new samples and minimize KL
KL = ift.MetricGaussianKL.make(mean, H, N_samples, True)
KL = ift.MetricGaussianKL(mean, H, N_samples, True)
KL, convergence = minimizer(KL)
mean = KL.position
ift.extra.minisanity(data, lambda x: N.inverse, signal_response,
......@@ -134,7 +134,7 @@ def main():
for i in range(5):
# Draw new samples and minimize KL
KL = ift.MetricGaussianKL.make(mean, H, N_samples, True)
KL = ift.MetricGaussianKL(mean, H, N_samples, True)
KL, convergence = minimizer(KL)
mean = KL.position
......@@ -94,7 +94,7 @@ if __name__ == "__main__":
for i in range(5):
# Draw new samples and minimize KL
kl = ift.MetricGaussianKL.make(mean, ham, n_samples, True)
kl = ift.MetricGaussianKL(mean, ham, n_samples, True)
kl, convergence = minimizer(kl)
mean = kl.position
......@@ -36,6 +36,8 @@ import nifty7 as ift
def main():
use_geo = False
name = 'GEO' if use_geo else 'MGVI'
dom = ift.UnstructuredDomain(1)
scale = 10
......@@ -91,12 +93,16 @@ def main():
plt.figure(figsize=[12, 8])
for ii in range(15):
if ii % 3 == 0:
mgkl = ift.MetricGaussianKL.make(pos, ham, 40, False)
if use_geo:
mini_samp = ift.NewtonCG(
mgkl = ift.GeoMetricKL(pos, ham, 100, mini_samp, False)
mgkl = ift.MetricGaussianKL(pos, ham, 100, False)
plt.imshow(z.T, origin='lower', norm=LogNorm(), vmin=1e-3,
vmax=np.max(z), cmap='gist_earth_r',
extent=x_limits_scaled + y_limits)
plt.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
cmap='gist_earth_r', extent=x_limits_scaled + y_limits)
if ii == 0:
cbar = plt.colorbar()'pdf')
......@@ -105,12 +111,14 @@ def main():
samp = (samp + pos).val
plt.scatter(np.array(xs)*scale, np.array(ys), label='MGVI samples')
plt.scatter(pos.val['a']*scale, pos.val['b'], label='MGVI latent mean')
plt.scatter(np.array(map_xs)*scale, np.array(map_ys),
label='Laplace samples')
plt.scatter(np.array(xs)*scale, np.array(ys), label=name+' samples')
plt.scatter(pos.val['a']*scale, pos.val['b'], label=name+' latent mean')
plt.scatter(MAP.position.val['a']*scale, MAP.position.val['b'],
label='Maximum a posterior solution')
......@@ -71,7 +71,7 @@ from .minimization.scipy_minimizer import L_BFGS_B
from import Energy
from .minimization.quadratic_energy import QuadraticEnergy
from .minimization.energy_adapter import EnergyAdapter
from .minimization.metric_gaussian_kl import MetricGaussianKL
from .minimization.kl_energies import MetricGaussianKL, GeoMetricKL
from .sugar import *
......@@ -526,3 +526,22 @@ def _tableentries(redchisq, scmean, ndof, keylen):
out += f"{ndof[kk]:>11}"
out += "\n"
return out[:-1]
class _KeyModifier(LinearOperator):
def __init__(self, domain, pre):
if not isinstance(domain, MultiDomain):
raise ValueError
from .sugar import makeDomain
self._domain = makeDomain(domain)
self._pre = str(pre)
target = {self._pre+k: domain[k] for k in domain.keys()}
self._target = makeDomain(MultiDomain.make(target))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
res = {self._pre+k:x[k] for k in self._domain.keys()}
res = {k:x[self._pre+k] for k in self._domain.keys()}
return MultiField.from_dict(res, domain=self._tgt(mode))
......@@ -20,15 +20,18 @@ import numpy as np
from .. import utilities
from ..domain_tuple import DomainTuple
from ..field import Field
from ..linearization import Linearization
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..sugar import makeDomain, makeOp
from ..utilities import myassert
from .linear_operator import LinearOperator
from .operator import Operator
from .adder import Adder
from .sampling_enabler import SamplingDtypeSetter, SamplingEnabler
from .scaling_operator import ScalingOperator
from .simple_linear_operators import VdotOperator
from .sandwich_operator import SandwichOperator
from .simple_linear_operators import VdotOperator, FieldAdapter
def _check_sampling_dtype(domain, dtypes):
......@@ -78,6 +81,24 @@ class EnergyOperator(Operator):
_target = DomainTuple.scalar_domain()
class LikelihoodOperator(EnergyOperator):
"""`EnergyOperator` representing a likelihood. The input to the Operator
are the parameters of the likelihood. Unlike a general `EnergyOperator`,
the metric of a `LikelihoodOperator` is the Fisher information metric of
the likelihood.
def get_metric_at(self, x):
"""Computes the Fisher information metric for a `LikelihoodOperator`
at `x` using the Jacobian of the coordinate transformation given by
dtp, f = self.get_transformation()
ch = ScalingOperator(, 1.)
if dtp is not None:
ch = SamplingDtypeSetter(ch, dtp)
return SandwichOperator.make(f(Linearization.make_var(x)).jac, ch)
class Squared2NormOperator(EnergyOperator):
"""Computes the square of the L2-norm of the output of an operator.
......@@ -126,7 +147,7 @@ class QuadraticFormOperator(EnergyOperator):
return, VdotOperator(self._op(x.val)))
class VariableCovarianceGaussianEnergy(EnergyOperator):
class VariableCovarianceGaussianEnergy(LikelihoodOperator):
"""Computes the negative log pdf of a Gaussian with unknown covariance.
The covariance is assumed to be diagonal.
......@@ -152,9 +173,16 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
sampling_dtype : np.dtype
Data type of the samples. Usually either 'np.float*' or 'np.complex*'
use_full_fisher: boolean
Whether or not the proper Fisher information metric should be used as
a `metric`. If False the same approximation used in
`get_transformation` is used instead.
Default is True
def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype):
def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype,
use_full_fisher = True):
self._kr = str(residual_key)
self._ki = str(inverse_covariance_key)
dom = DomainTuple.make(domain)
......@@ -162,6 +190,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
self._dt = {self._kr: sampling_dtype, self._ki: np.float64}
_check_sampling_dtype(self._domain, self._dt)
self._cplx = _iscomplex(sampling_dtype)
self._use_fisher = use_full_fisher
def apply(self, x):
......@@ -172,10 +201,14 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
res = 0.5*(r.vdot(r*i) - i.ptw("log").sum())
if not x.want_metric:
return res
met = 1. if self._cplx else .5
met = MultiField.from_dict({self._kr: i.val, self._ki: met*i.val**(-2)},
return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt))
if self._use_fisher:
met = 1. if self._cplx else 0.5
met = MultiField.from_dict({self._kr: i.val, self._ki: met*i.val**(-2)},
met = SamplingDtypeSetter(makeOp(met), self._dt)
met = self.get_metric_at(x.val)
return res.add_metric(met)
def _simplify_for_constant_input_nontrivial(self, c_inp):
from .simplify_for_const import ConstantEnergyOperator
......@@ -190,20 +223,30 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
res = GaussianEnergy(inverse_covariance=makeOp(cst),
trlog = cst.log().sum().val_rw()
if not _iscomplex(dt):
if not self._cplx:
trlog /= 2
res = res + ConstantEnergyOperator(-trlog)
res = res + ConstantEnergyOperator(0.)
myassert( is
return None, res
def get_transformation(self):
"""Note that for the metric of a `VariableCovarianceGaussianEnergy` no
global transformation to Euclidean space exists. A local approximation
ivoking the resudual is used instead.
r = FieldAdapter(self._domain[self._kr], self._kr)
ivar = FieldAdapter(self._domain[self._kr], self._ki)
sc = 1. if self._cplx else 0.5
return self._dt, r.adjoint@(ivar.ptw('sqrt')*r) + ivar.adjoint@(sc*ivar.ptw('log'))
class _SpecialGammaEnergy(EnergyOperator):
class _SpecialGammaEnergy(LikelihoodOperator):
def __init__(self, residual):
self._domain = DomainTuple.make(residual.domain)
self._resi = residual
self._cplx = _iscomplex(self._resi.dtype)
self._scale = ScalingOperator(self._domain, 1 if self._cplx else .5)
self._dt = self._resi.dtype
def apply(self, x):
......@@ -214,11 +257,13 @@ class _SpecialGammaEnergy(EnergyOperator):
res = 0.5*((r*x).vdot(r) - x.ptw("log").sum())
if not x.want_metric:
return res
met = makeOp((self._scale(x.val))**(-2))
return res.add_metric(SamplingDtypeSetter(met, self._resi.dtype))
return res.add_metric(self.get_metric_at(x.val))
def get_transformation(self):
sc = 1. if self._cplx else np.sqrt(0.5)
return self._dt, sc*ScalingOperator(self._domain, 1.).ptw('log')
class GaussianEnergy(EnergyOperator):
class GaussianEnergy(LikelihoodOperator):
"""Computes a negative-log Gaussian.
Represents up to constants in :math:`m`:
......@@ -301,15 +346,22 @@ class GaussianEnergy(EnergyOperator):
residual = x if self._mean is None else x - self._mean
res = self._op(residual).real
if x.want_metric:
return res.add_metric(self._met)
return res.add_metric(self.get_metric_at(x.val))
return res
def get_transformation(self):
icov, dtp = self._met, None
if isinstance(icov, SamplingDtypeSetter):
dtp = icov._dtype
icov = icov._op
return dtp, icov.get_sqrt()
def __repr__(self):
dom = '()' if isinstance(self.domain, DomainTuple) else self.domain.keys()
return f'GaussianEnergy {dom}'
class PoissonianEnergy(EnergyOperator):
class PoissonianEnergy(LikelihoodOperator):
"""Computes likelihood Hamiltonians of expected count field constrained by
Poissonian count data.
......@@ -341,10 +393,12 @@ class PoissonianEnergy(EnergyOperator):
res = x.sum() - x.ptw("log").vdot(self._d)
if not x.want_metric:
return res
return res.add_metric(SamplingDtypeSetter(makeOp(1./x.val), np.float64))
return res.add_metric(self.get_metric_at(x.val))
def get_transformation(self):
return np.float64, 2.*ScalingOperator(self._domain,1.).sqrt()
class InverseGammaLikelihood(EnergyOperator):
class InverseGammaLikelihood(LikelihoodOperator):
"""Computes the negative log-likelihood of the inverse gamma distribution.
It negative log-pdf(x) is given by
......@@ -385,13 +439,14 @@ class InverseGammaLikelihood(EnergyOperator):
res = x.ptw("log").vdot(self._alphap1) + x.ptw("reciprocal").vdot(self._beta)
if not x.want_metric:
return res
met = makeOp(self._alphap1/(x.val**2))
if self._sampling_dtype is not None:
met = SamplingDtypeSetter(met, self._sampling_dtype)
return res.add_metric(met)
return res.add_metric(self.get_metric_at(x.val))
def get_transformation(self):
fact = self._alphap1.ptw('sqrt')
res = makeOp(fact)@ScalingOperator(self._domain,1.).ptw('log')
return self._sampling_dtype, res
class StudentTEnergy(EnergyOperator):
class StudentTEnergy(LikelihoodOperator):
"""Computes likelihood energy corresponding to Student's t-distribution.
.. math ::
......@@ -418,11 +473,17 @@ class StudentTEnergy(EnergyOperator):
res = (((self._theta+1)/2)*(x**2/self._theta).ptw("log1p")).sum()
if not x.want_metric:
return res
met = makeOp((self._theta+1) / (self._theta+3), self.domain)
return res.add_metric(SamplingDtypeSetter(met, np.float64))
return res.add_metric(self.get_metric_at(x.val))
def get_transformation(self):
if isinstance(self._theta, Field) or isinstance(self._theta, MultiField):
th = self._theta
from ..extra import full
th = full(self._domain, self._theta)
return np.float64, makeOp(((th+1)/(th+3)).ptw('sqrt'))
class BernoulliEnergy(EnergyOperator):
class BernoulliEnergy(LikelihoodOperator):
"""Computes likelihood energy of expected event frequency constrained by
event data.
......@@ -452,9 +513,13 @@ class BernoulliEnergy(EnergyOperator):
res = -x.ptw("log").vdot(self._d) + (1.-x).ptw("log").vdot(self._d-1.)
if not x.want_metric:
return res
met = makeOp(1./(x.val*(1. - x.val)))
return res.add_metric(SamplingDtypeSetter(met, np.float64))
return res.add_metric(self.get_metric_at(x.val))
def get_transformation(self):
from ..extra import full
res = Adder(full(self._domain,1.))@ScalingOperator(self._domain,-1)
res = res * ScalingOperator(self._domain,1).ptw('reciprocal')
return np.float64, -2.*res.ptw('sqrt').ptw('arctan')
class StandardHamiltonian(EnergyOperator):
"""Computes an information Hamiltonian in its standard form, i.e. with the
......@@ -546,3 +611,11 @@ class AveragedEnergy(EnergyOperator):
mymap = map(lambda v: self._h(x+v), self._res_samples)
return utilities.my_sum(mymap)/len(self._res_samples)
def get_transformation(self):
dtp, trafo = self._h.get_transformation()
mymap = map(lambda v: trafo@Adder(v), self._res_samples)
return dtp, utilities.my_sum(mymap)/np.sqrt(len(self._res_samples))
......@@ -107,6 +107,23 @@ class Operator(metaclass=NiftyMeta):
return None
def get_transformation(self):
"""The coordinate transformation that maps into a coordinate system
where the metric of a likelihood is the Euclidean metric.
This is `None`, except when the object is considered a likelihood i.E.
for an instance of `EnergyOperator` with its metric being a proper
Fisher information metric, or a sum or nested sum thereof.
np.dtype, or dict of np.dtype : The dtype(s) of the target space of the
Operator : The transformation that maps from `domain` into the
Euclidean target space.
return None
def _check_domain_equality(dom_op, dom_field):
if dom_op != dom_field:
......@@ -402,6 +419,12 @@ class _OpChain(_CombinedOperator):
x = op(x)
return x
def get_transformation(self):
tr = self._ops[0].get_transformation()
if tr is None:
return tr
return tr[0], _OpChain.make((tr[1],)+self._ops[1:])
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain):
......@@ -488,6 +511,25 @@ class _OpSum(Operator):
res = res.add_metric(lin1._metric._myadd(lin2._metric, False))
return res
def get_transformation(self):
tr1 = self._op1.get_transformation()
tr2 = self._op2.get_transformation()
if tr1 is None or tr2 is None:
return None
from ..extra import _KeyModifier
dtype, trafo = {}, None
for i, lh in enumerate([self._op1, self._op2]):
dtp, tr = lh.get_transformation()
if isinstance(, MultiDomain):
dtype.update({str(i)+d:dtp[d] for d in dtp.keys()})
tr = _KeyModifier(, str(i)) @ tr
trafo = tr if trafo is None else trafo+tr
dtype[str(i)] = dtp
tr = tr.ducktape_left(str(i))
trafo = tr if trafo is None else trafo + tr
return dtype, trafo
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
from .simplify_for_const import ConstCollector
......@@ -73,6 +73,9 @@ class SandwichOperator(EndomorphicOperator):
# If our sandwich is diagonal, we can return immediately
if isinstance(op, (ScalingOperator, DiagonalOperator)):
if isinstance(cheese, SamplingDtypeSetter):
return SamplingDtypeSetter(op, cheese._dtype)
return op
return SandwichOperator(bun, cheese, op, _callingfrommake=True)
......@@ -39,7 +39,7 @@ from .plot import Plot
__all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'density_estimator', 'create_harmonic_smoothing_operator',
'from_random', 'full', 'makeField', 'is_fieldlike',
'is_linearization', 'is_operator', 'makeDomain',
'is_linearization', 'is_operator', 'makeDomain', 'is_likelihood',
'get_signal_variance', 'makeOp', 'domain_union',
'get_default_codomain', 'single_plot', 'exec_time',
'calculate_position'] + list(pointwise.ptw_dict.keys())
......@@ -221,9 +221,9 @@ def density_estimator(domain, pad=1.0, cf_fluctuations=None,
from .library.correlated_fields import CorrelatedFieldMaker
from .library.special_distributions import UniformOperator
cf_azm_uniform_sane_default = (1e-15, 5.0)
cf_azm_uniform_sane_default = (1e-4, 1.0)
cf_fluctuations_sane_default = {
"scale": (0.3, 0.2),
"scale": (0.5, 0.3),
"cutoff": (4.0, 3.0),
"loglogslope": (-6.0, 3.0)
......@@ -557,7 +557,7 @@ def calculate_position(operator, output):
"""Finds approximate preimage of an operator for a given output."""
from .minimization.descent_minimizers import NewtonCG
from .minimization.iteration_controllers import GradientNormController
from .minimization.metric_gaussian_kl import MetricGaussianKL
from .minimization.kl_energies import MetricGaussianKL
from .operators.scaling_operator import ScalingOperator
from .operators.energy_operators import GaussianEnergy, StandardHamiltonian
if not isinstance(operator, Operator):
......@@ -585,18 +585,24 @@ def calculate_position(operator, output):
minimizer = NewtonCG(GradientNormController(iteration_limit=10, name='findpos'))
for ii in range(3):'Start iteration {ii+1}/3')
kl = MetricGaussianKL.make(pos, H, 3, True)
kl = MetricGaussianKL(pos, H, 3, True)
kl, _ = minimizer(kl)
pos = kl.position
return pos
def is_likelihood(obj):
"""Checks if object is likelihood-like.
return isinstance(obj, Operator) and obj.get_transformation() is not None
def is_operator(obj):
"""Check if object is operator-like.
"""Checks if object is operator-like.
A simple `isinstance(obj, ift.Operator)` does not give the expected
A simple `isinstance(obj, ift.Operator)` does give the expected
result because, e.g., :class:`~nifty7.field.Field` inherits from
......@@ -604,20 +610,19 @@ def is_operator(obj):
def is_linearization(obj):
"""Check if object is linearization-like."""
"""Checks if object is linearization-like."""
return isinstance(obj, Operator) and obj.jac is not None
def is_fieldlike(obj):
"""Check if object is field-like.
"""Checks if object is field-like.
A simple `isinstance(obj, ift.Field)` does not give the expected
A simple `isinstance(obj, ift.Field)` does give the expected
result because users might have implemented another class which
behaves field-like but is not an instance of
:class:`~nifty7.field.Field`. Instances of
:class:`~nifty7.linearization.Linearization` are considered to be
:class:`~nifty7.field.Field`. Also not that instances of
:class:`~nifty7.linearization.Linearization` behave field-like.
return isinstance(obj, Operator) and obj.val is not None
......@@ -31,7 +31,8 @@ pmp = pytest.mark.parametrize
@pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b']))
@pmp('mirror_samples', (True, False))
@pmp('mf', (True, False))
def test_kl(constants, point_estimates, mirror_samples, mf):
@pmp('geo', (True, False))
def test_kl(constants, point_estimates, mirror_samples, mf, geo):
if not mf and (len(point_estimates) != 0 or len(constants) != 0):
dom = ift.RGSpace((12,), (2.12))
......@@ -51,11 +52,19 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
'n_samples': nsamps,
'mean': mean0,
'hamiltonian': h}
if geo:
args['minimizer_samp'] = ift.NewtonCG(ic)
if isinstance(mean0, ift.MultiField) and set(point_estimates) == set(mean0.keys()):
with assert_raises(RuntimeError):
if geo:
kl = ift.MetricGaussianKL.make(**args)
if geo:
kl = ift.GeoMetricKL(**args)
kl = ift.MetricGaussianKL(**args)
myassert(len(ic.history) > 0)
myassert(len(ic.history) == len(ic.history.time_stamps))
myassert(len(ic.history) == len(ic.history.energy_values))
......@@ -71,8 +80,10 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
tmph = h
tmpmean = mean0
klpure = ift.MetricGaussianKL(tmpmean, tmph, nsamps, mirror_samples, None, locsamp, False, True)
if geo and mirror_samples:
klpure = ift.minimization.kl_energies._SampledKLEnergy(tmpmean, tmph, 2*nsamps, False, None, locsamp, False)
klpure = ift.minimization.kl_energies._SampledKLEnergy(tmpmean, tmph, nsamps, mirror_samples, None, locsamp, False)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
myassert(len(tuple(kl.samples)) == expected_nsamps)