Commit c7bbcba3 authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge branch 'geomtric_kl_merge' into 'NIFTy_7'

Add Geometric KL

See merge request !626
parents d12411a4 c61b7f37
Pipeline #102812 passed with stages
in 27 minutes and 2 seconds
......@@ -143,7 +143,7 @@ run_curve_fitting:
paths:
- '*.png'
run_visual_mgvi:
run_visual_vi:
stage: demo_runs
script:
- python3 demos/mgvi_visualized.py
- python3 demos/variational_inference_visualized.py
......@@ -64,13 +64,32 @@ The implementation tests for nonlinear operators are now available in
`ift.extra.check_operator()` and for linear operators
`ift.extra.check_linear_operator()`.
MetricGaussianKL interface
--------------------------
Users do not instantiate `MetricGaussianKL` by its constructor anymore. Rather
`MetricGaussianKL.make()` shall be used. Additionally, `mirror_samples` is not
set by default anymore.
`mirror_samples` is not set by default anymore.
GeoMetricKL
-----------
A new posterior approximation scheme, called geometric Variational Inference
(geoVI) was introduced. `GeoMetricKL` extends `MetricGaussianKL` in the sense
that it uses (non-linear) geoVI samples instead of (linear) MGVI samples.
`GeoMetricKL` can be configured such that it reduces to `MetricGaussianKL`.
`GeoMetricKL` is now used in `demos/getting_started_3.py` and a visual
comparison to MGVI can be found in `demos/variational_inference_visualized.py`.
For further details see (<https://arxiv.org/abs/2105.10470>).
LikelihoodOperator
------------------
A new subclass of `EnergyOperator` was introduced and all `EnergyOperator`s
that are likelihoods are now `LikelihoodOperator`s. A `LikelihoodOperator`
has to implement the function `get_transformation`, which returns a
coordinate transformation in which the Fisher metric of the likelihood becomes
the identity matrix. This is needed for the `GeoMetricKL` algorithm.
Changes since NIFTy 5
......
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -31,6 +31,7 @@ import numpy as np
import nifty7 as ift
ift.random.push_sseq_from_seed(27)
def random_los(n_los):
starts = list(ift.random.current_rng().random((n_los, 2)).T)
......@@ -63,16 +64,16 @@ def main():
'offset_std': (1e-3, 1e-6),
# Amplitude of field fluctuations
'fluctuations': (2., 1.), # 1.0, 1e-2
'fluctuations': (1., 0.8), # 1.0, 1e-2
# Exponent of power law power spectrum component
'loglogavgslope': (-4., 1), # -6.0, 1
'loglogavgslope': (-3., 1), # -6.0, 1
# Amplitude of integrated Wiener process power spectrum component
'flexibility': (5, 2.), # 2.0, 1.0
'flexibility': (2, 1.), # 1.0, 0.5
# How ragged the integrated Wiener process component is
'asperity': (0.5, 0.5) # 0.1, 0.5
'asperity': (0.5, 0.4) # 0.1, 0.5
}
correlated_field = ift.SimpleCorrelatedField(position_space, **args)
......@@ -89,7 +90,6 @@ def main():
# Specify noise
data_space = R.target
noise = .001
sig = ift.ScalingOperator(data_space, np.sqrt(noise))
N = ift.ScalingOperator(data_space, noise)
# Generate mock signal and data
......@@ -97,11 +97,14 @@ def main():
data = signal_response(mock_position) + N.draw_sample_with_dtype(dtype=np.float64)
# Minimization parameters
ic_sampling = ift.AbsDeltaEnergyController(
deltaE=0.05, iteration_limit=100)
ic_newton = ift.AbsDeltaEnergyController(
name='Newton', deltaE=0.5, iteration_limit=35)
ic_sampling = ift.AbsDeltaEnergyController(name="Sampling (linear)",
deltaE=0.05, iteration_limit=100)
ic_newton = ift.AbsDeltaEnergyController(name='Newton', deltaE=0.5,
convergence_level=2, iteration_limit=35)
minimizer = ift.NewtonCG(ic_newton)
ic_sampling_nl = ift.AbsDeltaEnergyController(name='Sampling (nonlin)',
deltaE=0.5, iteration_limit=15, convergence_level=2)
minimizer_sampling = ift.NewtonCG(ic_sampling_nl)
# Set up likelihood and information Hamiltonian
likelihood = (ift.GaussianEnergy(mean=data, inverse_covariance=N.inverse) @
......@@ -112,18 +115,21 @@ def main():
mean = initial_mean
plot = ift.Plot()
plot.add(signal(mock_position), title='Ground Truth')
plot.add(signal(mock_position), title='Ground Truth', zmin = 0, zmax = 1)
plot.add(R.adjoint_times(data), title='Data')
plot.add([pspec.force(mock_position)], title='Power Spectrum')
plot.output(ny=1, nx=3, xsize=24, ysize=6, name=filename.format("setup"))
# number of samples used to estimate the KL
N_samples = 20
N_samples = 10
# Draw new samples to approximate the KL five times
for i in range(5):
# Draw new samples to approximate the KL six times
for i in range(6):
if i==5:
# Double the number of samples in the last step for better statistics
N_samples = 2*N_samples
# Draw new samples and minimize KL
KL = ift.MetricGaussianKL.make(mean, H, N_samples, True)
KL = ift.GeoMetricKL(mean, H, N_samples, minimizer_sampling, True)
KL, convergence = minimizer(KL)
mean = KL.position
ift.extra.minisanity(data, lambda x: N.inverse, signal_response,
......@@ -131,7 +137,7 @@ def main():
# Plot current reconstruction
plot = ift.Plot()
plot.add(signal(KL.position), title="Latent mean")
plot.add(signal(KL.position), title="Latent mean", zmin = 0, zmax = 1)
plot.add([pspec.force(KL.position + ss) for ss in KL.samples],
title="Samples power spectrum")
plot.output(ny=1, ysize=6, xsize=16,
......@@ -144,16 +150,16 @@ def main():
# Plotting
filename_res = filename.format("results")
plot = ift.Plot()
plot.add(sc.mean, title="Posterior Mean")
plot.add(sc.mean, title="Posterior Mean", zmin = 0, zmax = 1)
plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation")
powers = [pspec.force(s + KL.position) for s in KL.samples]
sc = ift.StatCalculator()
for pp in powers:
sc.add(pp)
sc.add(pp.log())
plot.add(
powers + [pspec.force(mock_position),
pspec.force(KL.position), sc.mean],
pspec.force(KL.position), sc.mean.exp()],
title="Sampled Posterior Power Spectrum",
linewidth=[1.]*len(powers) + [3., 3., 3.],
label=[None]*len(powers) + ['Ground truth', 'Posterior latent mean', 'Posterior mean'])
......
......@@ -134,7 +134,7 @@ def main():
for i in range(5):
# Draw new samples and minimize KL
KL = ift.MetricGaussianKL.make(mean, H, N_samples, True)
KL = ift.MetricGaussianKL(mean, H, N_samples, True)
KL, convergence = minimizer(KL)
mean = KL.position
......
......@@ -94,7 +94,7 @@ if __name__ == "__main__":
for i in range(5):
# Draw new samples and minimize KL
kl = ift.MetricGaussianKL.make(mean, ham, n_samples, True)
kl = ift.MetricGaussianKL(mean, ham, n_samples, True)
kl, convergence = minimizer(kl)
mean = kl.position
......
......@@ -11,21 +11,21 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Authors: Reimar Leike, Philipp Arras
# Copyright(C) 2013-2021 Max-Planck-Society
# Authors: Reimar Leike, Philipp Arras, Philipp Frank
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
###############################################################################
# Metric Gaussian Variational Inference (MGVI)
# Variational Inference (VI)
#
# This script demonstrates how MGVI works for an inference problem with only
# two real quantities of interest. This enables us to plot the posterior
# probability density as two-dimensional plot. The posterior samples generated
# by MGVI are contrasted with the maximum-a-posterior (MAP) solution together
# with samples drawn with the Laplace method. This method uses the local
# curvature at the MAP solution as inverse covariance of a Gaussian probability
# density.
# This script demonstrates how MGVI and GeoVI work for an inference problem
# with only two real quantities of interest. This enables us to plot the
# posterior probability density as two-dimensional plot. The approximate
# posterior samples are contrasted with the maximum-a-posterior (MAP) solution
# together with samples drawn with the Laplace method. This method uses the
# local curvature at the MAP solution as inverse covariance of a Gaussian
# probability density.
###############################################################################
import numpy as np
......@@ -87,36 +87,52 @@ def main():
minimizer = ift.NewtonCG(
ift.GradientNormController(iteration_limit=2, name='Mini'))
pos = ift.from_random(ham.domain, 'normal')
plt.figure(figsize=[12, 8])
pos = pos1 = ift.from_random(ham.domain, 'normal')
fig, axs = plt.subplots(2, 1, figsize=[12, 8])
for ii in range(15):
if ii % 3 == 0:
mgkl = ift.MetricGaussianKL.make(pos, ham, 40, False)
plt.cla()
plt.imshow(z.T, origin='lower', norm=LogNorm(), vmin=1e-3,
vmax=np.max(z), cmap='gist_earth_r',
extent=x_limits_scaled + y_limits)
if ii == 0:
cbar = plt.colorbar()
cbar.ax.set_ylabel('pdf')
xs, ys = [], []
for samp in mgkl.samples:
samp = (samp + pos).val
xs.append(samp['a'])
ys.append(samp['b'])
plt.scatter(np.array(xs)*scale, np.array(ys), label='MGVI samples')
plt.scatter(pos.val['a']*scale, pos.val['b'], label='MGVI latent mean')
plt.scatter(np.array(map_xs)*scale, np.array(map_ys),
label='Laplace samples')
plt.scatter(MAP.position.val['a']*scale, MAP.position.val['b'],
label='Maximum a posterior solution')
plt.legend()
# Resample
mgkl = ift.MetricGaussianKL(pos, ham, 100, False)
mini_samp = ift.NewtonCG(ift.GradientNormController(iteration_limit=5))
geokl = ift.GeoMetricKL(pos1, ham, 100, mini_samp, False)
for axx in axs:
axx.clear()
im = axx.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
cmap='gist_earth_r', extent=x_limits_scaled + y_limits)
if ii == 0:
cbar = plt.colorbar(im, ax=axx)
cbar.ax.set_ylabel('pdf')
for jj, nn, kl, pp in ((0, "MGVI", mgkl, pos), (1, "GeoVI", geokl, pos1)):
xs, ys = [], []
for samp in kl.samples:
samp = (samp + pp).val
xs.append(samp['a'])
ys.append(samp['b'])
axs[jj].scatter(np.array(xs)*scale, np.array(ys), label=f'{nn} samples')
axs[jj].scatter(pp.val['a']*scale, pp.val['b'], label=f'{nn} latent mean')
axs[jj].set_title(nn)
for axx in axs:
axx.scatter(np.array(map_xs)*scale, np.array(map_ys),
label='Laplace samples')
axx.scatter(MAP.position.val['a']*scale, MAP.position.val['b'],
label='Maximum a posterior solution')
axx.set_xlim(x_limits_scaled)
axx.set_ylim(y_limits)
axx.set_ylabel('y')
axx.legend(loc='lower right')
axs[0].xaxis.set_visible(False)
axs[1].set_xlabel('x')
plt.tight_layout()
plt.draw()
plt.pause(1.0)
mgkl, _ = minimizer(mgkl)
geokl, _ = minimizer(geokl)
pos = mgkl.position
pos1 = geokl.position
ift.logger.info('Finished')
# Uncomment the following line in order to leave the plots open
# plt.show()
......
......@@ -369,17 +369,16 @@ tackling new IFT problems. An example of concrete energy classes delivered with
NIFTy7 is :class:`~minimization.quadratic_energy.QuadraticEnergy` (with
position-independent metric, mainly used with conjugate gradient minimization).
For MGVI, NIFTy provides the :class:`~energy.Energy` subclass
:class:`~minimization.metric_gaussian_kl.MetricGaussianKL`,
which computes the sampled estimated of the KL divergence, its gradient and the
Fisher metric. The constructor of
:class:`~minimization.metric_gaussian_kl.MetricGaussianKL` requires an instance
For MGVI and GeoVI, NIFTy provides :func:`~minimization.kl_energies.MetricGaussianKL`
and :func:`~minimization.kl_energies.GeoMetricKL` that instantiate objects
containing the sampled estimate of the KL divergence, its gradient and the
Fisher metric. These constructors require an instance
of :class:`~operators.energy_operators.StandardHamiltonian`, an operator to
compute the negative log-likelihood of the problem in standardized coordinates
at a given position in parameter space.
Finally, the :class:`~operators.energy_operators.StandardHamiltonian`
can be constructed from the likelihood, represented by an
:class:`~operators.energy_operators.EnergyOperator` instance.
can be constructed from the likelihood, represented by a
:class:`~operators.energy_operators.LikelihoodOperator` instance.
Several commonly used forms of the likelihoods are already provided in
NIFTy, such as :class:`~operators.energy_operators.GaussianEnergy`,
:class:`~operators.energy_operators.PoissonianEnergy`,
......
......@@ -19,17 +19,16 @@ Also, there were certainly previous works in a similar spirit.
Anyhow, in many cases IFT provides novel rigorous ways to extract information from data.
NIFTy comes with reimplemented MAP and VI estimators.
.. tip:: *In-a-nutshell introductions to information field theory* can be found in [2]_, [3]_, [4]_, and [5]_, with the latter probably being the most didactical.
.. tip:: A didactical introduction to *information field theory* can be found in [2]_. Other resources include the `Wikipedia entry <https://en.wikipedia.org/wiki/Information_field_theory>`_, [3]_ and [4]_.
.. [1] T.A. Enßlin et al. (2009), "Information field theory for cosmological perturbation reconstruction and nonlinear signal analysis", PhysRevD.80.105005, 09/2009; `[arXiv:0806.3474] <https://www.arxiv.org/abs/0806.3474>`_
.. [2] T.A. Enßlin (2013), "Information field theory", proceedings of MaxEnt 2012 -- the 32nd International Workshop on Bayesian Inference and Maximum Entropy Methods in Science and Engineering; AIP Conference Proceedings, Volume 1553, Issue 1, p.184; `[arXiv:1301.2556] <https://arxiv.org/abs/1301.2556>`_
.. [2] T.A. Enßlin (2019), "Information theory for fields", Annalen der Physik; `[DOI] <https://doi.org/10.1002/andp.201800127>`_, `[arXiv:1804.03350] <https://arxiv.org/abs/1804.03350>`_
.. [3] T.A. Enßlin (2014), "Astrophysical data analysis with information field theory", AIP Conference Proceedings, Volume 1636, Issue 1, p.49; `[arXiv:1405.7701] <https://arxiv.org/abs/1405.7701>`_
.. [3] T.A. Enßlin (2013), "Information field theory", proceedings of MaxEnt 2012 -- the 32nd International Workshop on Bayesian Inference and Maximum Entropy Methods in Science and Engineering; AIP Conference Proceedings, Volume 1553, Issue 1, p.184; `[arXiv:1301.2556] <https://arxiv.org/abs/1301.2556>`_
.. [4] Wikipedia contributors (2018), `"Information field theory" <https://en.wikipedia.org/w/index.php?title=Information_field_theory&oldid=876731720>`_, Wikipedia, The Free Encyclopedia.
.. [4] T.A. Enßlin (2014), "Astrophysical data analysis with information field theory", AIP Conference Proceedings, Volume 1636, Issue 1, p.49; `[arXiv:1405.7701] <https://arxiv.org/abs/1405.7701>`_
.. [5] T.A. Enßlin (2019), "Information theory for fields", accepted by Annalen der Physik; `[DOI] <https://doi.org/10.1002/andp.201800127>`_, `[arXiv:1804.03350] <https://arxiv.org/abs/1804.03350>`_
......@@ -133,8 +132,8 @@ NIFTy takes advantage of this formulation in several ways:
3) The response can be non-linear, e.g. :math:`{R'(s)=R \exp(A\,\xi)}`, see `demos/getting_started_2.py`.
4) The amplitude operator may depend on further parameters, e.g. :math:`A=A(\tau)= F\, \widehat{e^\tau}` represents an amplitude operator with a positive definite, unknown spectrum defined in the Fourier domain.
The amplitude field :math:`{\tau}` would get its own amplitude operator, with a cepstrum (spectrum of a log spectrum) defined in quefrency space (harmonic space of a logarithmically binned harmonic space) to regularize its degrees of freedom by imposing some (user-defined degree of) spectral smoothness.
4) The amplitude operator may depend on further parameters, e.g. :math:`A=A(\tau)=F\, \widehat{e^\tau}` represents an amplitude operator with a positive definite, unknown spectrum.
The log-amplitude field :math:`{\tau}` is modelled with the help of an integrated Wiener process in order to impose some (user-defined degree of) spectral smoothness.
5) NIFTy calculates the gradient of the information Hamiltonian and the Fisher information metric with respect to all unknown parameters, here :math:`{\xi}` and :math:`{\tau}`, by automatic differentiation.
The gradients are used for MAP estimates, and the Fisher matrix is required in addition to the gradient by Metric Gaussian Variational Inference (MGVI), which is available in NIFTy as well.
......@@ -176,7 +175,7 @@ It only requires minimizing the information Hamiltonian, e.g. by a gradient desc
NIFTy7 automatically calculates the necessary gradient from a generative model of the signal and the data and uses this to minimize the Hamiltonian.
However, MAP often provides unsatisfactory results in cases of deep hirachical Bayesian networks.
However, MAP often provides unsatisfactory results in cases of deep hierarchical Bayesian networks.
The reason for this is that MAP ignores the volume factors in parameter space, which are not to be neglected in deciding whether a solution is reasonable or not.
In the high dimensional setting of field inference these volume factors can differ by large ratios.
A MAP estimate, which is only representative for a tiny fraction of the parameter space, might be a poorer choice (with respect to an error norm) compared to a slightly worse location with slightly lower posterior probability, which, however, is associated with a much larger volume (of nearby locations with similar probability).
......@@ -184,8 +183,8 @@ A MAP estimate, which is only representative for a tiny fraction of the paramete
This causes MAP signal estimates to be more prone to overfitting the noise as well as to perception thresholds than methods that take volume effects into account.
Variational Inference
---------------------
Metric Gaussian Variational Inference
-------------------------------------
One method that takes volume effects into account is Variational Inference (VI).
In VI, the posterior :math:`\mathcal{P}(\xi|d)` is approximated by a simpler, parametrized distribution, often a Gaussian :math:`\mathcal{Q}(\xi)=\mathcal{G}(\xi-m,D)`.
......@@ -214,7 +213,7 @@ The only term within the KL-divergence that explicitly depends on it is the Hami
\mathrm{KL}(m|d) \;\widehat{=}\;
\left\langle \mathcal{H}(\xi,d) \right\rangle_{\mathcal{Q}(\xi)},
where :math:`\widehat{=}` expresses equality up to irrelvant (here not :math:`m`-dependent) terms.
where :math:`\widehat{=}` expresses equality up to irrelevant (here not :math:`m`-dependent) terms.
Thus, only the gradient of the KL is needed with respect to this, which can be expressed as
......@@ -224,11 +223,39 @@ Thus, only the gradient of the KL is needed with respect to this, which can be e
We stochastically estimate the KL-divergence and gradients with a set of samples drawn from the approximate posterior distribution.
The particular structure of the covariance allows us to draw independent samples solving a certain system of equations.
This KL-divergence for MGVI is implemented in the class :class:`~nifty.7minimization.metric_gaussian_kl.MetricGaussianKL` within NIFTy7.
This KL-divergence for MGVI is implemented by
:func:`~nifty7.minimization.kl_energies.MetricGaussianKL` within NIFTy7.
Note that MGVI typically provides only a lower bound on the variance.
The demo `getting_started_3.py` for example not only infers a field this way, but also the power spectrum of the process that has generated the field.
The cross-correlation of field and power spectrum is taken care of in this process.
Posterior samples can be obtained to study this cross-correlation.
Geometric Variational Inference
-------------------------------
It should be noted that MGVI, as any VI method, can typically only provide a lower bound on the variance.
For non-linear posterior distributions :math:`\mathcal{P}(\xi|d)` an approximation with a Gaussian :math:`\mathcal{Q}(\xi)` in the coordinates :math:`\xi` is sub-optimal, as higher order interactions are ignored.
A better approximation can be achieved by constructing a coordinate system :math:`y = g\left(\xi\right)` in which the posterior is close to a Gaussian, and perform VI with a Gaussian :math:`\mathcal{Q}(y)` in these coordinates.
This approach is called Geometric Variation Inference (geoVI).
It is discussed in detail in [6]_.
One useful coordinate system is obtained in case the metric :math:`M` of the posterior can be expressed as the pullback of the Euclidean metric by :math:`g`:
.. math::
M = \left(\frac{\partial g}{\partial \xi}\right)^T \frac{\partial g}{\partial \xi} \ .
In general, such a transformation exists only locally, i.e. in a neighbourhood of some expansion point :math:`\bar{\xi}`, denoted as :math:`g_{\bar{\xi}}\left(\xi\right)`.
Using :math:`g_{\bar{\xi}}`, the GeoVI scheme uses a zero mean, unit Gaussian :math:`\mathcal{Q}(y) = \mathcal{G}(y, 1)` approximation.
It can be expressed in :math:`\xi` coordinates via the pushforward by the inverse transformation :math:`\xi = g_{\bar{\xi}}^{-1}(y)`:
.. math::
\mathcal{Q}_{\bar{\xi}}(\xi) = \left(g_{\bar{\xi}}^{-1} * \mathcal{Q}\right)(\xi) = \int \delta\left(\xi - g_{\bar{\xi}}^{-1}(y)\right) \ \mathcal{G}(y, 1) \ \mathcal{D}y \ ,
where :math:`\delta` denotes the Kronecker-delta.
GeoVI obtains the optimal expansion point :math:`\bar{\xi}` such that :math:`\mathcal{Q}_{\bar{\xi}}` matches the posterior as good as possible.
Analogous to the MGVI algorithm, :math:`\bar{\xi}` is obtained by minimization of the KL-divergence between :math:`\mathcal{P}` and :math:`\mathcal{Q}_{\bar{\xi}}` w.r.t. :math:`\bar{\xi}`.
Furthermore the KL is represented as a stochastic estimate using a set of samples drawn from :math:`\mathcal{Q}_{\bar{\xi}}` which is implemented in NIFTy7 via :func:`~nifty7.minimization.kl_energies.GeoMetricKL`.
A visual comparison of the MGVI and GeoVI algorithm can be found in `variational_inference_visualized.py <https://gitlab.mpcdf.mpg.de/ift/nifty/-/blob/NIFTy_7/demos/variational_inference_visualized.py>`_.
.. [6] P. Frank, R. Leike, and T.A. Enßlin (2021), "Geometric Variational Inference"; `[arXiv:2105.10470] <https://arxiv.org/abs/2105.10470>`_
......@@ -71,7 +71,7 @@ from .minimization.scipy_minimizer import L_BFGS_B
from .minimization.energy import Energy
from .minimization.quadratic_energy import QuadraticEnergy
from .minimization.energy_adapter import EnergyAdapter
from .minimization.metric_gaussian_kl import MetricGaussianKL
from .minimization.kl_energies import MetricGaussianKL, GeoMetricKL
from .sugar import *
......
This diff is collapsed.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from .. import random, utilities
from ..linearization import Linearization
from ..logger import logger
from ..multi_field import MultiField
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import StandardHamiltonian
from ..probing import approximation2endo
from ..sugar import makeOp
from ..utilities import myassert
from .energy import Energy
class _KLMetric(EndomorphicOperator):
def __init__(self, KL):
self._KL = KL
self._capability = self.TIMES | self.ADJOINT_TIMES
self._domain = KL.position.domain
def apply(self, x, mode):
self._check_input(x, mode)
return self._KL.apply_metric(x)
def _get_lo_hi(comm, n_samples):
ntask, rank, _ = utilities.get_MPI_params_from_comm(comm)
return utilities.shareRange(n_samples, ntask, rank)
def _modify_sample_domain(sample, domain):
"""Takes only keys from sample which are also in domain and inserts zeros
for keys which are not in sample.domain."""
from ..domain_tuple import DomainTuple
from ..field import Field
from ..multi_domain import MultiDomain
from ..sugar import makeDomain
domain = makeDomain(domain)
if isinstance(domain, DomainTuple) and isinstance(sample, Field):
if sample.domain is not domain:
raise TypeError
return sample
elif isinstance(domain, MultiDomain) and isinstance(sample, MultiField):
if sample.domain is domain:
return sample
out = {kk: vv for kk, vv in sample.items() if kk in domain.keys()}
out = MultiField.from_dict(out, domain)
return out
raise TypeError
class MetricGaussianKL(Energy):
"""Provides the sampled Kullback-Leibler divergence between a distribution
and a Metric Gaussian.
A Metric Gaussian is used to approximate another probability distribution.
It is a Gaussian distribution that uses the Fisher information metric of
the other distribution at the location of its mean to approximate the
variance. In order to infer the mean, a stochastic estimate of the
Kullback-Leibler divergence is minimized. This estimate is obtained by
sampling the Metric Gaussian at the current mean. During minimization
these samples are kept constant; only the mean is updated. Due to the
typically nonlinear structure of the true distribution these samples have
to be updated eventually by intantiating `MetricGaussianKL` again. For the
true probability distribution the standard parametrization is assumed.
The samples of this class can be distributed among MPI tasks.
Notes
-----
DomainTuples should never be created using the constructor, but rather
via the factory function :attr:`make`!
See also
--------
`Metric Gaussian Variational Inference`, Jakob Knollmüller,
Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
"""
def __init__(self, mean, hamiltonian, n_samples, mirror_samples, comm,
local_samples, nanisinf, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(MetricGaussianKL, self).__init__(mean)
myassert(mean.domain is hamiltonian.domain)
self._hamiltonian = hamiltonian
self._n_samples = int(n_samples)
self._mirror_samples = bool(mirror_samples)
self._comm = comm
self._local_samples = local_samples
self._nanisinf = bool(nanisinf)
lin = Linearization.make_var(mean)
v, g = [], []
for s in self._local_samples:
s = _modify_sample_domain(s, mean.domain)
tmp = hamiltonian(lin+s)
tv = tmp.val.val
tg = tmp.gradient
if mirror_samples:
tmp = hamiltonian(lin-s)
tv = tv + tmp.val.val
tg = tg + tmp.gradient
v.append(tv)
g.append(tg)
self._val = utilities.allreduce_sum(v, self._comm)[()]/self.n_eff_samples
if np.isnan(self._val) and self._nanisinf:
self._val = np.inf
self._grad = utilities.allreduce_sum(g, self._comm)/self.n_eff_samples
@staticmethod
def make(mean, hamiltonian, n_samples, mirror_samples, constants=[],
point_estimates=[], napprox=0, comm=None, nanisinf=False):
"""Return instance of :class:`MetricGaussianKL`.
Parameters
----------
mean : Field
Mean of the Gaussian probability distribution.
hamiltonian : StandardHamiltonian
Hamiltonian of the approximated probability distribution.
n_samples : integer
Number of samples used to stochastically estimate the KL.
mirror_samples : boolean
Whether the negative of the drawn samples are also used, as they are
equally legitimate samples. If true, the number of used samples
doubles. Mirroring samples stabilizes the KL estimate as extreme
sample variation is counterbalanced. Since it improves stability in
many cases, it is recommended to set `mirror_samples` to `True`.
constants : list
List of parameter keys that are kept constant during optimization.
Default is no constants.
point_estimates : list
List of parameter keys for which no samples are drawn, but that are
(possibly) optimized for, corresponding to point estimates of these.
Default is to draw samples for the complete domain.
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible
across this communicator. If `mirror_samples` is set, then a sample and
its mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on
these occaisions but rather the minimizer is told that the position it
has tried is not sensible.
Note
----
The two lists `constants` and `point_estimates` are independent from each
other. It is possible to sample along domains which are kept constant
during minimization and vice versa.