Commit 21f0128d authored by Philipp Frank's avatar Philipp Frank Committed by Philipp Arras
Browse files

Implement Geometric KL

parent 0acc9ff7
......@@ -64,13 +64,29 @@ The implementation tests for nonlinear operators are now available in
`ift.extra.check_operator()` and for linear operators
`ift.extra.check_linear_operator()`.
MetricGaussianKL interface
--------------------------
Users do not instantiate `MetricGaussianKL` by its constructor anymore. Rather
`MetricGaussianKL.make()` shall be used. Additionally, `mirror_samples` is not
set by default anymore.
`mirror_samples` is not set by default anymore.
GeoMetricKL
-----------
A new posterior approximation scheme, called geometric Variational Inference
(geoVI) was introduced. `GeoMetricKL` is analogous to `MetricGaussianKL` with
the exception that geoVI samples are used instead of MGVI samples. For further
details see (<https://arxiv.org/abs/2105.10470>).
LikelihoodOperator
------------------
A new subclass of `EnergyOperator` was introduced and all `EnergyOperator`s
that are likelihoods are now `LikelihoodOperator`s. A `LikelihoodOperator`
has to implement the function `get_transformation`, which returns a
coordinate transformation in which the Fisher metric of the likelihood becomes
the identity matrix.
Changes since NIFTy 5
......
......@@ -123,7 +123,7 @@ def main():
# Draw new samples to approximate the KL five times
for i in range(5):
# Draw new samples and minimize KL
KL = ift.MetricGaussianKL.make(mean, H, N_samples, True)
KL = ift.MetricGaussianKL(mean, H, N_samples, True)
KL, convergence = minimizer(KL)
mean = KL.position
ift.extra.minisanity(data, lambda x: N.inverse, signal_response,
......
......@@ -134,7 +134,7 @@ def main():
for i in range(5):
# Draw new samples and minimize KL
KL = ift.MetricGaussianKL.make(mean, H, N_samples, True)
KL = ift.MetricGaussianKL(mean, H, N_samples, True)
KL, convergence = minimizer(KL)
mean = KL.position
......
......@@ -94,7 +94,7 @@ if __name__ == "__main__":
for i in range(5):
# Draw new samples and minimize KL
kl = ift.MetricGaussianKL.make(mean, ham, n_samples, True)
kl = ift.MetricGaussianKL(mean, ham, n_samples, True)
kl, convergence = minimizer(kl)
mean = kl.position
......
......@@ -36,6 +36,8 @@ import nifty7 as ift
def main():
use_geo = False
name = 'GEO' if use_geo else 'MGVI'
dom = ift.UnstructuredDomain(1)
scale = 10
......@@ -91,12 +93,16 @@ def main():
plt.figure(figsize=[12, 8])
for ii in range(15):
if ii % 3 == 0:
mgkl = ift.MetricGaussianKL.make(pos, ham, 40, False)
if use_geo:
mini_samp = ift.NewtonCG(
ift.GradientNormController(iteration_limit=5))
mgkl = ift.GeoMetricKL(pos, ham, 100, mini_samp, False)
else:
mgkl = ift.MetricGaussianKL(pos, ham, 100, False)
plt.cla()
plt.imshow(z.T, origin='lower', norm=LogNorm(), vmin=1e-3,
vmax=np.max(z), cmap='gist_earth_r',
extent=x_limits_scaled + y_limits)
plt.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
cmap='gist_earth_r', extent=x_limits_scaled + y_limits)
if ii == 0:
cbar = plt.colorbar()
cbar.ax.set_ylabel('pdf')
......@@ -105,12 +111,14 @@ def main():
samp = (samp + pos).val
xs.append(samp['a'])
ys.append(samp['b'])
plt.scatter(np.array(xs)*scale, np.array(ys), label='MGVI samples')
plt.scatter(pos.val['a']*scale, pos.val['b'], label='MGVI latent mean')
plt.scatter(np.array(map_xs)*scale, np.array(map_ys),
label='Laplace samples')
plt.scatter(np.array(xs)*scale, np.array(ys), label=name+' samples')
plt.scatter(pos.val['a']*scale, pos.val['b'], label=name+' latent mean')
plt.scatter(MAP.position.val['a']*scale, MAP.position.val['b'],
label='Maximum a posterior solution')
plt.xlim(x_limits_scaled)
plt.ylim(y_limits)
plt.legend()
plt.draw()
plt.pause(1.0)
......
......@@ -71,7 +71,7 @@ from .minimization.scipy_minimizer import L_BFGS_B
from .minimization.energy import Energy
from .minimization.quadratic_energy import QuadraticEnergy
from .minimization.energy_adapter import EnergyAdapter
from .minimization.metric_gaussian_kl import MetricGaussianKL
from .minimization.kl_energies import MetricGaussianKL, GeoMetricKL
from .sugar import *
......
......@@ -526,3 +526,22 @@ def _tableentries(redchisq, scmean, ndof, keylen):
out += f"{ndof[kk]:>11}"
out += "\n"
return out[:-1]
class _KeyModifier(LinearOperator):
def __init__(self, domain, pre):
if not isinstance(domain, MultiDomain):
raise ValueError
from .sugar import makeDomain
self._domain = makeDomain(domain)
self._pre = str(pre)
target = {self._pre+k: domain[k] for k in domain.keys()}
self._target = makeDomain(MultiDomain.make(target))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
res = {self._pre+k:x[k] for k in self._domain.keys()}
else:
res = {k:x[self._pre+k] for k in self._domain.keys()}
return MultiField.from_dict(res, domain=self._tgt(mode))
......@@ -11,33 +11,43 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
# Authors: Philipp Frank
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from functools import reduce
from .. import random, utilities
from .. import random
from .. import utilities
from ..linearization import Linearization
from ..logger import logger
from ..multi_field import MultiField
from ..operators.inversion_enabler import InversionEnabler
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import StandardHamiltonian
from ..operators.energy_operators import StandardHamiltonian, GaussianEnergy
from ..operators.sandwich_operator import SandwichOperator
from ..operators.sampling_enabler import SamplingDtypeSetter
from ..operators.scaling_operator import ScalingOperator
from .energy_adapter import EnergyAdapter
from .quadratic_energy import QuadraticEnergy
from ..probing import approximation2endo
from ..sugar import makeOp
from ..operators.adder import Adder
from ..utilities import myassert
from .energy import Energy
from .descent_minimizers import DescentMinimizer, ConjugateGradient
class _KLMetric(EndomorphicOperator):
def __init__(self, KL):
self._KL = KL
self._capability = self.TIMES | self.ADJOINT_TIMES
self._domain = KL.position.domain
def apply(self, x, mode):
self._check_input(x, mode)
return self._KL.apply_metric(x)
def _is_prior_dtype_float(H):
real = True
dts = H._prior._met._dtype
if isinstance(dts, dict):
for k in dts.keys():
if not np.issubdtype(dts[k], np.float):
real = False
else:
real = np.issubdtype(dts, np.float)
return real
def _get_lo_hi(comm, n_samples):
......@@ -48,9 +58,9 @@ def _get_lo_hi(comm, n_samples):
def _modify_sample_domain(sample, domain):
"""Takes only keys from sample which are also in domain and inserts zeros
for keys which are not in sample.domain."""
from ..domain_tuple import DomainTuple
from ..field import Field
from ..multi_domain import MultiDomain
from ..field import Field
from ..domain_tuple import DomainTuple
from ..sugar import makeDomain
domain = makeDomain(domain)
if isinstance(domain, DomainTuple) and isinstance(sample, Field):
......@@ -66,37 +76,32 @@ def _modify_sample_domain(sample, domain):
raise TypeError
class MetricGaussianKL(Energy):
"""Provides the sampled Kullback-Leibler divergence between a distribution
and a Metric Gaussian.
def _reduce_by_keys(mean, hamiltonian, keys):
if isinstance(mean, MultiField):
_, new_ham = hamiltonian.simplify_for_constant_input(mean.extract_by_keys(keys))
new_mean = mean.extract_by_keys(set(mean.keys()) - set(keys))
return new_mean, new_ham
return mean, hamiltonian
A Metric Gaussian is used to approximate another probability distribution.
It is a Gaussian distribution that uses the Fisher information metric of
the other distribution at the location of its mean to approximate the
variance. In order to infer the mean, a stochastic estimate of the
Kullback-Leibler divergence is minimized. This estimate is obtained by
sampling the Metric Gaussian at the current mean. During minimization
these samples are kept constant; only the mean is updated. Due to the
typically nonlinear structure of the true distribution these samples have
to be updated eventually by intantiating `MetricGaussianKL` again. For the
true probability distribution the standard parametrization is assumed.
The samples of this class can be distributed among MPI tasks.
Notes
-----
class _KLMetric(EndomorphicOperator):
def __init__(self, KL):
self._KL = KL
self._capability = self.TIMES | self.ADJOINT_TIMES
self._domain = KL.position.domain
DomainTuples should never be created using the constructor, but rather
via the factory function :attr:`make`!
See also
--------
`Metric Gaussian Variational Inference`, Jakob Knollmüller,
Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
def apply(self, x, mode):
self._check_input(x, mode)
return self._KL.apply_metric(x)
class _SampledKLEnergy(Energy):
"""Base class for Energies representing a sampled Kullback-Leibler divergence
for the variational approximation of a distribution with another distribution.
"""
def __init__(self, mean, hamiltonian, n_samples, mirror_samples, comm,
local_samples, nanisinf, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(MetricGaussianKL, self).__init__(mean)
local_samples, nanisinf):
super(_SampledKLEnergy, self).__init__(mean)
myassert(mean.domain is hamiltonian.domain)
self._hamiltonian = hamiltonian
self._n_samples = int(n_samples)
......@@ -104,7 +109,7 @@ class MetricGaussianKL(Energy):
self._comm = comm
self._local_samples = local_samples
self._nanisinf = bool(nanisinf)
lin = Linearization.make_var(mean)
v, g = [], []
for s in self._local_samples:
......@@ -123,94 +128,6 @@ class MetricGaussianKL(Energy):
self._val = np.inf
self._grad = utilities.allreduce_sum(g, self._comm)/self.n_eff_samples
@staticmethod
def make(mean, hamiltonian, n_samples, mirror_samples, constants=[],
point_estimates=[], napprox=0, comm=None, nanisinf=False):
"""Return instance of :class:`MetricGaussianKL`.
Parameters
----------
mean : Field
Mean of the Gaussian probability distribution.
hamiltonian : StandardHamiltonian
Hamiltonian of the approximated probability distribution.
n_samples : integer
Number of samples used to stochastically estimate the KL.
mirror_samples : boolean
Whether the negative of the drawn samples are also used, as they are
equally legitimate samples. If true, the number of used samples
doubles. Mirroring samples stabilizes the KL estimate as extreme
sample variation is counterbalanced. Since it improves stability in
many cases, it is recommended to set `mirror_samples` to `True`.
constants : list
List of parameter keys that are kept constant during optimization.
Default is no constants.
point_estimates : list
List of parameter keys for which no samples are drawn, but that are
(possibly) optimized for, corresponding to point estimates of these.
Default is to draw samples for the complete domain.
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible
across this communicator. If `mirror_samples` is set, then a sample and
its mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on
these occaisions but rather the minimizer is told that the position it
has tried is not sensible.
Note
----
The two lists `constants` and `point_estimates` are independent from each
other. It is possible to sample along domains which are kept constant
during minimization and vice versa.
"""
if not isinstance(hamiltonian, StandardHamiltonian):
raise TypeError
if hamiltonian.domain is not mean.domain:
raise ValueError
if not isinstance(n_samples, int):
raise TypeError
if not isinstance(mirror_samples, bool):
raise TypeError
if isinstance(mean, MultiField) and set(point_estimates) == set(mean.keys()):
raise RuntimeError(
'Point estimates for whole domain. Use EnergyAdapter instead.')
n_samples = int(n_samples)
mirror_samples = bool(mirror_samples)
if isinstance(mean, MultiField):
cstpos = mean.extract_by_keys(point_estimates)
_, ham_sampling = hamiltonian.simplify_for_constant_input(cstpos)
else:
ham_sampling = hamiltonian
lin = Linearization.make_var(mean.extract(ham_sampling.domain), True)
met = ham_sampling(lin).metric
if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox))
local_samples = []
sseq = random.spawn_sseq(n_samples)
for i in range(*_get_lo_hi(comm, n_samples)):
with random.Context(sseq[i]):
local_samples.append(met.draw_sample(from_inverse=True))
local_samples = tuple(local_samples)
if isinstance(mean, MultiField):
_, hamiltonian = hamiltonian.simplify_for_constant_input(mean.extract_by_keys(constants))
mean = mean.extract_by_keys(set(mean.keys()) - set(constants))
return MetricGaussianKL(
mean, hamiltonian, n_samples, mirror_samples, comm, local_samples,
nanisinf, _callingfrommake=True)
def at(self, position):
return MetricGaussianKL(
position, self._hamiltonian, self._n_samples, self._mirror_samples,
self._comm, self._local_samples, self._nanisinf, True)
@property
def value(self):
return self._val
......@@ -219,6 +136,11 @@ class MetricGaussianKL(Energy):
def gradient(self):
return self._grad
def at(self, position):
return _SampledKLEnergy(
position, self._hamiltonian, self._n_samples, self._mirror_samples,
self._comm, self._local_samples, self._nanisinf)
def apply_metric(self, x):
lin = Linearization.make_var(self.position, want_metric=True)
res = []
......@@ -258,3 +180,304 @@ class MetricGaussianKL(Energy):
yield s
if self._mirror_samples:
yield -s
class _GeoMetricSampler:
def __init__(self, position, H, minimizer, start_from_lin,
n_samples, mirror_samples, napprox=0, want_error=False):
if not isinstance(H, StandardHamiltonian):
raise NotImplementedError
if not _is_prior_dtype_float(H):
raise ValueError("_GeoMetricSampler only supports real valued latent DOFs.")
if isinstance(position, MultiField):
self._position = position.extract(H.domain)
else:
self._position = position
tr = H._lh.get_transformation()
if tr is None:
raise ValueError("_GeoMetricSampler only works for likelihoods")
dtype, f_lh = tr
scale = ScalingOperator(f_lh.target, 1.)
if isinstance(dtype, dict):
sampling = reduce((lambda a,b: a*b),
[dtype[k] is not None for k in dtype.keys()])
else:
sampling = dtype is not None
scale = SamplingDtypeSetter(scale, dtype) if sampling else scale
fl = f_lh(Linearization.make_var(self._position))
self._g = (Adder(-self._position) +
fl.jac.adjoint@Adder(-fl.val)@f_lh)
self._likelihood = SandwichOperator.make(fl.jac, scale)
self._prior = SamplingDtypeSetter(ScalingOperator(fl.domain,1.), np.float64)
self._met = self._likelihood + self._prior
if napprox >=1:
self._approximation = makeOp(approximation2endo(self._met, napprox)).inverse
else:
self._approximation = None
self._ic = H._ic_samp
self._minimizer = minimizer
self._start_from_lin = start_from_lin
self._want_error = want_error
sseq = random.spawn_sseq(n_samples)
if mirror_samples:
mysseq = []
for seq in sseq:
mysseq += [seq, seq]
else:
mysseq = sseq
self._sseq = mysseq
self._neg = (False, True)*n_samples if mirror_samples else (False, )*n_samples
self._n_samples = n_samples
self._mirror_samples = mirror_samples
@property
def n_eff_samples(self):
return 2*self._n_samples if self._mirror_samples else self._n_samples
@property
def position(self):
return self._position
def _draw_lin(self, neg):
s = self._prior.draw_sample(from_inverse=True)
s = -s if neg else s
nj = self._likelihood.draw_sample()
nj = -nj if neg else nj
y = self._prior(s) + nj
if self._start_from_lin:
energy = QuadraticEnergy(s, self._met, y,
_grad=self._likelihood(s) - nj)
inverter = ConjugateGradient(self._ic)
energy, convergence = inverter(energy,
preconditioner=self._approximation)
yi = energy.position
else:
yi = s
return y, yi
def _draw_nonlin(self, y, yi):
en = EnergyAdapter(self._position+yi, GaussianEnergy(mean=y)@self._g,
nanisinf=True, want_metric=True)
en, _ = self._minimizer(en)
sam = en.position - self._position
if self._want_error:
er = y - self._g(sam)
er = er.s_vdot(InversionEnabler(self._met, self._ic).inverse(er))
return sam, er
return sam
def draw_samples(self, comm):
local_samples = []
prev = None
for i in range(*_get_lo_hi(comm, self.n_eff_samples)):
with random.Context(self._sseq[i]):
neg = self._neg[i]
if (prev is None) or not self._mirror_samples:
y, yi = self._draw_lin(neg)
if not neg:
prev = (-y, -yi)
else:
(y, yi) = prev
prev = None
local_samples.append(self._draw_nonlin(y, yi))
return tuple(local_samples)
def MetricGaussianKL(mean, hamiltonian, n_samples, mirror_samples, constants=[],
point_estimates=[], napprox=0, comm=None, nanisinf=False):
"""Provides the sampled Kullback-Leibler divergence between a distribution
and a Metric Gaussian.
A Metric Gaussian is used to approximate another probability distribution.
It is a Gaussian distribution that uses the Fisher information metric of
the other distribution at the location of its mean to approximate the
variance. In order to infer the mean, a stochastic estimate of the
Kullback-Leibler divergence is minimized. This estimate is obtained by
sampling the Metric Gaussian at the current mean. During minimization
these samples are kept constant; only the mean is updated. Due to the
typically nonlinear structure of the true distribution these samples have
to be updated eventually by intantiating `MetricGaussianKL` again. For the
true probability distribution the standard parametrization is assumed.
The samples of this class can be distributed among MPI tasks.
Parameters
----------
mean : Field
Mean of the Gaussian probability distribution.
hamiltonian : StandardHamiltonian
Hamiltonian of the approximated probability distribution.
n_samples : integer
Number of samples used to stochastically estimate the KL.
mirror_samples : boolean
Whether the negative of the drawn samples are also used, as they are
equally legitimate samples. If true, the number of used samples
doubles. Mirroring samples stabilizes the KL estimate as extreme
sample variation is counterbalanced. Since it improves stability in
many cases, it is recommended to set `mirror_samples` to `True`.
constants : list
List of parameter keys that are kept constant during optimization.
Default is no constants.
point_estimates : list
List of parameter keys for which no samples are drawn, but that are
(possibly) optimized for, corresponding to point estimates of these.
Default is to draw samples for the complete domain.
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible
across this communicator. If `mirror_samples` is set, then a sample and
its mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on
these occaisions but rather the minimizer is told that the position it
has tried is not sensible.
Notes
-----
The two lists `constants` and `point_estimates` are independent from each
other. It is possible to sample along domains which are kept constant
during minimization and vice versa.
DomainTuples should never be created using the constructor, but rather
via the factory function :attr:`make`!
See also
--------
`Metric Gaussian Variational Inference`, Jakob Knollmüller,
Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
"""
if not isinstance(hamiltonian, StandardHamiltonian):
raise TypeError
if hamiltonian.domain is not mean.domain:
raise ValueError
if not isinstance(n_samples, int):
raise TypeError
if not isinstance(mirror_samples, bool):
raise TypeError
if isinstance(mean, MultiField) and set(point_estimates) == set(mean.keys()):
raise RuntimeError(
'Point estimates for whole domain. Use EnergyAdapter instead.')
n_samples = int(n_samples)
mirror_samples = bool(mirror_samples)
_, ham_sampling = _reduce_by_keys(mean, hamiltonian, point_estimates)
lin = Linearization.make_var(mean.extract(ham_sampling.domain), True)
met = ham_sampling(lin).metric
if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox))
local_samples = []
sseq = random.spawn_sseq(n_samples)
for i in range(*_get_lo_hi(comm, n_samples)):
with random.Context(sseq[i]):
local_samples.append(met.draw_sample(from_inverse=True))
local_samples = tuple(local_samples)
mean, hamiltonian = _reduce_by_keys(mean, hamiltonian, constants)
return _SampledKLEnergy(mean, hamiltonian, n_samples, mirror_samples, comm,
local_samples, nanisinf)
def GeoMetricKL(mean, hamiltonian, n_samples, minimizer_samp, mirror_samples,
start_from_lin = True, constants=[], point_estimates=[],
napprox=0, comm=None, nanisinf=True):
"""Provides the sampled Kullback-Leibler used in geometric Variational
Inference (geoVI).
In geoVI a probability distribution is approximated with a standard normal