Commit 54b34fbe authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more adjustments

parent 5de1cebf
...@@ -79,7 +79,8 @@ if __name__ == '__main__': ...@@ -79,7 +79,8 @@ if __name__ == '__main__':
minimizer = ift.RelaxedNewton(ic_newton) minimizer = ift.RelaxedNewton(ic_newton)
# Minimize the Hamiltonian # Minimize the Hamiltonian
H = ift.Hamiltonian(likelihood, ic_cg) H = ift.Hamiltonian(likelihood)
H = H.makeInvertible(ic_cg)
H, convergence = minimizer(H) H, convergence = minimizer(H)
# Plot results # Plot results
......
...@@ -64,8 +64,7 @@ if __name__ == '__main__': ...@@ -64,8 +64,7 @@ if __name__ == '__main__':
minimizer = ift.RelaxedNewton(ic_newton) minimizer = ift.RelaxedNewton(ic_newton)
# build model Hamiltonian # build model Hamiltonian
H = ift.Hamiltonian(likelihood, ic_cg, H = ift.Hamiltonian(likelihood, ic_sampling)
iteration_controller_sampling=ic_sampling)
INITIAL_POSITION = ift.from_random('normal', H.position.domain) INITIAL_POSITION = ift.from_random('normal', H.position.domain)
position = INITIAL_POSITION position = INITIAL_POSITION
......
...@@ -19,25 +19,23 @@ ...@@ -19,25 +19,23 @@
from ..library.gaussian_energy import GaussianEnergy from ..library.gaussian_energy import GaussianEnergy
from ..minimization.energy import Energy from ..minimization.energy import Energy
from ..models.variable import Variable from ..models.variable import Variable
from ..operators import InversionEnabler, SamplingEnabler from ..operators.sampling_enabler import SamplingEnabler
from ..utilities import memo from ..utilities import memo
class Hamiltonian(Energy): class Hamiltonian(Energy):
def __init__(self, lh, iteration_controller, def __init__(self, lh, iteration_controller_sampling=None):
iteration_controller_sampling=None):
""" """
lh: Likelihood (energy object) lh: Likelihood (energy object)
prior: prior:
""" """
super(Hamiltonian, self).__init__(lh.position) super(Hamiltonian, self).__init__(lh.position)
self._lh = lh self._lh = lh
self._ic = iteration_controller
self._ic_samp = iteration_controller_sampling self._ic_samp = iteration_controller_sampling
self._prior = GaussianEnergy(Variable(self.position)) self._prior = GaussianEnergy(Variable(self.position))
def at(self, position): def at(self, position):
return self.__class__(self._lh.at(position), self._ic, self._ic_samp) return self.__class__(self._lh.at(position), self._ic_samp)
@property @property
@memo @memo
...@@ -54,11 +52,10 @@ class Hamiltonian(Energy): ...@@ -54,11 +52,10 @@ class Hamiltonian(Energy):
def curvature(self): def curvature(self):
prior_curv = self._prior.curvature prior_curv = self._prior.curvature
if self._ic_samp is None: if self._ic_samp is None:
c = self._lh.curvature + prior_curv return self._lh.curvature + prior_curv
else: else:
c = SamplingEnabler(self._lh.curvature, prior_curv.inverse, return SamplingEnabler(self._lh.curvature, prior_curv.inverse,
self._ic_samp, prior_curv.inverse) self._ic_samp, prior_curv.inverse)
return InversionEnabler(c, self._ic, prior_curv)
def __str__(self): def __str__(self):
res = 'Likelihood:\t{:.2E}\n'.format(self._lh.value) res = 'Likelihood:\t{:.2E}\n'.format(self._lh.value)
......
...@@ -80,12 +80,9 @@ class DescentMinimizer(Minimizer): ...@@ -80,12 +80,9 @@ class DescentMinimizer(Minimizer):
return energy, controller.CONVERGED return energy, controller.CONVERGED
# compute a step length that reduces energy.value sufficiently # compute a step length that reduces energy.value sufficiently
try: new_energy, success = self.line_searcher.perform_line_search(
new_energy, success = self.line_searcher.perform_line_search( energy=energy, pk=self.get_descent_direction(energy),
energy=energy, pk=self.get_descent_direction(energy), f_k_minus_1=f_k_minus_1)
f_k_minus_1=f_k_minus_1)
except ValueError:
return energy, controller.ERROR
if not success: if not success:
self.reset() self.reset()
......
...@@ -274,9 +274,9 @@ class LinearOperator(NiftyMetaBase()): ...@@ -274,9 +274,9 @@ class LinearOperator(NiftyMetaBase()):
def _check_mode(self, mode): def _check_mode(self, mode):
if not self._validMode[mode]: if not self._validMode[mode]:
raise ValueError("invalid operator mode specified") raise NotImplementedError("invalid operator mode specified")
if mode & self.capability == 0: if mode & self.capability == 0:
raise ValueError("requested operator mode is not supported") raise NotImplementedError("requested operator mode is not supported")
def _check_input(self, x, mode): def _check_input(self, x, mode):
self._check_mode(mode) self._check_mode(mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment