Commit ca2c5d10 authored by Theo Steininger's avatar Theo Steininger
Browse files

Modified memo decorator such that injection is possible. Added callback to IterationController.

parent 0155aecd
......@@ -54,8 +54,8 @@ if __name__ == "__main__":
data_domain = R.target[0]
R_harmonic = ComposedOperator([fft, R], default_spaces=[0, 0])
ndiag = Field(data_domain,mock_signal.var()/signal_to_noise).weight(1)
N = DiagonalOperator(data_domain,ndiag)
ndiag = Field(data_domain, mock_signal.var()/signal_to_noise).weight(1)
N = DiagonalOperator(data_domain, ndiag)
noise = Field.from_random(domain=data_domain,
random_type='normal',
std=mock_signal.std()/np.sqrt(signal_to_noise),
......
......@@ -68,16 +68,14 @@ class Energy(with_metaclass(NiftyMeta,
def __init__(self, position, gradient=None, curvature=None):
super(Energy, self).__init__()
self._cache = {}
self._position = position.copy()
self._cache = {}
if gradient is not None:
key = id(self.gradient)
self._cache[key] = gradient
self._cache['gradient'] = gradient
if curvature is not None:
key = id(self.curvature)
self._cache[key] = curvature
self._cache['curvature'] = curvature
def at(self, position, gradient=None, curvature=None):
""" Initializes and returns a new Energy object at the new position.
......
......@@ -18,9 +18,11 @@
def memo(f):
name = id(f)
name = f.__name__
def wrapped_f(self):
if not hasattr(self, "_cache"):
self._cache = {}
try:
return self._cache[name]
except KeyError:
......
......@@ -28,7 +28,6 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
def __init__(self, R, N, S, d, position, inverter=None, fft4exp=None,
offset=None, **kwargs):
self._cache = {}
self.R = R
self.N = N
self.S = S
......
......@@ -110,4 +110,5 @@ class ConjugateGradient(Minimizer):
return energy, controller.ERROR
r -= q * alpha
energy = energy.at_with_grad(energy.position+d*alpha, -r)
energy = energy.at(position=energy.position + d*alpha,
gradient=-r)
......@@ -22,8 +22,8 @@ from .iteration_controller import IterationController
class GradientNormController(IterationController):
def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None,
convergence_level=1, iteration_limit=None):
super(GradientNormController, self).__init__()
convergence_level=1, iteration_limit=None, callback=None):
super(GradientNormController, self).__init__(callback=callback)
self._tol_abs_gradnorm = tol_abs_gradnorm
self._tol_rel_gradnorm = tol_rel_gradnorm
self._tol_rel_gradnorm_now = None
......@@ -38,7 +38,9 @@ class GradientNormController(IterationController):
* energy.gradient_norm
def check(self, energy):
self._iteration_count += 1
super_check = super(GradientNormController, self).check(energy)
if super_check != self.CONTINUE:
return super_check
# check if position is at a flat point
if energy.gradient_norm == 0:
......
......@@ -18,7 +18,7 @@
from builtins import range
import abc
from ..nifty_meta import NiftyMeta
from ...nifty_meta import NiftyMeta
from keepers import Loggable
from future.utils import with_metaclass
......@@ -45,9 +45,10 @@ class IterationController(
CONVERGED, CONTINUE, STOPPED, ERROR = list(range(4))
def __init__(self):
def __init__(self, callback=None):
self._iteration_count = 0
self._convergence_count = 0
self.callback = callback
@property
def iteration_count(self):
......@@ -69,7 +70,6 @@ class IterationController(
raise NotImplementedError
@abc.abstractmethod
def check(self, energy):
"""
Parameters
......@@ -82,4 +82,12 @@ class IterationController(
status : integer status, can be CONVERGED, CONTINUE or ERROR
"""
raise NotImplementedError
self._iteration_count += 1
if self.callback is not None:
try:
self.callback(energy, self._iteration_count)
except StopIteration:
self.logger.info("Minimization was stopped by callback "
"function.")
return self.STOPPED
return self.CONTINUE
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment