diff --git a/nifty5/minimization/metric_gaussian_kl.py b/nifty5/minimization/metric_gaussian_kl.py index c42844a36287fa538eeee00b0e5278527df46b57..74f83a39409da890a888bb73e4c70cda93c80cd3 100644 --- a/nifty5/minimization/metric_gaussian_kl.py +++ b/nifty5/minimization/metric_gaussian_kl.py @@ -18,6 +18,8 @@ from .. import utilities from ..linearization import Linearization from ..operators.energy_operators import StandardHamiltonian +from ..probing import approximation2endo +from ..sugar import makeOp from .energy import Energy @@ -56,6 +58,9 @@ class MetricGaussianKL(Energy): as they are equally legitimate samples. If true, the number of used samples doubles. Mirroring samples stabilizes the KL estimate as extreme sample variation is counterbalanced. Default is False. + napprox : int + Number of samples for computing preconditioner for sampling. No + preconditioning is done by default. _samples : None Only a parameter for internal uses. Typically not to be set by users. @@ -67,12 +72,13 @@ class MetricGaussianKL(Energy): See also -------- - Metric Gaussian Variational Inference (FIXME in preparation) + `Metric Gaussian Variational Inference`, Jakob Knollmüller, + Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_ """ def __init__(self, mean, hamiltonian, n_samples, constants=[], point_estimates=[], mirror_samples=False, - _samples=None): + napprox=0, _samples=None): super(MetricGaussianKL, self).__init__(mean) if not isinstance(hamiltonian, StandardHamiltonian): @@ -91,12 +97,15 @@ class MetricGaussianKL(Energy): if _samples is None: met = hamiltonian(Linearization.make_partial_var( mean, point_estimates, True)).metric + if napprox > 1: + met._approximation = makeOp(approximation2endo(met, napprox)) _samples = tuple(met.draw_sample(from_inverse=True) for _ in range(n_samples)) if mirror_samples: _samples += tuple(-s for s in _samples) self._samples = _samples + # FIXME Use simplify for constant input instead self._lin = Linearization.make_partial_var(mean, constants) v, g = None, None for s in self._samples: @@ -110,11 +119,12 @@ class MetricGaussianKL(Energy): self._val = v / len(self._samples) self._grad = g * (1./len(self._samples)) self._metric = None + self._napprox = napprox def at(self, position): return MetricGaussianKL(position, self._hamiltonian, 0, self._constants, self._point_estimates, - _samples=self._samples) + napprox=self._napprox, _samples=self._samples) @property def value(self): @@ -129,8 +139,12 @@ class MetricGaussianKL(Energy): lin = self._lin.with_want_metric() mymap = map(lambda v: self._hamiltonian(lin+v).metric, self._samples) - self._metric = utilities.my_sum(mymap) - self._metric = self._metric.scale(1./len(self._samples)) + self._unscaled_metric = utilities.my_sum(mymap) + self._metric = self._unscaled_metric.scale(1./len(self._samples)) + + def unscaled_metric(self): + self._get_metric() + return self._unscaled_metric, 1/len(self._samples) def apply_metric(self, x): self._get_metric() diff --git a/nifty5/operators/simple_linear_operators.py b/nifty5/operators/simple_linear_operators.py index 80905fbfb395930367f8af57d042939b692340ec..9fd8efab3fbc08f1295bfa80812ee2b949b651a0 100644 --- a/nifty5/operators/simple_linear_operators.py +++ b/nifty5/operators/simple_linear_operators.py @@ -326,7 +326,7 @@ class NullOperator(LinearOperator): return self._nullfield(self._tgt(mode)) -class _PartialExtractor(LinearOperator): +class PartialExtractor(LinearOperator): def __init__(self, domain, target): if not isinstance(domain, MultiDomain): raise TypeError("MultiDomain expected") @@ -335,7 +335,7 @@ class _PartialExtractor(LinearOperator): self._domain = domain self._target = target for key in self._target.keys(): - if not (self._domain[key] is not self._target[key]): + if self._domain[key] is not self._target[key]: raise ValueError("domain mismatch") self._capability = self.TIMES | self.ADJOINT_TIMES diff --git a/nifty5/probing.py b/nifty5/probing.py index e5c06392258ba5286868ca63217f7789bdea9d9e..2c1ad8a3fae03ce3f6cdd4970819cbfdc5385049 100644 --- a/nifty5/probing.py +++ b/nifty5/probing.py @@ -15,9 +15,10 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. +from .multi_field import MultiField from .operators.endomorphic_operator import EndomorphicOperator from .operators.operator import Operator -from .sugar import from_random +from .sugar import from_global_data, from_random class StatCalculator(object): @@ -134,3 +135,17 @@ def probe_diagonal(op, nprobes, random_type="pm1"): x = from_random(random_type, op.domain) sc.add(op(x).conjugate()*x) return sc.mean + + +def approximation2endo(op, nsamples): + print('Calculate preconditioner') + sc = StatCalculator() + for _ in range(nsamples): + sc.add(op.draw_sample()) + approx = sc.var + dct = approx.to_dict() + for kk in dct: + foo = dct[kk].to_global_data_rw() + foo[foo == 0] = 1 + dct[kk] = from_global_data(dct[kk].domain, foo) + return MultiField.from_dict(dct) diff --git a/test/test_kl.py b/test/test_kl.py new file mode 100644 index 0000000000000000000000000000000000000000..79428c5b5c0a176109687e6ba01a06a562f72f38 --- /dev/null +++ b/test/test_kl.py @@ -0,0 +1,82 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2019 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. + +import numpy as np + +import nifty5 as ift +from numpy.testing import assert_, assert_allclose +import pytest + +pmp = pytest.mark.parametrize + + +@pmp('constants', ([], ['a'], ['b'], ['a', 'b'])) +@pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b'])) +@pmp('mirror_samples', (True, False)) +def test_kl(constants, point_estimates, mirror_samples): + np.random.seed(42) + dom = ift.RGSpace((12,), (2.12)) + op0 = ift.HarmonicSmoothingOperator(dom, 3) + op = ift.ducktape(dom, None, 'a')*(op0.ducktape('b')) + lh = ift.GaussianEnergy(domain=op.target) @ op + ic = ift.GradientNormController(iteration_limit=5) + h = ift.StandardHamiltonian(lh, ic_samp=ic) + mean0 = ift.from_random('normal', h.domain) + + nsamps = 2 + kl = ift.MetricGaussianKL(mean0, + h, + nsamps, + constants=constants, + point_estimates=point_estimates, + mirror_samples=mirror_samples, + napprox=0) + klpure = ift.MetricGaussianKL(mean0, + h, + nsamps, + mirror_samples=mirror_samples, + napprox=0, + _samples=kl.samples) + + # Test value + assert_allclose(kl.value, klpure.value) + + # Test gradient + for kk in h.domain.keys(): + res0 = klpure.gradient.to_global_data()[kk] + if kk in constants: + res0 = 0*res0 + res1 = kl.gradient.to_global_data()[kk] + assert_allclose(res0, res1) + + # Test number of samples + expected_nsamps = 2*nsamps if mirror_samples else nsamps + assert_(len(kl.samples) == expected_nsamps) + + # Test point_estimates (after drawing samples) + for kk in point_estimates: + for ss in kl.samples: + ss = ss.to_global_data()[kk] + assert_allclose(ss, 0*ss) + + # Test constants (after some minimization) + cg = ift.GradientNormController(iteration_limit=5) + minimizer = ift.NewtonCG(cg) + kl, _ = minimizer(kl) + diff = (mean0 - kl.position).to_global_data() + for kk in constants: + assert_allclose(diff[kk], 0*diff[kk])