Commit 2005e30f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more merges from operator_spectra

parent d0c8860d
Pipeline #61053 passed with stages
in 8 minutes and 58 seconds
......@@ -18,6 +18,8 @@
from .. import utilities
from ..linearization import Linearization
from ..operators.energy_operators import StandardHamiltonian
from ..probing import approximation2endo
from ..sugar import makeOp
from .energy import Energy
......@@ -56,6 +58,9 @@ class MetricGaussianKL(Energy):
as they are equally legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. Default is False.
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
_samples : None
Only a parameter for internal uses. Typically not to be set by users.
......@@ -67,12 +72,13 @@ class MetricGaussianKL(Energy):
See also
Metric Gaussian Variational Inference (FIXME in preparation)
`Metric Gaussian Variational Inference`, Jakob Knollmüller,
Torsten A. Enßlin, `<>`_
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
napprox=0, _samples=None):
super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
......@@ -91,12 +97,15 @@ class MetricGaussianKL(Energy):
if _samples is None:
met = hamiltonian(Linearization.make_partial_var(
mean, point_estimates, True)).metric
if napprox > 1:
met._approximation = makeOp(approximation2endo(met, napprox))
_samples = tuple(met.draw_sample(from_inverse=True)
for _ in range(n_samples))
if mirror_samples:
_samples += tuple(-s for s in _samples)
self._samples = _samples
# FIXME Use simplify for constant input instead
self._lin = Linearization.make_partial_var(mean, constants)
v, g = None, None
for s in self._samples:
......@@ -110,11 +119,12 @@ class MetricGaussianKL(Energy):
self._val = v / len(self._samples)
self._grad = g * (1./len(self._samples))
self._metric = None
self._napprox = napprox
def at(self, position):
return MetricGaussianKL(position, self._hamiltonian, 0,
self._constants, self._point_estimates,
napprox=self._napprox, _samples=self._samples)
def value(self):
......@@ -129,8 +139,12 @@ class MetricGaussianKL(Energy):
lin = self._lin.with_want_metric()
mymap = map(lambda v: self._hamiltonian(lin+v).metric,
self._metric = utilities.my_sum(mymap)
self._metric = self._metric.scale(1./len(self._samples))
self._unscaled_metric = utilities.my_sum(mymap)
self._metric = self._unscaled_metric.scale(1./len(self._samples))
def unscaled_metric(self):
return self._unscaled_metric, 1/len(self._samples)
def apply_metric(self, x):
......@@ -326,7 +326,7 @@ class NullOperator(LinearOperator):
return self._nullfield(self._tgt(mode))
class _PartialExtractor(LinearOperator):
class PartialExtractor(LinearOperator):
def __init__(self, domain, target):
if not isinstance(domain, MultiDomain):
raise TypeError("MultiDomain expected")
......@@ -335,7 +335,7 @@ class _PartialExtractor(LinearOperator):
self._domain = domain
self._target = target
for key in self._target.keys():
if not (self._domain[key] is not self._target[key]):
if self._domain[key] is not self._target[key]:
raise ValueError("domain mismatch")
self._capability = self.TIMES | self.ADJOINT_TIMES
......@@ -15,9 +15,10 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .multi_field import MultiField
from .operators.endomorphic_operator import EndomorphicOperator
from .operators.operator import Operator
from .sugar import from_random
from .sugar import from_global_data, from_random
class StatCalculator(object):
......@@ -134,3 +135,17 @@ def probe_diagonal(op, nprobes, random_type="pm1"):
x = from_random(random_type, op.domain)
return sc.mean
def approximation2endo(op, nsamples):
print('Calculate preconditioner')
sc = StatCalculator()
for _ in range(nsamples):
approx = sc.var
dct = approx.to_dict()
for kk in dct:
foo = dct[kk].to_global_data_rw()
foo[foo == 0] = 1
dct[kk] = from_global_data(dct[kk].domain, foo)
return MultiField.from_dict(dct)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <>.
# Copyright(C) 2013-2019 Max-Planck-Society
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import nifty5 as ift
from numpy.testing import assert_, assert_allclose
import pytest
pmp = pytest.mark.parametrize
@pmp('constants', ([], ['a'], ['b'], ['a', 'b']))
@pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b']))
@pmp('mirror_samples', (True, False))
def test_kl(constants, point_estimates, mirror_samples):
dom = ift.RGSpace((12,), (2.12))
op0 = ift.HarmonicSmoothingOperator(dom, 3)
op = ift.ducktape(dom, None, 'a')*(op0.ducktape('b'))
lh = ift.GaussianEnergy( @ op
ic = ift.GradientNormController(iteration_limit=5)
h = ift.StandardHamiltonian(lh, ic_samp=ic)
mean0 = ift.from_random('normal', h.domain)
nsamps = 2
kl = ift.MetricGaussianKL(mean0,
klpure = ift.MetricGaussianKL(mean0,
# Test value
assert_allclose(kl.value, klpure.value)
# Test gradient
for kk in h.domain.keys():
res0 = klpure.gradient.to_global_data()[kk]
if kk in constants:
res0 = 0*res0
res1 = kl.gradient.to_global_data()[kk]
assert_allclose(res0, res1)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(kl.samples) == expected_nsamps)
# Test point_estimates (after drawing samples)
for kk in point_estimates:
for ss in kl.samples:
ss = ss.to_global_data()[kk]
assert_allclose(ss, 0*ss)
# Test constants (after some minimization)
cg = ift.GradientNormController(iteration_limit=5)
minimizer = ift.NewtonCG(cg)
kl, _ = minimizer(kl)
diff = (mean0 - kl.position).to_global_data()
for kk in constants:
assert_allclose(diff[kk], 0*diff[kk])
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment