diff --git a/demos/getting_started_3b.py b/demos/getting_started_3b.py index 08c99b3251397adfbbde9f8c393593c568812214..2a7ce7d27f3c29d5f61fad49635fd9bac2b7c33d 100644 --- a/demos/getting_started_3b.py +++ b/demos/getting_started_3b.py @@ -97,9 +97,8 @@ if __name__ == '__main__': # build model Hamiltonian H = ift.Hamiltonian(likelihood, ic_sampling) - H = EnergyAdapter(MOCK_POSITION, H) - INITIAL_POSITION = ift.from_random('normal', H.position.domain) + INITIAL_POSITION = ift.from_random('normal', domain) position = INITIAL_POSITION ift.plot(signal(MOCK_POSITION), title='ground truth') @@ -110,11 +109,12 @@ if __name__ == '__main__': # number of samples used to estimate the KL N_samples = 20 for i in range(2): - H = H.at(position) - samples = [H.metric.draw_sample(from_inverse=True) + metric = H(ift.Linearization.make_var(position)).metric + samples = [metric.draw_sample(from_inverse=True) for _ in range(N_samples)] KL = ift.SampledKullbachLeiblerDivergence(H, samples) + KL = EnergyAdapter(position, KL) KL = KL.make_invertible(ic_cg) KL, convergence = minimizer(KL) position = KL.position diff --git a/nifty5/energies/kl.py b/nifty5/energies/kl.py index 403b70bab79d3de61d9b3b36f2bb04453cc231ee..99fb600ed192e8698419a4e60baca529cc544104 100644 --- a/nifty5/energies/kl.py +++ b/nifty5/energies/kl.py @@ -19,41 +19,21 @@ from __future__ import absolute_import, division, print_function from ..compat import * -from ..minimization.energy import Energy -from ..utilities import memo, my_sum +from ..operator import Operator +from ..utilities import my_sum -class SampledKullbachLeiblerDivergence(Energy): +class SampledKullbachLeiblerDivergence(Operator): def __init__(self, h, res_samples): """ # MR FIXME: does h have to be a Hamiltonian? Couldn't it be any energy? h: Hamiltonian N: Number of samples to be used """ - super(SampledKullbachLeiblerDivergence, self).__init__(h.position) + super(SampledKullbachLeiblerDivergence, self).__init__() self._h = h - self._res_samples = res_samples + self._res_samples = tuple(res_samples) - self._energy_list = tuple(h.at(self.position+ss) - for ss in res_samples) - - def at(self, position): - return self.__class__(self._h.at(position), self._res_samples) - - @property - @memo - def value(self): - return (my_sum(map(lambda v: v.value, self._energy_list)) / - len(self._energy_list)) - - @property - @memo - def gradient(self): - return (my_sum(map(lambda v: v.gradient, self._energy_list)) / - len(self._energy_list)) - - @property - @memo - def metric(self): - return (my_sum(map(lambda v: v.metric, self._energy_list)) * - (1./len(self._energy_list))) + def __call__(self, x): + return (my_sum(map(lambda v: self._h(x+v), self._res_samples)) * + (1./len(self._res_samples))) diff --git a/nifty5/linearization.py b/nifty5/linearization.py index b639359a4aaefac172ff0a8114f7889ff2fa07f8..dcc9fe5184c562d5f65bad1c2f2d0e4b186d301a 100644 --- a/nifty5/linearization.py +++ b/nifty5/linearization.py @@ -82,7 +82,8 @@ class Linearization(object): if isinstance(other, (int, float, complex)): # if other == 0: # return ... - return Linearization(self._val*other, self._jac*other) + met = None if self._metric is None else self._metric*other + return Linearization(self._val*other, self._jac*other, met) if isinstance(other, (Field, MultiField)): d2 = makeOp(other) return Linearization(self._val*other, d2*self._jac)