Commit 00a2a0c5 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'KL_docstrings' into 'NIFTy_5'

KL docstrings

See merge request ift/nifty-dev!190
parents d1e5a734 01e88a70
...@@ -74,7 +74,7 @@ if __name__ == '__main__': ...@@ -74,7 +74,7 @@ if __name__ == '__main__':
ic_sampling = ift.GradientNormController(iteration_limit=100) ic_sampling = ift.GradientNormController(iteration_limit=100)
# Minimize the Hamiltonian # Minimize the Hamiltonian
H = ift.Hamiltonian(likelihood, ic_sampling) H = ift.StandardHamiltonian(likelihood, ic_sampling)
H = ift.EnergyAdapter(position, H, want_metric=True) H = ift.EnergyAdapter(position, H, want_metric=True)
# minimizer = ift.L_BFGS(ic_newton) # minimizer = ift.L_BFGS(ic_newton)
H, convergence = minimizer(H) H, convergence = minimizer(H)
......
...@@ -99,7 +99,7 @@ if __name__ == '__main__': ...@@ -99,7 +99,7 @@ if __name__ == '__main__':
minimizer = ift.NewtonCG(ic_newton) minimizer = ift.NewtonCG(ic_newton)
# Compute MAP solution by minimizing the information Hamiltonian # Compute MAP solution by minimizing the information Hamiltonian
H = ift.Hamiltonian(likelihood) H = ift.StandardHamiltonian(likelihood)
initial_position = ift.from_random('normal', domain) initial_position = ift.from_random('normal', domain)
H = ift.EnergyAdapter(initial_position, H, want_metric=True) H = ift.EnergyAdapter(initial_position, H, want_metric=True)
H, convergence = minimizer(H) H, convergence = minimizer(H)
......
...@@ -100,10 +100,10 @@ if __name__ == '__main__': ...@@ -100,10 +100,10 @@ if __name__ == '__main__':
# Set up likelihood and information Hamiltonian # Set up likelihood and information Hamiltonian
likelihood = ift.GaussianEnergy(mean=data, covariance=N)(signal_response) likelihood = ift.GaussianEnergy(mean=data, covariance=N)(signal_response)
H = ift.Hamiltonian(likelihood, ic_sampling) H = ift.StandardHamiltonian(likelihood, ic_sampling)
initial_position = ift.MultiField.full(H.domain, 0.) initial_mean = ift.MultiField.full(H.domain, 0.)
position = initial_position mean = initial_mean
plot = ift.Plot() plot = ift.Plot()
plot.add(signal(mock_position), title='Ground Truth') plot.add(signal(mock_position), title='Ground Truth')
...@@ -117,9 +117,9 @@ if __name__ == '__main__': ...@@ -117,9 +117,9 @@ if __name__ == '__main__':
# Draw new samples to approximate the KL five times # Draw new samples to approximate the KL five times
for i in range(5): for i in range(5):
# Draw new samples and minimize KL # Draw new samples and minimize KL
KL = ift.KL_Energy(position, H, N_samples) KL = ift.MetricGaussianKL(mean, H, N_samples)
KL, convergence = minimizer(KL) KL, convergence = minimizer(KL)
position = KL.position mean = KL.position
# Plot current reconstruction # Plot current reconstruction
plot = ift.Plot() plot = ift.Plot()
...@@ -128,7 +128,7 @@ if __name__ == '__main__': ...@@ -128,7 +128,7 @@ if __name__ == '__main__':
plot.output(ny=1, ysize=6, xsize=16, name="loop-{:02}.png".format(i)) plot.output(ny=1, ysize=6, xsize=16, name="loop-{:02}.png".format(i))
# Draw posterior samples # Draw posterior samples
KL = ift.KL_Energy(position, H, N_samples) KL = ift.MetricGaussianKL(mean, H, N_samples)
sc = ift.StatCalculator() sc = ift.StatCalculator()
for sample in KL.samples: for sample in KL.samples:
sc.add(signal(sample + KL.position)) sc.add(signal(sample + KL.position))
......
...@@ -103,7 +103,7 @@ N = ift.DiagonalOperator(ift.from_global_data(d_space, var)) ...@@ -103,7 +103,7 @@ N = ift.DiagonalOperator(ift.from_global_data(d_space, var))
IC = ift.DeltaEnergyController(tol_rel_deltaE=1e-12, iteration_limit=200) IC = ift.DeltaEnergyController(tol_rel_deltaE=1e-12, iteration_limit=200)
likelihood = ift.GaussianEnergy(d, N)(R) likelihood = ift.GaussianEnergy(d, N)(R)
Ham = ift.Hamiltonian(likelihood, IC) Ham = ift.StandardHamiltonian(likelihood, IC)
H = ift.EnergyAdapter(params, Ham, want_metric=True) H = ift.EnergyAdapter(params, Ham, want_metric=True)
# Minimize # Minimize
......
...@@ -49,7 +49,7 @@ from .operators.simple_linear_operators import ( ...@@ -49,7 +49,7 @@ from .operators.simple_linear_operators import (
from .operators.value_inserter import ValueInserter from .operators.value_inserter import ValueInserter
from .operators.energy_operators import ( from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood, EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
BernoulliEnergy, Hamiltonian, AveragedEnergy) BernoulliEnergy, StandardHamiltonian, AveragedEnergy)
from .probing import probe_with_posterior_samples, probe_diagonal, \ from .probing import probe_with_posterior_samples, probe_diagonal, \
StatCalculator StatCalculator
...@@ -68,7 +68,7 @@ from .minimization.scipy_minimizer import (ScipyMinimizer, L_BFGS_B, ScipyCG) ...@@ -68,7 +68,7 @@ from .minimization.scipy_minimizer import (ScipyMinimizer, L_BFGS_B, ScipyCG)
from .minimization.energy import Energy from .minimization.energy import Energy
from .minimization.quadratic_energy import QuadraticEnergy from .minimization.quadratic_energy import QuadraticEnergy
from .minimization.energy_adapter import EnergyAdapter from .minimization.energy_adapter import EnergyAdapter
from .minimization.kl_energy import KL_Energy from .minimization.metric_gaussian_kl import MetricGaussianKL
from .sugar import * from .sugar import *
from .plot import Plot from .plot import Plot
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from ..minimization.energy_adapter import EnergyAdapter from ..minimization.energy_adapter import EnergyAdapter
from ..multi_field import MultiField from ..multi_field import MultiField
from ..operators.distributors import PowerDistributor from ..operators.distributors import PowerDistributor
from ..operators.energy_operators import Hamiltonian, InverseGammaLikelihood from ..operators.energy_operators import StandardHamiltonian, InverseGammaLikelihood
from ..operators.scaling_operator import ScalingOperator from ..operators.scaling_operator import ScalingOperator
from ..operators.simple_linear_operators import ducktape from ..operators.simple_linear_operators import ducktape
...@@ -53,8 +53,8 @@ def make_adjust_variances(a, ...@@ -53,8 +53,8 @@ def make_adjust_variances(a,
Returns Returns
------- -------
Energy StandardHamiltonian
Hamiltonian that can be used for further minimization. A Hamiltonian that can be used for further minimization.
""" """
d = a*xi d = a*xi
...@@ -72,7 +72,7 @@ def make_adjust_variances(a, ...@@ -72,7 +72,7 @@ def make_adjust_variances(a,
if scaling is not None: if scaling is not None:
x = ScalingOperator(scaling, x.target)(x) x = ScalingOperator(scaling, x.target)(x)
return Hamiltonian(InverseGammaLikelihood(d_eval)(x), ic_samp=ic_samp) return StandardHamiltonian(InverseGammaLikelihood(d_eval)(x), ic_samp=ic_samp)
def do_adjust_variances(position, def do_adjust_variances(position,
......
...@@ -20,31 +20,70 @@ from ..linearization import Linearization ...@@ -20,31 +20,70 @@ from ..linearization import Linearization
from .. import utilities from .. import utilities
class KL_Energy(Energy): class MetricGaussianKL(Energy):
def __init__(self, position, h, nsamp, constants=[], """Provides the sampled Kullback-Leibler divergence between a distribution
constants_samples=None, gen_mirrored_samples=False, and a Metric Gaussian.
A Metric Gaussian is used to approximate some other distribution.
It is a Gaussian distribution that uses the Fisher Information Metric
of the other distribution at the location of its mean to approximate the
variance. In order to infer the mean, the a stochastic estimate of the
Kullback-Leibler divergence is minimized. This estimate is obtained by
drawing samples from the Metric Gaussian at the current mean.
During minimization these samples are kept constant, updating only the
mean. Due to the typically nonlinear structure of the true distribution
these samples have to be updated by re-initializing this class at some
point. Here standard parametrization of the true distribution is assumed.
Parameters
----------
mean : Field
The current mean of the Gaussian.
hamiltonian : StandardHamiltonian
The StandardHamiltonian of the approximated probability distribution.
n_samples : integer
The number of samples used to stochastically estimate the KL.
constants : list
A list of parameter keys that are kept constant during optimization.
point_estimates : list
A list of parameter keys for which no samples are drawn, but that are
optimized for, corresponding to point estimates of these.
mirror_samples : boolean
Whether the negative of the drawn samples are also used,
as they are equaly legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. (default : False)
Notes
-----
For further details see: Metric Gaussian Variational Inference
(in preparation)
"""
def __init__(self, mean, hamiltonian, n_sampels, constants=[],
point_estimates=None, mirror_samples=False,
_samples=None): _samples=None):
super(KL_Energy, self).__init__(position) super(MetricGaussianKL, self).__init__(mean)
if h.domain is not position.domain: if hamiltonian.domain is not mean.domain:
raise TypeError raise TypeError
self._h = h self._hamiltonian = hamiltonian
self._constants = constants self._constants = constants
if constants_samples is None: if point_estimates is None:
constants_samples = constants point_estimates = constants
self._constants_samples = constants_samples self._constants_samples = point_estimates
if _samples is None: if _samples is None:
met = h(Linearization.make_partial_var( met = hamiltonian(Linearization.make_partial_var(
position, constants_samples, True)).metric mean, point_estimates, True)).metric
_samples = tuple(met.draw_sample(from_inverse=True) _samples = tuple(met.draw_sample(from_inverse=True)
for _ in range(nsamp)) for _ in range(n_sampels))
if gen_mirrored_samples: if mirror_samples:
_samples += tuple(-s for s in _samples) _samples += tuple(-s for s in _samples)
self._samples = _samples self._samples = _samples
self._lin = Linearization.make_partial_var(position, constants) self._lin = Linearization.make_partial_var(mean, constants)
v, g = None, None v, g = None, None
for s in self._samples: for s in self._samples:
tmp = self._h(self._lin+s) tmp = self._hamiltonian(self._lin+s)
if v is None: if v is None:
v = tmp.val.local_data[()] v = tmp.val.local_data[()]
g = tmp.gradient g = tmp.gradient
...@@ -56,9 +95,9 @@ class KL_Energy(Energy): ...@@ -56,9 +95,9 @@ class KL_Energy(Energy):
self._metric = None self._metric = None
def at(self, position): def at(self, position):
return KL_Energy(position, self._h, 0, return MetricGaussianKL(position, self._hamiltonian, 0,
self._constants, self._constants_samples, self._constants, self._constants_samples,
_samples=self._samples) _samples=self._samples)
@property @property
def value(self): def value(self):
...@@ -71,7 +110,8 @@ class KL_Energy(Energy): ...@@ -71,7 +110,8 @@ class KL_Energy(Energy):
def _get_metric(self): def _get_metric(self):
if self._metric is None: if self._metric is None:
lin = self._lin.with_want_metric() lin = self._lin.with_want_metric()
mymap = map(lambda v: self._h(lin+v).metric, self._samples) mymap = map(lambda v: self._hamiltonian(lin+v).metric,
self._samples)
self._metric = utilities.my_sum(mymap) self._metric = utilities.my_sum(mymap)
self._metric = self._metric.scale(1./len(self._samples)) self._metric = self._metric.scale(1./len(self._samples))
......
...@@ -27,7 +27,7 @@ class BlockDiagonalOperator(EndomorphicOperator): ...@@ -27,7 +27,7 @@ class BlockDiagonalOperator(EndomorphicOperator):
domain : MultiDomain domain : MultiDomain
Domain and target of the operator. Domain and target of the operator.
operators : dict operators : dict
Dictionary with subdomain names as keys and :class:`LinearOperator`s Dictionary with subdomain names as keys and :class:`LinearOperator` s
as items. as items.
""" """
def __init__(self, domain, operators): def __init__(self, domain, operators):
......
...@@ -259,7 +259,7 @@ class BernoulliEnergy(EnergyOperator): ...@@ -259,7 +259,7 @@ class BernoulliEnergy(EnergyOperator):
return v.add_metric(met) return v.add_metric(met)
class Hamiltonian(EnergyOperator): class StandardHamiltonian(EnergyOperator):
"""Computes an information Hamiltonian in its standard form, i.e. with the """Computes an information Hamiltonian in its standard form, i.e. with the
prior being a Gaussian with unit covariance. prior being a Gaussian with unit covariance.
...@@ -314,13 +314,13 @@ class Hamiltonian(EnergyOperator): ...@@ -314,13 +314,13 @@ class Hamiltonian(EnergyOperator):
def __repr__(self): def __repr__(self):
subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__())) subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
subs += '\nPrior: Quadratic{}'.format(self._lh.domain.keys()) subs += '\nPrior: Quadratic{}'.format(self._lh.domain.keys())
return 'Hamiltonian:\n' + utilities.indent(subs) return 'StandardHamiltonian:\n' + utilities.indent(subs)
class AveragedEnergy(EnergyOperator): class AveragedEnergy(EnergyOperator):
"""Computes Kullback-Leibler (KL) divergence or Gibbs free energies. """Computes Kullback-Leibler (KL) divergence or Gibbs free energies.
A sample-averaged energy, e.g. an Hamiltonian, approximates the relevant A sample-averaged energy, e.g. a Hamiltonian, approximates the relevant
part of a KL to be used in Variational Bayes inference if the samples are part of a KL to be used in Variational Bayes inference if the samples are
drawn from the approximating Gaussian: drawn from the approximating Gaussian:
......
...@@ -69,7 +69,7 @@ def test_hamiltonian_and_KL(field): ...@@ -69,7 +69,7 @@ def test_hamiltonian_and_KL(field):
field = field.exp() field = field.exp()
space = field.domain space = field.domain
lh = ift.GaussianEnergy(domain=space) lh = ift.GaussianEnergy(domain=space)
hamiltonian = ift.Hamiltonian(lh) hamiltonian = ift.StandardHamiltonian(lh)
ift.extra.check_value_gradient_consistency(hamiltonian, field) ift.extra.check_value_gradient_consistency(hamiltonian, field)
S = ift.ScalingOperator(1., space) S = ift.ScalingOperator(1., space)
samps = [S.draw_sample() for i in range(3)] samps = [S.draw_sample() for i in range(3)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment