Commit 54b34fbe authored by Martin Reinecke's avatar Martin Reinecke

more adjustments

parent 5de1cebf
......@@ -79,7 +79,8 @@ if __name__ == '__main__':
minimizer = ift.RelaxedNewton(ic_newton)
# Minimize the Hamiltonian
H = ift.Hamiltonian(likelihood, ic_cg)
H = ift.Hamiltonian(likelihood)
H = H.makeInvertible(ic_cg)
H, convergence = minimizer(H)
# Plot results
......
......@@ -64,8 +64,7 @@ if __name__ == '__main__':
minimizer = ift.RelaxedNewton(ic_newton)
# build model Hamiltonian
H = ift.Hamiltonian(likelihood, ic_cg,
iteration_controller_sampling=ic_sampling)
H = ift.Hamiltonian(likelihood, ic_sampling)
INITIAL_POSITION = ift.from_random('normal', H.position.domain)
position = INITIAL_POSITION
......
......@@ -19,25 +19,23 @@
from ..library.gaussian_energy import GaussianEnergy
from ..minimization.energy import Energy
from ..models.variable import Variable
from ..operators import InversionEnabler, SamplingEnabler
from ..operators.sampling_enabler import SamplingEnabler
from ..utilities import memo
class Hamiltonian(Energy):
def __init__(self, lh, iteration_controller,
iteration_controller_sampling=None):
def __init__(self, lh, iteration_controller_sampling=None):
"""
lh: Likelihood (energy object)
prior:
"""
super(Hamiltonian, self).__init__(lh.position)
self._lh = lh
self._ic = iteration_controller
self._ic_samp = iteration_controller_sampling
self._prior = GaussianEnergy(Variable(self.position))
def at(self, position):
return self.__class__(self._lh.at(position), self._ic, self._ic_samp)
return self.__class__(self._lh.at(position), self._ic_samp)
@property
@memo
......@@ -54,11 +52,10 @@ class Hamiltonian(Energy):
def curvature(self):
prior_curv = self._prior.curvature
if self._ic_samp is None:
c = self._lh.curvature + prior_curv
return self._lh.curvature + prior_curv
else:
c = SamplingEnabler(self._lh.curvature, prior_curv.inverse,
self._ic_samp, prior_curv.inverse)
return InversionEnabler(c, self._ic, prior_curv)
return SamplingEnabler(self._lh.curvature, prior_curv.inverse,
self._ic_samp, prior_curv.inverse)
def __str__(self):
res = 'Likelihood:\t{:.2E}\n'.format(self._lh.value)
......
......@@ -80,12 +80,9 @@ class DescentMinimizer(Minimizer):
return energy, controller.CONVERGED
# compute a step length that reduces energy.value sufficiently
try:
new_energy, success = self.line_searcher.perform_line_search(
energy=energy, pk=self.get_descent_direction(energy),
f_k_minus_1=f_k_minus_1)
except ValueError:
return energy, controller.ERROR
new_energy, success = self.line_searcher.perform_line_search(
energy=energy, pk=self.get_descent_direction(energy),
f_k_minus_1=f_k_minus_1)
if not success:
self.reset()
......
......@@ -274,9 +274,9 @@ class LinearOperator(NiftyMetaBase()):
def _check_mode(self, mode):
if not self._validMode[mode]:
raise ValueError("invalid operator mode specified")
raise NotImplementedError("invalid operator mode specified")
if mode & self.capability == 0:
raise ValueError("requested operator mode is not supported")
raise NotImplementedError("requested operator mode is not supported")
def _check_input(self, x, mode):
self._check_mode(mode)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment