Commit 77f85939 authored by Philipp Arras's avatar Philipp Arras

Migrate hamiltonian and energy from global newton

parent 67e96903
Pipeline #31977 failed with stages
in 12 seconds
......@@ -21,5 +21,7 @@ from .logger import logger
from .multi import *
from .energies import *
# We deliberately don't set __all__ here, because we don't want people to do a
# "from nifty5 import *"; that would swamp the global namespace.
from .hamiltonian import Hamiltonian
from .kl import SampledKullbachLeiblerDivergence
from nifty5 import Energy, InversionEnabler, SamplingEnabler, Variable, memo
from nifty5.library import UnitLogGauss
class Hamiltonian(Energy):
def __init__(self, lh, iteration_controller,
iteration_controller_sampling=None):
"""
lh: Likelihood (energy object)
prior:
"""
super(Hamiltonian, self).__init__(lh.position)
self._lh = lh
self._ic = iteration_controller
if iteration_controller_sampling is None:
self._ic_samp = iteration_controller
else:
self._ic_samp = iteration_controller_sampling
self._prior = UnitLogGauss(Variable(self.position))
self._precond = self._prior.curvature
def at(self, position):
return self.__class__(self._lh.at(position), self._ic, self._ic_samp)
@property
@memo
def value(self):
return self._lh.value + self._prior.value
@property
@memo
def gradient(self):
return self._lh.gradient + self._prior.gradient
@property
@memo
def curvature(self):
prior_curv = self._prior.curvature
c = SamplingEnabler(self._lh.curvature, prior_curv.inverse,
self._ic_samp, prior_curv.inverse)
return InversionEnabler(c, self._ic, self._precond)
def __str__(self):
res = 'Likelihood:\t{:.2E}\n'.format(self._lh.value)
res += 'Prior:\t\t{:.2E}'.format(self._prior.value)
return res
from nifty5 import Energy, InversionEnabler, ScalingOperator, memo
class SampledKullbachLeiblerDivergence(Energy):
def __init__(self, h, res_samples, iteration_controller):
"""
h: Hamiltonian
N: Number of samples to be used
"""
super(SampledKullbachLeiblerDivergence, self).__init__(h.position)
self._h = h
self._res_samples = res_samples
self._iteration_controller = iteration_controller
self._energy_list = []
for ss in res_samples:
e = h.at(self.position+ss)
self._energy_list.append(e)
def at(self, position):
return self.__class__(self._h.at(position), self._res_samples,
self._iteration_controller)
@property
@memo
def value(self):
v = self._energy_list[0].value
for energy in self._energy_list[1:]:
v += energy.value
return v / len(self._energy_list)
@property
@memo
def gradient(self):
g = self._energy_list[0].gradient
for energy in self._energy_list[1:]:
g += energy.gradient
return g / len(self._energy_list)
@property
@memo
def curvature(self):
# MR FIXME: This looks a bit strange...
approx = self._energy_list[-1]._prior.curvature
curvature_list = [e.curvature for e in self._energy_list]
op = curvature_list[0]
for curv in curvature_list[1:]:
op = op + curv
op = op * ScalingOperator(1./len(curvature_list), op.domain)
return InversionEnabler(op, self._iteration_controller, approx)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment