Commit 007fb8dc authored by Martin Reinecke's avatar Martin Reinecke

adjust KL

parent 34a385d9
......@@ -97,9 +97,8 @@ if __name__ == '__main__':
# build model Hamiltonian
H = ift.Hamiltonian(likelihood, ic_sampling)
H = EnergyAdapter(MOCK_POSITION, H)
INITIAL_POSITION = ift.from_random('normal', H.position.domain)
INITIAL_POSITION = ift.from_random('normal', domain)
position = INITIAL_POSITION
ift.plot(signal(MOCK_POSITION), title='ground truth')
......@@ -110,11 +109,12 @@ if __name__ == '__main__':
# number of samples used to estimate the KL
N_samples = 20
for i in range(2):
H = H.at(position)
samples = [H.metric.draw_sample(from_inverse=True)
metric = H(ift.Linearization.make_var(position)).metric
samples = [metric.draw_sample(from_inverse=True)
for _ in range(N_samples)]
KL = ift.SampledKullbachLeiblerDivergence(H, samples)
KL = EnergyAdapter(position, KL)
KL = KL.make_invertible(ic_cg)
KL, convergence = minimizer(KL)
position = KL.position
......
......@@ -19,41 +19,21 @@
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..minimization.energy import Energy
from ..utilities import memo, my_sum
from ..operator import Operator
from ..utilities import my_sum
class SampledKullbachLeiblerDivergence(Energy):
class SampledKullbachLeiblerDivergence(Operator):
def __init__(self, h, res_samples):
"""
# MR FIXME: does h have to be a Hamiltonian? Couldn't it be any energy?
h: Hamiltonian
N: Number of samples to be used
"""
super(SampledKullbachLeiblerDivergence, self).__init__(h.position)
super(SampledKullbachLeiblerDivergence, self).__init__()
self._h = h
self._res_samples = res_samples
self._res_samples = tuple(res_samples)
self._energy_list = tuple(h.at(self.position+ss)
for ss in res_samples)
def at(self, position):
return self.__class__(self._h.at(position), self._res_samples)
@property
@memo
def value(self):
return (my_sum(map(lambda v: v.value, self._energy_list)) /
len(self._energy_list))
@property
@memo
def gradient(self):
return (my_sum(map(lambda v: v.gradient, self._energy_list)) /
len(self._energy_list))
@property
@memo
def metric(self):
return (my_sum(map(lambda v: v.metric, self._energy_list)) *
(1./len(self._energy_list)))
def __call__(self, x):
return (my_sum(map(lambda v: self._h(x+v), self._res_samples)) *
(1./len(self._res_samples)))
......@@ -82,7 +82,8 @@ class Linearization(object):
if isinstance(other, (int, float, complex)):
# if other == 0:
# return ...
return Linearization(self._val*other, self._jac*other)
met = None if self._metric is None else self._metric*other
return Linearization(self._val*other, self._jac*other, met)
if isinstance(other, (Field, MultiField)):
d2 = makeOp(other)
return Linearization(self._val*other, d2*self._jac)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment