diff --git a/nifty5/minimization/metric_gaussian_kl.py b/nifty5/minimization/metric_gaussian_kl.py index b926b74c7b61d64e87622842060d33186bd6a7e4..203bc6e252e28c5136a0ee9e86528c994b1faed4 100644 --- a/nifty5/minimization/metric_gaussian_kl.py +++ b/nifty5/minimization/metric_gaussian_kl.py @@ -21,16 +21,19 @@ from .. import utilities class MetricGaussianKL(Energy): - """Provides the sampled Kullback-Leibler divergence between a distribution and a Metric Gaussian. - - A Metric Gaussian is used to approximate some other distribution. - It is a Gaussian distribution that uses the Fisher Information Metric - of the other distribution at the location of its mean to approximate the variance. - In order to infer the mean, the a stochastic estimate of the Kullback-Leibler divergence - is minimized. This estimate is obtained by drawing samples from the Metric Gaussian at the current mean. - During minimization these samples are kept constant, updating only the mean. Due to the typically nonlinear - structure of the true distribution these samples have to be updated by re-initializing this class at some point. - Here standard parametrization of the true distribution is assumed. + """Provides the sampled Kullback-Leibler divergence between a distribution + and a Metric Gaussian. + + A Metric Gaussian is used to approximate some other distribution. + It is a Gaussian distribution that uses the Fisher Information Metric + of the other distribution at the location of its mean to approximate the + variance. In order to infer the mean, the a stochastic estimate of the + Kullback-Leibler divergence is minimized. This estimate is obtained by + drawing samples from the Metric Gaussian at the current mean. + During minimization these samples are kept constant, updating only the + mean. Due to the typically nonlinear structure of the true distribution + these samples have to be updated by re-initializing this class at some + point. Here standard parametrization of the true distribution is assumed. Parameters ---------- @@ -53,7 +56,8 @@ class MetricGaussianKL(Energy): Notes ----- - For further details see: Metric Gaussian Variational Inference (in preparation) + For further details see: Metric Gaussian Variational Inference + (in preparation) """ def __init__(self, mean, hamiltonian, n_sampels, constants=[], @@ -106,7 +110,8 @@ class MetricGaussianKL(Energy): def _get_metric(self): if self._metric is None: lin = self._lin.with_want_metric() - mymap = map(lambda v: self._hamiltonian(lin+v).metric, self._samples) + mymap = map(lambda v: self._hamiltonian(lin+v).metric, + self._samples) self._metric = utilities.my_sum(mymap) self._metric = self._metric.scale(1./len(self._samples))