diff --git a/nifty5/minimization/metric_gaussian_kl_mpi.py b/nifty5/minimization/metric_gaussian_kl_mpi.py
index 5a6b7f864edf7289072439718035a25adcb747c0..4175b202c5d2fcd1a69c336abe7a509d7a399738 100644
--- a/nifty5/minimization/metric_gaussian_kl_mpi.py
+++ b/nifty5/minimization/metric_gaussian_kl_mpi.py
@@ -18,9 +18,12 @@
from .. import utilities
from ..linearization import Linearization
from ..operators.energy_operators import StandardHamiltonian
+from ..operators.endomorphic_operator import EndomorphicOperator
from .energy import Energy
from mpi4py import MPI
import numpy as np
+from ..probing import approximation2endo
+from ..sugar import makeOp
from ..field import Field
from ..multi_field import MultiField
@@ -56,10 +59,83 @@ def allreduce_sum_field(fld):
return MultiField(fld.domain, res)
+class KLMetric(EndomorphicOperator):
+ def __init__(self, KL):
+ self._KL = KL
+ self._capability = self.TIMES | self.ADJOINT_TIMES
+ self._domain = KL.position.domain
+
+ def apply(self, x, mode):
+ self._check_input(x, mode)
+ return self._KL.apply_metric(x)
+
+ def draw_sample(self, from_inverse=False, dtype=np.float64):
+ self._KL.metric_sample(from_inverse, dtype)
+
+
+
class MetricGaussianKL_MPI(Energy):
+ """Provides the sampled Kullback-Leibler divergence between a distribution
+ and a Metric Gaussian.
+
+ A Metric Gaussian is used to approximate another probability distribution.
+ It is a Gaussian distribution that uses the Fisher information metric of
+ the other distribution at the location of its mean to approximate the
+ variance. In order to infer the mean, a stochastic estimate of the
+ Kullback-Leibler divergence is minimized. This estimate is obtained by
+ sampling the Metric Gaussian at the current mean. During minimization
+ these samples are kept constant; only the mean is updated. Due to the
+ typically nonlinear structure of the true distribution these samples have
+ to be updated eventually by intantiating `MetricGaussianKL` again. For the
+ true probability distribution the standard parametrization is assumed.
+ The samples of this class are distributed among MPI tasks.
+
+ Parameters
+ ----------
+ mean : Field
+ Mean of the Gaussian probability distribution.
+ hamiltonian : StandardHamiltonian
+ Hamiltonian of the approximated probability distribution.
+ n_samples : integer
+ Number of samples used to stochastically estimate the KL.
+ constants : list
+ List of parameter keys that are kept constant during optimization.
+ Default is no constants.
+ point_estimates : list
+ List of parameter keys for which no samples are drawn, but that are
+ (possibly) optimized for, corresponding to point estimates of these.
+ Default is to draw samples for the complete domain.
+ mirror_samples : boolean
+ Whether the negative of the drawn samples are also used,
+ as they are equally legitimate samples. If true, the number of used
+ samples doubles. Mirroring samples stabilizes the KL estimate as
+ extreme sample variation is counterbalanced. Default is False.
+ napprox : int
+ Number of samples for computing preconditioner for sampling. No
+ preconditioning is done by default.
+ _samples : None
+ Only a parameter for internal uses. Typically not to be set by users.
+ seed_offset : int
+ A parameter with which one can controll from which seed the samples
+ are drawn. Per default, the seed is different for MPI tasks, but the
+ same every time this class is initialized.
+
+ Note
+ ----
+ The two lists `constants` and `point_estimates` are independent from each
+ other. It is possible to sample along domains which are kept constant
+ during minimization and vice versa.
+
+ See also
+ --------
+ `Metric Gaussian Variational Inference`, Jakob Knollmüller,
+ Torsten A. Enßlin, ``_
+ """
+
+
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
- _samples=None, seed_offset=0):
+ napprox=0, _samples=None, seed_offset=0):
super(MetricGaussianKL_MPI, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
@@ -82,6 +158,8 @@ class MetricGaussianKL_MPI(Energy):
lo, hi = _shareRange(n_samples, ntask, rank)
met = hamiltonian(Linearization.make_partial_var(
mean, point_estimates, True)).metric
+ if napprox > 1:
+ met._approximation = makeOp(approximation2endo(met, napprox))
_samples = []
for i in range(lo, hi):
if mirror_samples:
@@ -142,8 +220,8 @@ class MetricGaussianKL_MPI(Energy):
else:
mymap = map(lambda v: self._hamiltonian(lin+v).metric,
self._samples)
- self._metric = utilities.my_sum(mymap)
- self._metric = self._metric.scale(1./self._n_samples)
+ self.unscaled_metric = utilities.my_sum(mymap)
+ self._metric = self.unscaled_metric.scale(1./self._n_samples)
def apply_metric(self, x):
self._get_metric()
@@ -151,12 +229,22 @@ class MetricGaussianKL_MPI(Energy):
@property
def metric(self):
- if ntask > 1:
- raise ValueError("not supported when MPI is active")
- return self._metric
+ return KLMetric(self)
@property
def samples(self):
res = _comm.allgather(self._samples)
res = [item for sublist in res for item in sublist]
return res
+
+ def unscaled_metric_sample(self, from_inverse=False, dtype=np.float64):
+ if from_inverse:
+ raise NotImplementedError()
+ lin = self._lin.with_want_metric()
+ samp = ift.full(self._hamiltonian.domain, 0.)
+ for s in self._samples:
+ samp = samp + self._hamiltonian(lin+v).metric.draw_sample(dtype)
+ return allreduce_sum_field(samp)
+
+ def metric_sample(self, from_inverse=False, dtype=np.float64):
+ return self.unscaled_metric_sample(from_inverse, dtype)/self._n_samples