Commit ffc6059b authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'spectra_partial_merge' into 'NIFTy_5'

more merges from operator_spectra

See merge request !348
parents 6c109e41 784f49a3
Pipeline #61424 passed with stages
in 22 minutes and 56 seconds
...@@ -18,8 +18,14 @@ ...@@ -18,8 +18,14 @@
import numpy as np import numpy as np
from ..logger import logger from ..logger import logger
from ..probing import approximation2endo
from ..sugar import makeOp
from .conjugate_gradient import ConjugateGradient
from .iteration_controllers import (AbsDeltaEnergyController,
GradientNormController)
from .line_search import LineSearch from .line_search import LineSearch
from .minimizer import Minimizer from .minimizer import Minimizer
from .quadratic_energy import QuadraticEnergy
class DescentMinimizer(Minimizer): class DescentMinimizer(Minimizer):
...@@ -79,7 +85,8 @@ class DescentMinimizer(Minimizer): ...@@ -79,7 +85,8 @@ class DescentMinimizer(Minimizer):
# compute a step length that reduces energy.value sufficiently # compute a step length that reduces energy.value sufficiently
new_energy, success = self.line_searcher.perform_line_search( new_energy, success = self.line_searcher.perform_line_search(
energy=energy, pk=self.get_descent_direction(energy), energy=energy,
pk=self.get_descent_direction(energy, f_k_minus_1),
f_k_minus_1=f_k_minus_1) f_k_minus_1=f_k_minus_1)
if not success: if not success:
self.reset() self.reset()
...@@ -103,7 +110,7 @@ class DescentMinimizer(Minimizer): ...@@ -103,7 +110,7 @@ class DescentMinimizer(Minimizer):
def reset(self): def reset(self):
pass pass
def get_descent_direction(self, energy): def get_descent_direction(self, energy, old_value=None):
"""Calculates the next descent direction. """Calculates the next descent direction.
Parameters Parameters
...@@ -112,6 +119,10 @@ class DescentMinimizer(Minimizer): ...@@ -112,6 +119,10 @@ class DescentMinimizer(Minimizer):
An instance of the Energy class which shall be minimized. The An instance of the Energy class which shall be minimized. The
position of `energy` is used as the starting point of minimization. position of `energy` is used as the starting point of minimization.
old_value : float
if provided, this must be the value of the energy in the previous
step.
Returns Returns
------- -------
Field Field
...@@ -127,7 +138,7 @@ class SteepestDescent(DescentMinimizer): ...@@ -127,7 +138,7 @@ class SteepestDescent(DescentMinimizer):
functional's gradient for minimization. functional's gradient for minimization.
""" """
def get_descent_direction(self, energy): def get_descent_direction(self, energy, _=None):
return -energy.gradient return -energy.gradient
...@@ -144,7 +155,7 @@ class RelaxedNewton(DescentMinimizer): ...@@ -144,7 +155,7 @@ class RelaxedNewton(DescentMinimizer):
super(RelaxedNewton, self).__init__(controller=controller, super(RelaxedNewton, self).__init__(controller=controller,
line_searcher=line_searcher) line_searcher=line_searcher)
def get_descent_direction(self, energy): def get_descent_direction(self, energy, _=None):
return -energy.metric.inverse_times(energy.gradient) return -energy.metric.inverse_times(energy.gradient)
...@@ -154,44 +165,32 @@ class NewtonCG(DescentMinimizer): ...@@ -154,44 +165,32 @@ class NewtonCG(DescentMinimizer):
Algorithm derived from SciPy sources. Algorithm derived from SciPy sources.
""" """
def __init__(self, controller, line_searcher=None): def __init__(self, controller, napprox=0, line_searcher=None, name=None,
nreset=20, max_cg_iterations=200, energy_reduction_factor=0.1):
if line_searcher is None: if line_searcher is None:
line_searcher = LineSearch(preferred_initial_step_size=1.) line_searcher = LineSearch(preferred_initial_step_size=1.)
super(NewtonCG, self).__init__(controller=controller, super(NewtonCG, self).__init__(controller=controller,
line_searcher=line_searcher) line_searcher=line_searcher)
self._napprox = napprox
def get_descent_direction(self, energy): self._name = name
float64eps = np.finfo(np.float64).eps self._nreset = nreset
grad = energy.gradient self._max_cg_iterations = max_cg_iterations
maggrad = abs(grad).sum() self._alpha = energy_reduction_factor
termcond = np.min([0.5, np.sqrt(maggrad)]) * maggrad
xsupi = energy.position*0 def get_descent_direction(self, energy, old_value=None):
ri = grad if old_value is None:
psupi = -ri ic = GradientNormController(iteration_limit=5)
dri0 = ri.vdot(ri) else:
ediff = self._alpha*(old_value-energy.value)
i = 0 ic = AbsDeltaEnergyController(
while True: ediff, iteration_limit=self._max_cg_iterations, name=self._name)
if abs(ri).sum() <= termcond: e = QuadraticEnergy(0*energy.position, energy.metric, energy.gradient)
return xsupi p = None
Ap = energy.apply_metric(psupi) if self._napprox > 1:
# check curvature unscmet, sc = energy.unscaled_metric()
curv = psupi.vdot(Ap) p = makeOp(approximation2endo(unscmet, self._napprox)*sc).inverse
if 0 <= curv <= 3*float64eps: e, conv = ConjugateGradient(ic, nreset=self._nreset)(e, p)
return xsupi return -e.position
elif curv < 0:
return xsupi if i > 0 else (dri0/curv) * grad
alphai = dri0/curv
xsupi = xsupi + alphai*psupi
ri = ri + alphai*Ap
dri1 = ri.vdot(ri)
psupi = (dri1/dri0)*psupi - ri
i += 1
dri0 = dri1 # update numpy.dot(ri,ri) for next time.
# curvature keeps increasing, bail out
raise ValueError("Warning: CG iterations didn't converge. "
"The Hessian is not positive definite.")
class L_BFGS(DescentMinimizer): class L_BFGS(DescentMinimizer):
...@@ -210,7 +209,7 @@ class L_BFGS(DescentMinimizer): ...@@ -210,7 +209,7 @@ class L_BFGS(DescentMinimizer):
self._s = [None]*self.max_history_length self._s = [None]*self.max_history_length
self._y = [None]*self.max_history_length self._y = [None]*self.max_history_length
def get_descent_direction(self, energy): def get_descent_direction(self, energy, _=None):
x = energy.position x = energy.position
s = self._s s = self._s
y = self._y y = self._y
...@@ -273,7 +272,7 @@ class VL_BFGS(DescentMinimizer): ...@@ -273,7 +272,7 @@ class VL_BFGS(DescentMinimizer):
def reset(self): def reset(self):
self._information_store = None self._information_store = None
def get_descent_direction(self, energy): def get_descent_direction(self, energy, _=None):
x = energy.position x = energy.position
gradient = energy.gradient gradient = energy.gradient
# initialize the information store if it doesn't already exist # initialize the information store if it doesn't already exist
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
from .. import utilities from .. import utilities
from ..linearization import Linearization from ..linearization import Linearization
from ..operators.energy_operators import StandardHamiltonian from ..operators.energy_operators import StandardHamiltonian
from ..probing import approximation2endo
from ..sugar import makeOp
from .energy import Energy from .energy import Energy
...@@ -56,6 +58,9 @@ class MetricGaussianKL(Energy): ...@@ -56,6 +58,9 @@ class MetricGaussianKL(Energy):
as they are equally legitimate samples. If true, the number of used as they are equally legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. Default is False. extreme sample variation is counterbalanced. Default is False.
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
_samples : None _samples : None
Only a parameter for internal uses. Typically not to be set by users. Only a parameter for internal uses. Typically not to be set by users.
...@@ -67,12 +72,13 @@ class MetricGaussianKL(Energy): ...@@ -67,12 +72,13 @@ class MetricGaussianKL(Energy):
See also See also
-------- --------
Metric Gaussian Variational Inference (FIXME in preparation) `Metric Gaussian Variational Inference`, Jakob Knollmüller,
Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
""" """
def __init__(self, mean, hamiltonian, n_samples, constants=[], def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False, point_estimates=[], mirror_samples=False,
_samples=None): napprox=0, _samples=None):
super(MetricGaussianKL, self).__init__(mean) super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian): if not isinstance(hamiltonian, StandardHamiltonian):
...@@ -91,12 +97,15 @@ class MetricGaussianKL(Energy): ...@@ -91,12 +97,15 @@ class MetricGaussianKL(Energy):
if _samples is None: if _samples is None:
met = hamiltonian(Linearization.make_partial_var( met = hamiltonian(Linearization.make_partial_var(
mean, point_estimates, True)).metric mean, point_estimates, True)).metric
if napprox > 1:
met._approximation = makeOp(approximation2endo(met, napprox))
_samples = tuple(met.draw_sample(from_inverse=True) _samples = tuple(met.draw_sample(from_inverse=True)
for _ in range(n_samples)) for _ in range(n_samples))
if mirror_samples: if mirror_samples:
_samples += tuple(-s for s in _samples) _samples += tuple(-s for s in _samples)
self._samples = _samples self._samples = _samples
# FIXME Use simplify for constant input instead
self._lin = Linearization.make_partial_var(mean, constants) self._lin = Linearization.make_partial_var(mean, constants)
v, g = None, None v, g = None, None
for s in self._samples: for s in self._samples:
...@@ -110,11 +119,12 @@ class MetricGaussianKL(Energy): ...@@ -110,11 +119,12 @@ class MetricGaussianKL(Energy):
self._val = v / len(self._samples) self._val = v / len(self._samples)
self._grad = g * (1./len(self._samples)) self._grad = g * (1./len(self._samples))
self._metric = None self._metric = None
self._napprox = napprox
def at(self, position): def at(self, position):
return MetricGaussianKL(position, self._hamiltonian, 0, return MetricGaussianKL(position, self._hamiltonian, 0,
self._constants, self._point_estimates, self._constants, self._point_estimates,
_samples=self._samples) napprox=self._napprox, _samples=self._samples)
@property @property
def value(self): def value(self):
...@@ -129,8 +139,12 @@ class MetricGaussianKL(Energy): ...@@ -129,8 +139,12 @@ class MetricGaussianKL(Energy):
lin = self._lin.with_want_metric() lin = self._lin.with_want_metric()
mymap = map(lambda v: self._hamiltonian(lin+v).metric, mymap = map(lambda v: self._hamiltonian(lin+v).metric,
self._samples) self._samples)
self._metric = utilities.my_sum(mymap) self._unscaled_metric = utilities.my_sum(mymap)
self._metric = self._metric.scale(1./len(self._samples)) self._metric = self._unscaled_metric.scale(1./len(self._samples))
def unscaled_metric(self):
self._get_metric()
return self._unscaled_metric, 1/len(self._samples)
def apply_metric(self, x): def apply_metric(self, x):
self._get_metric() self._get_metric()
......
...@@ -326,7 +326,7 @@ class NullOperator(LinearOperator): ...@@ -326,7 +326,7 @@ class NullOperator(LinearOperator):
return self._nullfield(self._tgt(mode)) return self._nullfield(self._tgt(mode))
class _PartialExtractor(LinearOperator): class PartialExtractor(LinearOperator):
def __init__(self, domain, target): def __init__(self, domain, target):
if not isinstance(domain, MultiDomain): if not isinstance(domain, MultiDomain):
raise TypeError("MultiDomain expected") raise TypeError("MultiDomain expected")
...@@ -335,7 +335,7 @@ class _PartialExtractor(LinearOperator): ...@@ -335,7 +335,7 @@ class _PartialExtractor(LinearOperator):
self._domain = domain self._domain = domain
self._target = target self._target = target
for key in self._target.keys(): for key in self._target.keys():
if not (self._domain[key] is not self._target[key]): if self._domain[key] is not self._target[key]:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
......
...@@ -15,9 +15,10 @@ ...@@ -15,9 +15,10 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .multi_field import MultiField
from .operators.endomorphic_operator import EndomorphicOperator from .operators.endomorphic_operator import EndomorphicOperator
from .operators.operator import Operator from .operators.operator import Operator
from .sugar import from_random from .sugar import from_global_data, from_random
class StatCalculator(object): class StatCalculator(object):
...@@ -134,3 +135,16 @@ def probe_diagonal(op, nprobes, random_type="pm1"): ...@@ -134,3 +135,16 @@ def probe_diagonal(op, nprobes, random_type="pm1"):
x = from_random(random_type, op.domain) x = from_random(random_type, op.domain)
sc.add(op(x).conjugate()*x) sc.add(op(x).conjugate()*x)
return sc.mean return sc.mean
def approximation2endo(op, nsamples):
sc = StatCalculator()
for _ in range(nsamples):
sc.add(op.draw_sample())
approx = sc.var
dct = approx.to_dict()
for kk in dct:
foo = dct[kk].to_global_data_rw()
foo[foo == 0] = 1
dct[kk] = from_global_data(dct[kk].domain, foo)
return MultiField.from_dict(dct)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import nifty5 as ift
from numpy.testing import assert_, assert_allclose
import pytest
pmp = pytest.mark.parametrize
@pmp('constants', ([], ['a'], ['b'], ['a', 'b']))
@pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b']))
@pmp('mirror_samples', (True, False))
def test_kl(constants, point_estimates, mirror_samples):
np.random.seed(42)
dom = ift.RGSpace((12,), (2.12))
op0 = ift.HarmonicSmoothingOperator(dom, 3)
op = ift.ducktape(dom, None, 'a')*(op0.ducktape('b'))
lh = ift.GaussianEnergy(domain=op.target) @ op
ic = ift.GradientNormController(iteration_limit=5)
h = ift.StandardHamiltonian(lh, ic_samp=ic)
mean0 = ift.from_random('normal', h.domain)
nsamps = 2
kl = ift.MetricGaussianKL(mean0,
h,
nsamps,
constants=constants,
point_estimates=point_estimates,
mirror_samples=mirror_samples,
napprox=0)
klpure = ift.MetricGaussianKL(mean0,
h,
nsamps,
mirror_samples=mirror_samples,
napprox=0,
_samples=kl.samples)
# Test value
assert_allclose(kl.value, klpure.value)
# Test gradient
for kk in h.domain.keys():
res0 = klpure.gradient.to_global_data()[kk]
if kk in constants:
res0 = 0*res0
res1 = kl.gradient.to_global_data()[kk]
assert_allclose(res0, res1)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(kl.samples) == expected_nsamps)
# Test point_estimates (after drawing samples)
for kk in point_estimates:
for ss in kl.samples:
ss = ss.to_global_data()[kk]
assert_allclose(ss, 0*ss)
# Test constants (after some minimization)
cg = ift.GradientNormController(iteration_limit=5)
minimizer = ift.NewtonCG(cg)
kl, _ = minimizer(kl)
diff = (mean0 - kl.position).to_global_data()
for kk in constants:
assert_allclose(diff[kk], 0*diff[kk])
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment