Commit a2916703 authored by Martin Reinecke's avatar Martin Reinecke

Introduce KL_Energy (which might be parallelized in the future)

parent 747e2082
...@@ -91,26 +91,21 @@ if __name__ == '__main__': ...@@ -91,26 +91,21 @@ if __name__ == '__main__':
# number of samples used to estimate the KL # number of samples used to estimate the KL
N_samples = 20 N_samples = 20
for i in range(2): for i in range(2):
metric = H(ift.Linearization.make_var(position)).metric KL = ift.KL_Energy(position, H, N_samples)
samples = [metric.draw_sample(from_inverse=True)
for _ in range(N_samples)]
KL = ift.SampledKullbachLeiblerDivergence(H, samples)
KL = ift.EnergyAdapter(position, KL)
KL, convergence = minimizer(KL) KL, convergence = minimizer(KL)
position = KL.position position = KL.position
ift.plot(signal(position), title="reconstruction") ift.plot(signal(KL.position), title="reconstruction")
ift.plot([A(position), A(MOCK_POSITION)], title="power") ift.plot([A(KL.position), A(MOCK_POSITION)], title="power")
ift.plot_finish(nx=2, xsize=12, ysize=6, title="loop", name="loop.png") ift.plot_finish(nx=2, xsize=12, ysize=6, title="loop", name="loop.png")
sc = ift.StatCalculator() sc = ift.StatCalculator()
for sample in samples: for sample in KL.samples:
sc.add(signal(sample+position)) sc.add(signal(sample+KL.position))
ift.plot(sc.mean, title="mean") ift.plot(sc.mean, title="mean")
ift.plot(ift.sqrt(sc.var), title="std deviation") ift.plot(ift.sqrt(sc.var), title="std deviation")
powers = [A(s+position) for s in samples] powers = [A(s+KL.position) for s in KL.samples]
ift.plot([A(position), A(MOCK_POSITION)]+powers, title="power") ift.plot([A(KL.position), A(MOCK_POSITION)]+powers, title="power")
ift.plot_finish(nx=3, xsize=16, ysize=5, title="results", ift.plot_finish(nx=3, xsize=16, ysize=5, title="results",
name="results.png") name="results.png")
...@@ -66,6 +66,7 @@ from .minimization.energy import Energy ...@@ -66,6 +66,7 @@ from .minimization.energy import Energy
from .minimization.quadratic_energy import QuadraticEnergy from .minimization.quadratic_energy import QuadraticEnergy
from .minimization.line_energy import LineEnergy from .minimization.line_energy import LineEnergy
from .minimization.energy_adapter import EnergyAdapter from .minimization.energy_adapter import EnergyAdapter
from .minimization.kl_energy import KL_Energy
from .sugar import * from .sugar import *
from .plotting.plot import plot, plot_finish from .plotting.plot import plot, plot_finish
......
...@@ -26,18 +26,6 @@ class EnergyAdapter(Energy): ...@@ -26,18 +26,6 @@ class EnergyAdapter(Energy):
def at(self, position): def at(self, position):
return EnergyAdapter(position, self._op, self._constants) return EnergyAdapter(position, self._op, self._constants)
def _fill_all(self):
if len(self._constants) == 0:
tmp = self._op(Linearization.make_var(self._position))
else:
ops = [ScalingOperator(0. if key in self._constants else 1., dom)
for key, dom in self._position.domain.items()]
bdop = BlockDiagonalOperator(self._position.domain, tuple(ops))
tmp = self._op(Linearization(self._position, bdop))
self._val = tmp.val.local_data[()]
self._grad = tmp.gradient
self._metric = tmp._metric
@property @property
def value(self): def value(self):
return self._val return self._val
......
from __future__ import absolute_import, division, print_function
from ..compat import *
from .energy import Energy
from ..linearization import Linearization
from ..operators.scaling_operator import ScalingOperator
from ..operators.block_diagonal_operator import BlockDiagonalOperator
from .. import utilities
class KL_Energy(Energy):
def __init__(self, position, h, nsamp, constants=[], _samples=None):
super(KL_Energy, self).__init__(position)
self._h = h
self._constants = constants
if _samples is None:
met = h(Linearization.make_var(position)).metric
_samples = tuple(met.draw_sample(from_inverse=True)
for _ in range(nsamp))
self._samples = _samples
if len(constants) == 0:
tmp = Linearization.make_var(position)
else:
ops = [ScalingOperator(0. if key in constants else 1., dom)
for key, dom in position.domain.items()]
bdop = BlockDiagonalOperator(position.domain, tuple(ops))
tmp = Linearization(position, bdop)
mymap = map(lambda v: self._h(tmp+v), self._samples)
tmp = utilities.my_sum(mymap) * (1./len(self._samples))
self._val = tmp.val.local_data[()]
self._grad = tmp.gradient
self._metric = tmp.metric
def at(self, position):
return KL_Energy(position, self._h, 0, self._constants, self._samples)
@property
def value(self):
return self._val
@property
def gradient(self):
return self._grad
def apply_metric(self, x):
return self._metric(x)
@property
def samples(self):
return self._samples
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment