diff --git a/nifty5/minimization/metric_gaussian_kl_mpi.py b/nifty5/minimization/metric_gaussian_kl_mpi.py index 5a6b7f864edf7289072439718035a25adcb747c0..4175b202c5d2fcd1a69c336abe7a509d7a399738 100644 --- a/nifty5/minimization/metric_gaussian_kl_mpi.py +++ b/nifty5/minimization/metric_gaussian_kl_mpi.py @@ -18,9 +18,12 @@ from .. import utilities from ..linearization import Linearization from ..operators.energy_operators import StandardHamiltonian +from ..operators.endomorphic_operator import EndomorphicOperator from .energy import Energy from mpi4py import MPI import numpy as np +from ..probing import approximation2endo +from ..sugar import makeOp from ..field import Field from ..multi_field import MultiField @@ -56,10 +59,83 @@ def allreduce_sum_field(fld): return MultiField(fld.domain, res) +class KLMetric(EndomorphicOperator): + def __init__(self, KL): + self._KL = KL + self._capability = self.TIMES | self.ADJOINT_TIMES + self._domain = KL.position.domain + + def apply(self, x, mode): + self._check_input(x, mode) + return self._KL.apply_metric(x) + + def draw_sample(self, from_inverse=False, dtype=np.float64): + self._KL.metric_sample(from_inverse, dtype) + + + class MetricGaussianKL_MPI(Energy): + """Provides the sampled Kullback-Leibler divergence between a distribution + and a Metric Gaussian. + + A Metric Gaussian is used to approximate another probability distribution. + It is a Gaussian distribution that uses the Fisher information metric of + the other distribution at the location of its mean to approximate the + variance. In order to infer the mean, a stochastic estimate of the + Kullback-Leibler divergence is minimized. This estimate is obtained by + sampling the Metric Gaussian at the current mean. During minimization + these samples are kept constant; only the mean is updated. Due to the + typically nonlinear structure of the true distribution these samples have + to be updated eventually by intantiating `MetricGaussianKL` again. For the + true probability distribution the standard parametrization is assumed. + The samples of this class are distributed among MPI tasks. + + Parameters + ---------- + mean : Field + Mean of the Gaussian probability distribution. + hamiltonian : StandardHamiltonian + Hamiltonian of the approximated probability distribution. + n_samples : integer + Number of samples used to stochastically estimate the KL. + constants : list + List of parameter keys that are kept constant during optimization. + Default is no constants. + point_estimates : list + List of parameter keys for which no samples are drawn, but that are + (possibly) optimized for, corresponding to point estimates of these. + Default is to draw samples for the complete domain. + mirror_samples : boolean + Whether the negative of the drawn samples are also used, + as they are equally legitimate samples. If true, the number of used + samples doubles. Mirroring samples stabilizes the KL estimate as + extreme sample variation is counterbalanced. Default is False. + napprox : int + Number of samples for computing preconditioner for sampling. No + preconditioning is done by default. + _samples : None + Only a parameter for internal uses. Typically not to be set by users. + seed_offset : int + A parameter with which one can controll from which seed the samples + are drawn. Per default, the seed is different for MPI tasks, but the + same every time this class is initialized. + + Note + ---- + The two lists `constants` and `point_estimates` are independent from each + other. It is possible to sample along domains which are kept constant + during minimization and vice versa. + + See also + -------- + `Metric Gaussian Variational Inference`, Jakob Knollmüller, + Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_ + """ + + def __init__(self, mean, hamiltonian, n_samples, constants=[], point_estimates=[], mirror_samples=False, - _samples=None, seed_offset=0): + napprox=0, _samples=None, seed_offset=0): super(MetricGaussianKL_MPI, self).__init__(mean) if not isinstance(hamiltonian, StandardHamiltonian): @@ -82,6 +158,8 @@ class MetricGaussianKL_MPI(Energy): lo, hi = _shareRange(n_samples, ntask, rank) met = hamiltonian(Linearization.make_partial_var( mean, point_estimates, True)).metric + if napprox > 1: + met._approximation = makeOp(approximation2endo(met, napprox)) _samples = [] for i in range(lo, hi): if mirror_samples: @@ -142,8 +220,8 @@ class MetricGaussianKL_MPI(Energy): else: mymap = map(lambda v: self._hamiltonian(lin+v).metric, self._samples) - self._metric = utilities.my_sum(mymap) - self._metric = self._metric.scale(1./self._n_samples) + self.unscaled_metric = utilities.my_sum(mymap) + self._metric = self.unscaled_metric.scale(1./self._n_samples) def apply_metric(self, x): self._get_metric() @@ -151,12 +229,22 @@ class MetricGaussianKL_MPI(Energy): @property def metric(self): - if ntask > 1: - raise ValueError("not supported when MPI is active") - return self._metric + return KLMetric(self) @property def samples(self): res = _comm.allgather(self._samples) res = [item for sublist in res for item in sublist] return res + + def unscaled_metric_sample(self, from_inverse=False, dtype=np.float64): + if from_inverse: + raise NotImplementedError() + lin = self._lin.with_want_metric() + samp = ift.full(self._hamiltonian.domain, 0.) + for s in self._samples: + samp = samp + self._hamiltonian(lin+v).metric.draw_sample(dtype) + return allreduce_sum_field(samp) + + def metric_sample(self, from_inverse=False, dtype=np.float64): + return self.unscaled_metric_sample(from_inverse, dtype)/self._n_samples