From a29167035480c570ab30223711618fb7f1817d7b Mon Sep 17 00:00:00 2001 From: Martin Reinecke Date: Mon, 20 Aug 2018 16:10:29 +0200 Subject: [PATCH] Introduce KL_Energy (which might be parallelized in the future) --- demos/getting_started_3.py | 19 ++++------ nifty5/__init__.py | 1 + nifty5/minimization/energy_adapter.py | 12 ------- nifty5/minimization/kl_energy.py | 50 +++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 24 deletions(-) create mode 100644 nifty5/minimization/kl_energy.py diff --git a/demos/getting_started_3.py b/demos/getting_started_3.py index 43387801..54e7a4c9 100644 --- a/demos/getting_started_3.py +++ b/demos/getting_started_3.py @@ -91,26 +91,21 @@ if __name__ == '__main__': # number of samples used to estimate the KL N_samples = 20 for i in range(2): - metric = H(ift.Linearization.make_var(position)).metric - samples = [metric.draw_sample(from_inverse=True) - for _ in range(N_samples)] - - KL = ift.SampledKullbachLeiblerDivergence(H, samples) - KL = ift.EnergyAdapter(position, KL) + KL = ift.KL_Energy(position, H, N_samples) KL, convergence = minimizer(KL) position = KL.position - ift.plot(signal(position), title="reconstruction") - ift.plot([A(position), A(MOCK_POSITION)], title="power") + ift.plot(signal(KL.position), title="reconstruction") + ift.plot([A(KL.position), A(MOCK_POSITION)], title="power") ift.plot_finish(nx=2, xsize=12, ysize=6, title="loop", name="loop.png") sc = ift.StatCalculator() - for sample in samples: - sc.add(signal(sample+position)) + for sample in KL.samples: + sc.add(signal(sample+KL.position)) ift.plot(sc.mean, title="mean") ift.plot(ift.sqrt(sc.var), title="std deviation") - powers = [A(s+position) for s in samples] - ift.plot([A(position), A(MOCK_POSITION)]+powers, title="power") + powers = [A(s+KL.position) for s in KL.samples] + ift.plot([A(KL.position), A(MOCK_POSITION)]+powers, title="power") ift.plot_finish(nx=3, xsize=16, ysize=5, title="results", name="results.png") diff --git a/nifty5/__init__.py b/nifty5/__init__.py index 3576309b..ccbbbac8 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -66,6 +66,7 @@ from .minimization.energy import Energy from .minimization.quadratic_energy import QuadraticEnergy from .minimization.line_energy import LineEnergy from .minimization.energy_adapter import EnergyAdapter +from .minimization.kl_energy import KL_Energy from .sugar import * from .plotting.plot import plot, plot_finish diff --git a/nifty5/minimization/energy_adapter.py b/nifty5/minimization/energy_adapter.py index 34c26ea0..7d5c0e29 100644 --- a/nifty5/minimization/energy_adapter.py +++ b/nifty5/minimization/energy_adapter.py @@ -26,18 +26,6 @@ class EnergyAdapter(Energy): def at(self, position): return EnergyAdapter(position, self._op, self._constants) - def _fill_all(self): - if len(self._constants) == 0: - tmp = self._op(Linearization.make_var(self._position)) - else: - ops = [ScalingOperator(0. if key in self._constants else 1., dom) - for key, dom in self._position.domain.items()] - bdop = BlockDiagonalOperator(self._position.domain, tuple(ops)) - tmp = self._op(Linearization(self._position, bdop)) - self._val = tmp.val.local_data[()] - self._grad = tmp.gradient - self._metric = tmp._metric - @property def value(self): return self._val diff --git a/nifty5/minimization/kl_energy.py b/nifty5/minimization/kl_energy.py new file mode 100644 index 00000000..d21dbdaf --- /dev/null +++ b/nifty5/minimization/kl_energy.py @@ -0,0 +1,50 @@ +from __future__ import absolute_import, division, print_function + +from ..compat import * +from .energy import Energy +from ..linearization import Linearization +from ..operators.scaling_operator import ScalingOperator +from ..operators.block_diagonal_operator import BlockDiagonalOperator +from .. import utilities + + +class KL_Energy(Energy): + def __init__(self, position, h, nsamp, constants=[], _samples=None): + super(KL_Energy, self).__init__(position) + self._h = h + self._constants = constants + if _samples is None: + met = h(Linearization.make_var(position)).metric + _samples = tuple(met.draw_sample(from_inverse=True) + for _ in range(nsamp)) + self._samples = _samples + if len(constants) == 0: + tmp = Linearization.make_var(position) + else: + ops = [ScalingOperator(0. if key in constants else 1., dom) + for key, dom in position.domain.items()] + bdop = BlockDiagonalOperator(position.domain, tuple(ops)) + tmp = Linearization(position, bdop) + mymap = map(lambda v: self._h(tmp+v), self._samples) + tmp = utilities.my_sum(mymap) * (1./len(self._samples)) + self._val = tmp.val.local_data[()] + self._grad = tmp.gradient + self._metric = tmp.metric + + def at(self, position): + return KL_Energy(position, self._h, 0, self._constants, self._samples) + + @property + def value(self): + return self._val + + @property + def gradient(self): + return self._grad + + def apply_metric(self, x): + return self._metric(x) + + @property + def samples(self): + return self._samples -- GitLab