Commit 007fb8dc authored by Martin Reinecke's avatar Martin Reinecke

adjust KL

parent 34a385d9
......@@ -97,9 +97,8 @@ if __name__ == '__main__':
# build model Hamiltonian
H = ift.Hamiltonian(likelihood, ic_sampling)
H = EnergyAdapter(MOCK_POSITION, H)
INITIAL_POSITION = ift.from_random('normal', H.position.domain)
INITIAL_POSITION = ift.from_random('normal', domain)
ift.plot(signal(MOCK_POSITION), title='ground truth')
......@@ -110,11 +109,12 @@ if __name__ == '__main__':
# number of samples used to estimate the KL
N_samples = 20
for i in range(2):
H =
samples = [H.metric.draw_sample(from_inverse=True)
metric = H(ift.Linearization.make_var(position)).metric
samples = [metric.draw_sample(from_inverse=True)
for _ in range(N_samples)]
KL = ift.SampledKullbachLeiblerDivergence(H, samples)
KL = EnergyAdapter(position, KL)
KL = KL.make_invertible(ic_cg)
KL, convergence = minimizer(KL)
position = KL.position
......@@ -19,41 +19,21 @@
from __future__ import absolute_import, division, print_function
from ..compat import *
from import Energy
from ..utilities import memo, my_sum
from ..operator import Operator
from ..utilities import my_sum
class SampledKullbachLeiblerDivergence(Energy):
class SampledKullbachLeiblerDivergence(Operator):
def __init__(self, h, res_samples):
# MR FIXME: does h have to be a Hamiltonian? Couldn't it be any energy?
h: Hamiltonian
N: Number of samples to be used
super(SampledKullbachLeiblerDivergence, self).__init__(h.position)
super(SampledKullbachLeiblerDivergence, self).__init__()
self._h = h
self._res_samples = res_samples
self._res_samples = tuple(res_samples)
self._energy_list = tuple(
for ss in res_samples)
def at(self, position):
return self.__class__(, self._res_samples)
def value(self):
return (my_sum(map(lambda v: v.value, self._energy_list)) /
def gradient(self):
return (my_sum(map(lambda v: v.gradient, self._energy_list)) /
def metric(self):
return (my_sum(map(lambda v: v.metric, self._energy_list)) *
def __call__(self, x):
return (my_sum(map(lambda v: self._h(x+v), self._res_samples)) *
......@@ -82,7 +82,8 @@ class Linearization(object):
if isinstance(other, (int, float, complex)):
# if other == 0:
# return ...
return Linearization(self._val*other, self._jac*other)
met = None if self._metric is None else self._metric*other
return Linearization(self._val*other, self._jac*other, met)
if isinstance(other, (Field, MultiField)):
d2 = makeOp(other)
return Linearization(self._val*other, d2*self._jac)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment