Commit 4725ca7c authored by Philipp Arras's avatar Philipp Arras

Merge branch 'kl_fixes' into 'NIFTy_5'

Kl fixes

See merge request ift/nifty-dev!220
parents 3f52237c 998d3b9e
......@@ -15,62 +15,79 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .energy import Energy
from ..linearization import Linearization
from .. import utilities
from ..linearization import Linearization
from ..operators.energy_operators import StandardHamiltonian
from .energy import Energy
class MetricGaussianKL(Energy):
"""Provides the sampled Kullback-Leibler divergence between a distribution
and a Metric Gaussian.
A Metric Gaussian is used to approximate some other distribution.
It is a Gaussian distribution that uses the Fisher Information Metric
of the other distribution at the location of its mean to approximate the
variance. In order to infer the mean, the a stochastic estimate of the
A Metric Gaussian is used to approximate another probability distribution.
It is a Gaussian distribution that uses the Fisher information metric of
the other distribution at the location of its mean to approximate the
variance. In order to infer the mean, a stochastic estimate of the
Kullback-Leibler divergence is minimized. This estimate is obtained by
drawing samples from the Metric Gaussian at the current mean.
During minimization these samples are kept constant, updating only the
mean. Due to the typically nonlinear structure of the true distribution
these samples have to be updated by re-initializing this class at some
point. Here standard parametrization of the true distribution is assumed.
sampling the Metric Gaussian at the current mean. During minimization
these samples are kept constant; only the mean is updated. Due to the
typically nonlinear structure of the true distribution these samples have
to be updated eventually by intantiating `MetricGaussianKL` again. For the
true probability distribution the standard parametrization is assumed.
Parameters
----------
mean : Field
The current mean of the Gaussian.
Mean of the Gaussian probability distribution.
hamiltonian : StandardHamiltonian
The StandardHamiltonian of the approximated probability distribution.
Hamiltonian of the approximated probability distribution.
n_samples : integer
The number of samples used to stochastically estimate the KL.
Number of samples used to stochastically estimate the KL.
constants : list
A list of parameter keys that are kept constant during optimization.
List of parameter keys that are kept constant during optimization.
Default is no constants.
point_estimates : list
A list of parameter keys for which no samples are drawn, but that are
optimized for, corresponding to point estimates of these.
List of parameter keys for which no samples are drawn, but that are
(possibly) optimized for, corresponding to point estimates of these.
Default is to draw samples for the complete domain.
mirror_samples : boolean
Whether the negative of the drawn samples are also used,
as they are equaly legitimate samples. If true, the number of used
as they are equally legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. (default : False)
Notes
-----
For further details see: Metric Gaussian Variational Inference
(FIXME in preparation)
extreme sample variation is counterbalanced. Default is False.
_samples : None
Only a parameter for internal uses. Typically not to be set by users.
Note
----
The two lists `constants` and `point_estimates` are independent from each
other. It is possible to sample along domains which are kept constant
during minimization and vice versa.
See also
--------
Metric Gaussian Variational Inference (FIXME in preparation)
"""
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=None, mirror_samples=False,
point_estimates=[], mirror_samples=False,
_samples=None):
super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
raise TypeError
if hamiltonian.domain is not mean.domain:
raise ValueError
if not isinstance(n_samples, int):
raise TypeError
self._constants = list(constants)
self._point_estimates = list(point_estimates)
if not isinstance(mirror_samples, bool):
raise TypeError
self._hamiltonian = hamiltonian
self._constants = constants
if point_estimates is None:
point_estimates = constants
self._constants_samples = point_estimates
if _samples is None:
met = hamiltonian(Linearization.make_partial_var(
mean, point_estimates, True)).metric
......@@ -96,7 +113,7 @@ class MetricGaussianKL(Energy):
def at(self, position):
return MetricGaussianKL(position, self._hamiltonian, 0,
self._constants, self._constants_samples,
self._constants, self._point_estimates,
_samples=self._samples)
@property
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment