Commit 2005e30f authored by Martin Reinecke's avatar Martin Reinecke

more merges from operator_spectra

parent d0c8860d
Pipeline #61053 passed with stages
in 8 minutes and 58 seconds
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
from .. import utilities from .. import utilities
from ..linearization import Linearization from ..linearization import Linearization
from ..operators.energy_operators import StandardHamiltonian from ..operators.energy_operators import StandardHamiltonian
from ..probing import approximation2endo
from ..sugar import makeOp
from .energy import Energy from .energy import Energy
...@@ -56,6 +58,9 @@ class MetricGaussianKL(Energy): ...@@ -56,6 +58,9 @@ class MetricGaussianKL(Energy):
as they are equally legitimate samples. If true, the number of used as they are equally legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. Default is False. extreme sample variation is counterbalanced. Default is False.
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
_samples : None _samples : None
Only a parameter for internal uses. Typically not to be set by users. Only a parameter for internal uses. Typically not to be set by users.
...@@ -67,12 +72,13 @@ class MetricGaussianKL(Energy): ...@@ -67,12 +72,13 @@ class MetricGaussianKL(Energy):
See also See also
-------- --------
Metric Gaussian Variational Inference (FIXME in preparation) `Metric Gaussian Variational Inference`, Jakob Knollmüller,
Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
""" """
def __init__(self, mean, hamiltonian, n_samples, constants=[], def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False, point_estimates=[], mirror_samples=False,
_samples=None): napprox=0, _samples=None):
super(MetricGaussianKL, self).__init__(mean) super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian): if not isinstance(hamiltonian, StandardHamiltonian):
...@@ -91,12 +97,15 @@ class MetricGaussianKL(Energy): ...@@ -91,12 +97,15 @@ class MetricGaussianKL(Energy):
if _samples is None: if _samples is None:
met = hamiltonian(Linearization.make_partial_var( met = hamiltonian(Linearization.make_partial_var(
mean, point_estimates, True)).metric mean, point_estimates, True)).metric
if napprox > 1:
met._approximation = makeOp(approximation2endo(met, napprox))
_samples = tuple(met.draw_sample(from_inverse=True) _samples = tuple(met.draw_sample(from_inverse=True)
for _ in range(n_samples)) for _ in range(n_samples))
if mirror_samples: if mirror_samples:
_samples += tuple(-s for s in _samples) _samples += tuple(-s for s in _samples)
self._samples = _samples self._samples = _samples
# FIXME Use simplify for constant input instead
self._lin = Linearization.make_partial_var(mean, constants) self._lin = Linearization.make_partial_var(mean, constants)
v, g = None, None v, g = None, None
for s in self._samples: for s in self._samples:
...@@ -110,11 +119,12 @@ class MetricGaussianKL(Energy): ...@@ -110,11 +119,12 @@ class MetricGaussianKL(Energy):
self._val = v / len(self._samples) self._val = v / len(self._samples)
self._grad = g * (1./len(self._samples)) self._grad = g * (1./len(self._samples))
self._metric = None self._metric = None
self._napprox = napprox
def at(self, position): def at(self, position):
return MetricGaussianKL(position, self._hamiltonian, 0, return MetricGaussianKL(position, self._hamiltonian, 0,
self._constants, self._point_estimates, self._constants, self._point_estimates,
_samples=self._samples) napprox=self._napprox, _samples=self._samples)
@property @property
def value(self): def value(self):
...@@ -129,8 +139,12 @@ class MetricGaussianKL(Energy): ...@@ -129,8 +139,12 @@ class MetricGaussianKL(Energy):
lin = self._lin.with_want_metric() lin = self._lin.with_want_metric()
mymap = map(lambda v: self._hamiltonian(lin+v).metric, mymap = map(lambda v: self._hamiltonian(lin+v).metric,
self._samples) self._samples)
self._metric = utilities.my_sum(mymap) self._unscaled_metric = utilities.my_sum(mymap)
self._metric = self._metric.scale(1./len(self._samples)) self._metric = self._unscaled_metric.scale(1./len(self._samples))
def unscaled_metric(self):
self._get_metric()
return self._unscaled_metric, 1/len(self._samples)
def apply_metric(self, x): def apply_metric(self, x):
self._get_metric() self._get_metric()
......
...@@ -326,7 +326,7 @@ class NullOperator(LinearOperator): ...@@ -326,7 +326,7 @@ class NullOperator(LinearOperator):
return self._nullfield(self._tgt(mode)) return self._nullfield(self._tgt(mode))
class _PartialExtractor(LinearOperator): class PartialExtractor(LinearOperator):
def __init__(self, domain, target): def __init__(self, domain, target):
if not isinstance(domain, MultiDomain): if not isinstance(domain, MultiDomain):
raise TypeError("MultiDomain expected") raise TypeError("MultiDomain expected")
...@@ -335,7 +335,7 @@ class _PartialExtractor(LinearOperator): ...@@ -335,7 +335,7 @@ class _PartialExtractor(LinearOperator):
self._domain = domain self._domain = domain
self._target = target self._target = target
for key in self._target.keys(): for key in self._target.keys():
if not (self._domain[key] is not self._target[key]): if self._domain[key] is not self._target[key]:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
......
...@@ -15,9 +15,10 @@ ...@@ -15,9 +15,10 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .multi_field import MultiField
from .operators.endomorphic_operator import EndomorphicOperator from .operators.endomorphic_operator import EndomorphicOperator
from .operators.operator import Operator from .operators.operator import Operator
from .sugar import from_random from .sugar import from_global_data, from_random
class StatCalculator(object): class StatCalculator(object):
...@@ -134,3 +135,17 @@ def probe_diagonal(op, nprobes, random_type="pm1"): ...@@ -134,3 +135,17 @@ def probe_diagonal(op, nprobes, random_type="pm1"):
x = from_random(random_type, op.domain) x = from_random(random_type, op.domain)
sc.add(op(x).conjugate()*x) sc.add(op(x).conjugate()*x)
return sc.mean return sc.mean
def approximation2endo(op, nsamples):
print('Calculate preconditioner')
sc = StatCalculator()
for _ in range(nsamples):
sc.add(op.draw_sample())
approx = sc.var
dct = approx.to_dict()
for kk in dct:
foo = dct[kk].to_global_data_rw()
foo[foo == 0] = 1
dct[kk] = from_global_data(dct[kk].domain, foo)
return MultiField.from_dict(dct)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import nifty5 as ift
from numpy.testing import assert_, assert_allclose
import pytest
pmp = pytest.mark.parametrize
@pmp('constants', ([], ['a'], ['b'], ['a', 'b']))
@pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b']))
@pmp('mirror_samples', (True, False))
def test_kl(constants, point_estimates, mirror_samples):
np.random.seed(42)
dom = ift.RGSpace((12,), (2.12))
op0 = ift.HarmonicSmoothingOperator(dom, 3)
op = ift.ducktape(dom, None, 'a')*(op0.ducktape('b'))
lh = ift.GaussianEnergy(domain=op.target) @ op
ic = ift.GradientNormController(iteration_limit=5)
h = ift.StandardHamiltonian(lh, ic_samp=ic)
mean0 = ift.from_random('normal', h.domain)
nsamps = 2
kl = ift.MetricGaussianKL(mean0,
h,
nsamps,
constants=constants,
point_estimates=point_estimates,
mirror_samples=mirror_samples,
napprox=0)
klpure = ift.MetricGaussianKL(mean0,
h,
nsamps,
mirror_samples=mirror_samples,
napprox=0,
_samples=kl.samples)
# Test value
assert_allclose(kl.value, klpure.value)
# Test gradient
for kk in h.domain.keys():
res0 = klpure.gradient.to_global_data()[kk]
if kk in constants:
res0 = 0*res0
res1 = kl.gradient.to_global_data()[kk]
assert_allclose(res0, res1)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(kl.samples) == expected_nsamps)
# Test point_estimates (after drawing samples)
for kk in point_estimates:
for ss in kl.samples:
ss = ss.to_global_data()[kk]
assert_allclose(ss, 0*ss)
# Test constants (after some minimization)
cg = ift.GradientNormController(iteration_limit=5)
minimizer = ift.NewtonCG(cg)
kl, _ = minimizer(kl)
diff = (mean0 - kl.position).to_global_data()
for kk in constants:
assert_allclose(diff[kk], 0*diff[kk])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment