Commit 21f0128d authored by Philipp Frank's avatar Philipp Frank Committed by Philipp Arras
Browse files

Implement Geometric KL

parent 0acc9ff7
...@@ -64,13 +64,29 @@ The implementation tests for nonlinear operators are now available in ...@@ -64,13 +64,29 @@ The implementation tests for nonlinear operators are now available in
`ift.extra.check_operator()` and for linear operators `ift.extra.check_operator()` and for linear operators
`ift.extra.check_linear_operator()`. `ift.extra.check_linear_operator()`.
MetricGaussianKL interface MetricGaussianKL interface
-------------------------- --------------------------
Users do not instantiate `MetricGaussianKL` by its constructor anymore. Rather `mirror_samples` is not set by default anymore.
`MetricGaussianKL.make()` shall be used. Additionally, `mirror_samples` is not
set by default anymore.
GeoMetricKL
-----------
A new posterior approximation scheme, called geometric Variational Inference
(geoVI) was introduced. `GeoMetricKL` is analogous to `MetricGaussianKL` with
the exception that geoVI samples are used instead of MGVI samples. For further
details see (<https://arxiv.org/abs/2105.10470>).
LikelihoodOperator
------------------
A new subclass of `EnergyOperator` was introduced and all `EnergyOperator`s
that are likelihoods are now `LikelihoodOperator`s. A `LikelihoodOperator`
has to implement the function `get_transformation`, which returns a
coordinate transformation in which the Fisher metric of the likelihood becomes
the identity matrix.
Changes since NIFTy 5 Changes since NIFTy 5
......
...@@ -123,7 +123,7 @@ def main(): ...@@ -123,7 +123,7 @@ def main():
# Draw new samples to approximate the KL five times # Draw new samples to approximate the KL five times
for i in range(5): for i in range(5):
# Draw new samples and minimize KL # Draw new samples and minimize KL
KL = ift.MetricGaussianKL.make(mean, H, N_samples, True) KL = ift.MetricGaussianKL(mean, H, N_samples, True)
KL, convergence = minimizer(KL) KL, convergence = minimizer(KL)
mean = KL.position mean = KL.position
ift.extra.minisanity(data, lambda x: N.inverse, signal_response, ift.extra.minisanity(data, lambda x: N.inverse, signal_response,
......
...@@ -134,7 +134,7 @@ def main(): ...@@ -134,7 +134,7 @@ def main():
for i in range(5): for i in range(5):
# Draw new samples and minimize KL # Draw new samples and minimize KL
KL = ift.MetricGaussianKL.make(mean, H, N_samples, True) KL = ift.MetricGaussianKL(mean, H, N_samples, True)
KL, convergence = minimizer(KL) KL, convergence = minimizer(KL)
mean = KL.position mean = KL.position
......
...@@ -94,7 +94,7 @@ if __name__ == "__main__": ...@@ -94,7 +94,7 @@ if __name__ == "__main__":
for i in range(5): for i in range(5):
# Draw new samples and minimize KL # Draw new samples and minimize KL
kl = ift.MetricGaussianKL.make(mean, ham, n_samples, True) kl = ift.MetricGaussianKL(mean, ham, n_samples, True)
kl, convergence = minimizer(kl) kl, convergence = minimizer(kl)
mean = kl.position mean = kl.position
......
...@@ -36,6 +36,8 @@ import nifty7 as ift ...@@ -36,6 +36,8 @@ import nifty7 as ift
def main(): def main():
use_geo = False
name = 'GEO' if use_geo else 'MGVI'
dom = ift.UnstructuredDomain(1) dom = ift.UnstructuredDomain(1)
scale = 10 scale = 10
...@@ -91,12 +93,16 @@ def main(): ...@@ -91,12 +93,16 @@ def main():
plt.figure(figsize=[12, 8]) plt.figure(figsize=[12, 8])
for ii in range(15): for ii in range(15):
if ii % 3 == 0: if ii % 3 == 0:
mgkl = ift.MetricGaussianKL.make(pos, ham, 40, False) if use_geo:
mini_samp = ift.NewtonCG(
ift.GradientNormController(iteration_limit=5))
mgkl = ift.GeoMetricKL(pos, ham, 100, mini_samp, False)
else:
mgkl = ift.MetricGaussianKL(pos, ham, 100, False)
plt.cla() plt.cla()
plt.imshow(z.T, origin='lower', norm=LogNorm(), vmin=1e-3, plt.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
vmax=np.max(z), cmap='gist_earth_r', cmap='gist_earth_r', extent=x_limits_scaled + y_limits)
extent=x_limits_scaled + y_limits)
if ii == 0: if ii == 0:
cbar = plt.colorbar() cbar = plt.colorbar()
cbar.ax.set_ylabel('pdf') cbar.ax.set_ylabel('pdf')
...@@ -105,12 +111,14 @@ def main(): ...@@ -105,12 +111,14 @@ def main():
samp = (samp + pos).val samp = (samp + pos).val
xs.append(samp['a']) xs.append(samp['a'])
ys.append(samp['b']) ys.append(samp['b'])
plt.scatter(np.array(xs)*scale, np.array(ys), label='MGVI samples')
plt.scatter(pos.val['a']*scale, pos.val['b'], label='MGVI latent mean')
plt.scatter(np.array(map_xs)*scale, np.array(map_ys), plt.scatter(np.array(map_xs)*scale, np.array(map_ys),
label='Laplace samples') label='Laplace samples')
plt.scatter(np.array(xs)*scale, np.array(ys), label=name+' samples')
plt.scatter(pos.val['a']*scale, pos.val['b'], label=name+' latent mean')
plt.scatter(MAP.position.val['a']*scale, MAP.position.val['b'], plt.scatter(MAP.position.val['a']*scale, MAP.position.val['b'],
label='Maximum a posterior solution') label='Maximum a posterior solution')
plt.xlim(x_limits_scaled)
plt.ylim(y_limits)
plt.legend() plt.legend()
plt.draw() plt.draw()
plt.pause(1.0) plt.pause(1.0)
......
...@@ -71,7 +71,7 @@ from .minimization.scipy_minimizer import L_BFGS_B ...@@ -71,7 +71,7 @@ from .minimization.scipy_minimizer import L_BFGS_B
from .minimization.energy import Energy from .minimization.energy import Energy
from .minimization.quadratic_energy import QuadraticEnergy from .minimization.quadratic_energy import QuadraticEnergy
from .minimization.energy_adapter import EnergyAdapter from .minimization.energy_adapter import EnergyAdapter
from .minimization.metric_gaussian_kl import MetricGaussianKL from .minimization.kl_energies import MetricGaussianKL, GeoMetricKL
from .sugar import * from .sugar import *
......
...@@ -526,3 +526,22 @@ def _tableentries(redchisq, scmean, ndof, keylen): ...@@ -526,3 +526,22 @@ def _tableentries(redchisq, scmean, ndof, keylen):
out += f"{ndof[kk]:>11}" out += f"{ndof[kk]:>11}"
out += "\n" out += "\n"
return out[:-1] return out[:-1]
class _KeyModifier(LinearOperator):
def __init__(self, domain, pre):
if not isinstance(domain, MultiDomain):
raise ValueError
from .sugar import makeDomain
self._domain = makeDomain(domain)
self._pre = str(pre)
target = {self._pre+k: domain[k] for k in domain.keys()}
self._target = makeDomain(MultiDomain.make(target))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
res = {self._pre+k:x[k] for k in self._domain.keys()}
else:
res = {k:x[self._pre+k] for k in self._domain.keys()}
return MultiField.from_dict(res, domain=self._tgt(mode))
...@@ -20,15 +20,18 @@ import numpy as np ...@@ -20,15 +20,18 @@ import numpy as np
from .. import utilities from .. import utilities
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..field import Field from ..field import Field
from ..linearization import Linearization
from ..multi_domain import MultiDomain from ..multi_domain import MultiDomain
from ..multi_field import MultiField from ..multi_field import MultiField
from ..sugar import makeDomain, makeOp from ..sugar import makeDomain, makeOp
from ..utilities import myassert from ..utilities import myassert
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
from .operator import Operator from .operator import Operator
from .adder import Adder
from .sampling_enabler import SamplingDtypeSetter, SamplingEnabler from .sampling_enabler import SamplingDtypeSetter, SamplingEnabler
from .scaling_operator import ScalingOperator from .scaling_operator import ScalingOperator
from .simple_linear_operators import VdotOperator from .sandwich_operator import SandwichOperator
from .simple_linear_operators import VdotOperator, FieldAdapter
def _check_sampling_dtype(domain, dtypes): def _check_sampling_dtype(domain, dtypes):
...@@ -78,6 +81,24 @@ class EnergyOperator(Operator): ...@@ -78,6 +81,24 @@ class EnergyOperator(Operator):
_target = DomainTuple.scalar_domain() _target = DomainTuple.scalar_domain()
class LikelihoodOperator(EnergyOperator):
"""`EnergyOperator` representing a likelihood. The input to the Operator
are the parameters of the likelihood. Unlike a general `EnergyOperator`,
the metric of a `LikelihoodOperator` is the Fisher information metric of
the likelihood.
"""
def get_metric_at(self, x):
"""Computes the Fisher information metric for a `LikelihoodOperator`
at `x` using the Jacobian of the coordinate transformation given by
`get_transformation`.
"""
dtp, f = self.get_transformation()
ch = ScalingOperator(f.target, 1.)
if dtp is not None:
ch = SamplingDtypeSetter(ch, dtp)
return SandwichOperator.make(f(Linearization.make_var(x)).jac, ch)
class Squared2NormOperator(EnergyOperator): class Squared2NormOperator(EnergyOperator):
"""Computes the square of the L2-norm of the output of an operator. """Computes the square of the L2-norm of the output of an operator.
...@@ -126,7 +147,7 @@ class QuadraticFormOperator(EnergyOperator): ...@@ -126,7 +147,7 @@ class QuadraticFormOperator(EnergyOperator):
return x.new(res, VdotOperator(self._op(x.val))) return x.new(res, VdotOperator(self._op(x.val)))
class VariableCovarianceGaussianEnergy(EnergyOperator): class VariableCovarianceGaussianEnergy(LikelihoodOperator):
"""Computes the negative log pdf of a Gaussian with unknown covariance. """Computes the negative log pdf of a Gaussian with unknown covariance.
The covariance is assumed to be diagonal. The covariance is assumed to be diagonal.
...@@ -152,9 +173,16 @@ class VariableCovarianceGaussianEnergy(EnergyOperator): ...@@ -152,9 +173,16 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
sampling_dtype : np.dtype sampling_dtype : np.dtype
Data type of the samples. Usually either 'np.float*' or 'np.complex*' Data type of the samples. Usually either 'np.float*' or 'np.complex*'
use_full_fisher: boolean
Whether or not the proper Fisher information metric should be used as
a `metric`. If False the same approximation used in
`get_transformation` is used instead.
Default is True
""" """
def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype): def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype,
use_full_fisher = True):
self._kr = str(residual_key) self._kr = str(residual_key)
self._ki = str(inverse_covariance_key) self._ki = str(inverse_covariance_key)
dom = DomainTuple.make(domain) dom = DomainTuple.make(domain)
...@@ -162,6 +190,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator): ...@@ -162,6 +190,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
self._dt = {self._kr: sampling_dtype, self._ki: np.float64} self._dt = {self._kr: sampling_dtype, self._ki: np.float64}
_check_sampling_dtype(self._domain, self._dt) _check_sampling_dtype(self._domain, self._dt)
self._cplx = _iscomplex(sampling_dtype) self._cplx = _iscomplex(sampling_dtype)
self._use_fisher = use_full_fisher
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
...@@ -172,10 +201,14 @@ class VariableCovarianceGaussianEnergy(EnergyOperator): ...@@ -172,10 +201,14 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
res = 0.5*(r.vdot(r*i) - i.ptw("log").sum()) res = 0.5*(r.vdot(r*i) - i.ptw("log").sum())
if not x.want_metric: if not x.want_metric:
return res return res
met = 1. if self._cplx else .5 if self._use_fisher:
met = MultiField.from_dict({self._kr: i.val, self._ki: met*i.val**(-2)}, met = 1. if self._cplx else 0.5
domain=self._domain) met = MultiField.from_dict({self._kr: i.val, self._ki: met*i.val**(-2)},
return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt)) domain=self._domain)
met = SamplingDtypeSetter(makeOp(met), self._dt)
else:
met = self.get_metric_at(x.val)
return res.add_metric(met)
def _simplify_for_constant_input_nontrivial(self, c_inp): def _simplify_for_constant_input_nontrivial(self, c_inp):
from .simplify_for_const import ConstantEnergyOperator from .simplify_for_const import ConstantEnergyOperator
...@@ -190,20 +223,30 @@ class VariableCovarianceGaussianEnergy(EnergyOperator): ...@@ -190,20 +223,30 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
res = GaussianEnergy(inverse_covariance=makeOp(cst), res = GaussianEnergy(inverse_covariance=makeOp(cst),
sampling_dtype=dt).ducktape(self._kr) sampling_dtype=dt).ducktape(self._kr)
trlog = cst.log().sum().val_rw() trlog = cst.log().sum().val_rw()
if not _iscomplex(dt): if not self._cplx:
trlog /= 2 trlog /= 2
res = res + ConstantEnergyOperator(-trlog) res = res + ConstantEnergyOperator(-trlog)
res = res + ConstantEnergyOperator(0.) res = res + ConstantEnergyOperator(0.)
myassert(res.target is self.target) myassert(res.target is self.target)
return None, res return None, res
def get_transformation(self):
"""Note that for the metric of a `VariableCovarianceGaussianEnergy` no
global transformation to Euclidean space exists. A local approximation
ivoking the resudual is used instead.
"""
r = FieldAdapter(self._domain[self._kr], self._kr)
ivar = FieldAdapter(self._domain[self._kr], self._ki)
sc = 1. if self._cplx else 0.5
return self._dt, r.adjoint@(ivar.ptw('sqrt')*r) + ivar.adjoint@(sc*ivar.ptw('log'))
class _SpecialGammaEnergy(EnergyOperator):
class _SpecialGammaEnergy(LikelihoodOperator):
def __init__(self, residual): def __init__(self, residual):
self._domain = DomainTuple.make(residual.domain) self._domain = DomainTuple.make(residual.domain)
self._resi = residual self._resi = residual
self._cplx = _iscomplex(self._resi.dtype) self._cplx = _iscomplex(self._resi.dtype)
self._scale = ScalingOperator(self._domain, 1 if self._cplx else .5) self._dt = self._resi.dtype
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
...@@ -214,11 +257,13 @@ class _SpecialGammaEnergy(EnergyOperator): ...@@ -214,11 +257,13 @@ class _SpecialGammaEnergy(EnergyOperator):
res = 0.5*((r*x).vdot(r) - x.ptw("log").sum()) res = 0.5*((r*x).vdot(r) - x.ptw("log").sum())
if not x.want_metric: if not x.want_metric:
return res return res
met = makeOp((self._scale(x.val))**(-2)) return res.add_metric(self.get_metric_at(x.val))
return res.add_metric(SamplingDtypeSetter(met, self._resi.dtype))
def get_transformation(self):
sc = 1. if self._cplx else np.sqrt(0.5)
return self._dt, sc*ScalingOperator(self._domain, 1.).ptw('log')
class GaussianEnergy(EnergyOperator): class GaussianEnergy(LikelihoodOperator):
"""Computes a negative-log Gaussian. """Computes a negative-log Gaussian.
Represents up to constants in :math:`m`: Represents up to constants in :math:`m`:
...@@ -301,15 +346,22 @@ class GaussianEnergy(EnergyOperator): ...@@ -301,15 +346,22 @@ class GaussianEnergy(EnergyOperator):
residual = x if self._mean is None else x - self._mean residual = x if self._mean is None else x - self._mean
res = self._op(residual).real res = self._op(residual).real
if x.want_metric: if x.want_metric:
return res.add_metric(self._met) return res.add_metric(self.get_metric_at(x.val))
return res return res
def get_transformation(self):
icov, dtp = self._met, None
if isinstance(icov, SamplingDtypeSetter):
dtp = icov._dtype
icov = icov._op
return dtp, icov.get_sqrt()
def __repr__(self): def __repr__(self):
dom = '()' if isinstance(self.domain, DomainTuple) else self.domain.keys() dom = '()' if isinstance(self.domain, DomainTuple) else self.domain.keys()
return f'GaussianEnergy {dom}' return f'GaussianEnergy {dom}'
class PoissonianEnergy(EnergyOperator): class PoissonianEnergy(LikelihoodOperator):
"""Computes likelihood Hamiltonians of expected count field constrained by """Computes likelihood Hamiltonians of expected count field constrained by
Poissonian count data. Poissonian count data.
...@@ -341,10 +393,12 @@ class PoissonianEnergy(EnergyOperator): ...@@ -341,10 +393,12 @@ class PoissonianEnergy(EnergyOperator):
res = x.sum() - x.ptw("log").vdot(self._d) res = x.sum() - x.ptw("log").vdot(self._d)
if not x.want_metric: if not x.want_metric:
return res return res
return res.add_metric(SamplingDtypeSetter(makeOp(1./x.val), np.float64)) return res.add_metric(self.get_metric_at(x.val))
def get_transformation(self):
return np.float64, 2.*ScalingOperator(self._domain,1.).sqrt()
class InverseGammaLikelihood(EnergyOperator): class InverseGammaLikelihood(LikelihoodOperator):
"""Computes the negative log-likelihood of the inverse gamma distribution. """Computes the negative log-likelihood of the inverse gamma distribution.
It negative log-pdf(x) is given by It negative log-pdf(x) is given by
...@@ -385,13 +439,14 @@ class InverseGammaLikelihood(EnergyOperator): ...@@ -385,13 +439,14 @@ class InverseGammaLikelihood(EnergyOperator):
res = x.ptw("log").vdot(self._alphap1) + x.ptw("reciprocal").vdot(self._beta) res = x.ptw("log").vdot(self._alphap1) + x.ptw("reciprocal").vdot(self._beta)
if not x.want_metric: if not x.want_metric:
return res return res
met = makeOp(self._alphap1/(x.val**2)) return res.add_metric(self.get_metric_at(x.val))
if self._sampling_dtype is not None:
met = SamplingDtypeSetter(met, self._sampling_dtype)
return res.add_metric(met)
def get_transformation(self):
fact = self._alphap1.ptw('sqrt')
res = makeOp(fact)@ScalingOperator(self._domain,1.).ptw('log')
return self._sampling_dtype, res
class StudentTEnergy(EnergyOperator): class StudentTEnergy(LikelihoodOperator):
"""Computes likelihood energy corresponding to Student's t-distribution. """Computes likelihood energy corresponding to Student's t-distribution.
.. math :: .. math ::
...@@ -418,11 +473,17 @@ class StudentTEnergy(EnergyOperator): ...@@ -418,11 +473,17 @@ class StudentTEnergy(EnergyOperator):
res = (((self._theta+1)/2)*(x**2/self._theta).ptw("log1p")).sum() res = (((self._theta+1)/2)*(x**2/self._theta).ptw("log1p")).sum()
if not x.want_metric: if not x.want_metric:
return res return res
met = makeOp((self._theta+1) / (self._theta+3), self.domain) return res.add_metric(self.get_metric_at(x.val))
return res.add_metric(SamplingDtypeSetter(met, np.float64))
def get_transformation(self):
if isinstance(self._theta, Field) or isinstance(self._theta, MultiField):
th = self._theta
else:
from ..extra import full
th = full(self._domain, self._theta)
return np.float64, makeOp(((th+1)/(th+3)).ptw('sqrt'))
class BernoulliEnergy(EnergyOperator): class BernoulliEnergy(LikelihoodOperator):
"""Computes likelihood energy of expected event frequency constrained by """Computes likelihood energy of expected event frequency constrained by
event data. event data.
...@@ -452,9 +513,13 @@ class BernoulliEnergy(EnergyOperator): ...@@ -452,9 +513,13 @@ class BernoulliEnergy(EnergyOperator):
res = -x.ptw("log").vdot(self._d) + (1.-x).ptw("log").vdot(self._d-1.) res = -x.ptw("log").vdot(self._d) + (1.-x).ptw("log").vdot(self._d-1.)
if not x.want_metric: if not x.want_metric:
return res return res
met = makeOp(1./(x.val*(1. - x.val))) return res.add_metric(self.get_metric_at(x.val))
return res.add_metric(SamplingDtypeSetter(met, np.float64))
def get_transformation(self):
from ..extra import full
res = Adder(full(self._domain,1.))@ScalingOperator(self._domain,-1)
res = res * ScalingOperator(self._domain,1).ptw('reciprocal')
return np.float64, -2.*res.ptw('sqrt').ptw('arctan')
class StandardHamiltonian(EnergyOperator): class StandardHamiltonian(EnergyOperator):
"""Computes an information Hamiltonian in its standard form, i.e. with the """Computes an information Hamiltonian in its standard form, i.e. with the
...@@ -546,3 +611,11 @@ class AveragedEnergy(EnergyOperator): ...@@ -546,3 +611,11 @@ class AveragedEnergy(EnergyOperator):
self._check_input(x) self._check_input(x)
mymap = map(lambda v: self._h(x+v), self._res_samples) mymap = map(lambda v: self._h(x+v), self._res_samples)
return utilities.my_sum(mymap)/len(self._res_samples) return utilities.my_sum(mymap)/len(self._res_samples)
def get_transformation(self):
dtp, trafo = self._h.get_transformation()
mymap = map(lambda v: trafo@Adder(v), self._res_samples)
return dtp, utilities.my_sum(mymap)/np.sqrt(len(self._res_samples))
...@@ -107,6 +107,23 @@ class Operator(metaclass=NiftyMeta): ...@@ -107,6 +107,23 @@ class Operator(metaclass=NiftyMeta):
""" """
return None return None
def get_transformation(self):
"""The coordinate transformation that maps into a coordinate system
where the metric of a likelihood is the Euclidean metric.
This is `None`, except when the object is considered a likelihood i.E.
for an instance of `EnergyOperator` with its metric being a proper
Fisher information metric, or a sum or nested sum thereof.
Retruns
-------
np.dtype, or dict of np.dtype : The dtype(s) of the target space of the
transformation.
Operator : The transformation that maps from `domain` into the
Euclidean target space.
"""
return None
@staticmethod @staticmethod
def _check_domain_equality(dom_op, dom_field): def _check_domain_equality(dom_op, dom_field):
if dom_op != dom_field: if dom_op != dom_field:
...@@ -402,6 +419,12 @@ class _OpChain(_CombinedOperator): ...@@ -402,6 +419,12 @@ class _OpChain(_CombinedOperator):
x = op(x) x = op(x)
return x return x
def get_transformation(self):
tr = self._ops[0].get_transformation()
if tr is None:
return tr
return tr[0], _OpChain.make((tr[1],)+self._ops[1:])
def _simplify_for_constant_input_nontrivial(self, c_inp): def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain): if not isinstance(self._domain, MultiDomain):
...@@ -488,6 +511,25 @@ class _OpSum(Operator): ...@@ -488,6 +511,25 @@ class _OpSum(Operator):
res = res.add_metric(lin1._metric._myadd(lin2._metric, False)) res = res.add_metric(lin1._metric._myadd(lin2._metric, False))
return res return res
def get_transformation(self):
tr1 = self._op1.get_transformation()
tr2 = self._op2.get_transformation()
if tr1 is None or tr2 is None: