Commit 0d71c455 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

no more custom convergence measures

parent c7e32d13
Pipeline #18994 passed with stage
in 4 minutes and 13 seconds
...@@ -90,7 +90,7 @@ if __name__ == "__main__": ...@@ -90,7 +90,7 @@ if __name__ == "__main__":
# Wiener filter # Wiener filter
j = R_harmonic.adjoint_times(N.inverse_times(data)) j = R_harmonic.adjoint_times(N.inverse_times(data))
ctrl = ift.DefaultIterationController(verbose=True, tol_custom=1e-3, convergence_level=3) ctrl = ift.DefaultIterationController(verbose=True, iteration_limit=100)
inverter = ift.ConjugateGradient(controller=ctrl) inverter = ift.ConjugateGradient(controller=ctrl)
wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic, inverter=inverter) wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic, inverter=inverter)
......
...@@ -95,10 +95,8 @@ class LineEnergy(object): ...@@ -95,10 +95,8 @@ class LineEnergy(object):
""" """
return self.__class__(line_position, return LineEnergy(line_position, self.energy, self.line_direction,
self.energy, offset=self.line_position)
self.line_direction,
offset=self.line_position)
@property @property
def value(self): def value(self):
......
...@@ -8,23 +8,21 @@ class QuadraticEnergy(Energy): ...@@ -8,23 +8,21 @@ class QuadraticEnergy(Energy):
position-independent. position-independent.
""" """
def __init__(self, position, A, b, _grad=None, _bnorm=None): def __init__(self, position, A, b, _grad=None):
super(QuadraticEnergy, self).__init__(position=position) super(QuadraticEnergy, self).__init__(position=position)
self._A = A self._A = A
self._b = b self._b = b
self._bnorm = _bnorm
if _grad is not None: if _grad is not None:
self._Ax = _grad + self._b self._Ax = _grad + self._b
else: else:
self._Ax = self._A(self.position) self._Ax = self._A(self.position)
def at(self, position): def at(self, position):
return self.__class__(position=position, A=self._A, b=self._b, return QuadraticEnergy(position=position, A=self._A, b=self._b)
_bnorm=self.norm_b)
def at_with_grad(self, position, grad): def at_with_grad(self, position, grad):
return self.__class__(position=position, A=self._A, b=self._b, return QuadraticEnergy(position=position, A=self._A, b=self._b,
_grad=grad, _bnorm=self.norm_b) _grad=grad)
@property @property
@memo @memo
...@@ -39,9 +37,3 @@ class QuadraticEnergy(Energy): ...@@ -39,9 +37,3 @@ class QuadraticEnergy(Energy):
@property @property
def curvature(self): def curvature(self):
return self._A return self._A
@property
def norm_b(self):
if self._bnorm is None:
self._bnorm = self._b.norm()
return self._bnorm
...@@ -73,7 +73,6 @@ class ConjugateGradient(Minimizer): ...@@ -73,7 +73,6 @@ class ConjugateGradient(Minimizer):
if status != controller.CONTINUE: if status != controller.CONTINUE:
return energy, status return energy, status
norm_b = energy.norm_b
r = energy.gradient r = energy.gradient
if preconditioner is not None: if preconditioner is not None:
d = preconditioner(r) d = preconditioner(r)
...@@ -111,9 +110,7 @@ class ConjugateGradient(Minimizer): ...@@ -111,9 +110,7 @@ class ConjugateGradient(Minimizer):
if gamma == 0: if gamma == 0:
return energy, controller.CONVERGED return energy, controller.CONVERGED
status = self._controller.check(energy, status = self._controller.check(energy)
custom_measure=np.sqrt(gamma) /
norm_b)
if status != controller.CONTINUE: if status != controller.CONTINUE:
return energy, status return energy, status
......
...@@ -22,26 +22,25 @@ from .iteration_controller import IterationController ...@@ -22,26 +22,25 @@ from .iteration_controller import IterationController
class DefaultIterationController(IterationController): class DefaultIterationController(IterationController):
def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None, def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None,
tol_custom=None, convergence_level=1, iteration_limit=None, convergence_level=1, iteration_limit=None,
name=None, verbose=None): name=None, verbose=None):
super(DefaultIterationController, self).__init__() super(DefaultIterationController, self).__init__()
self._tol_abs_gradnorm = tol_abs_gradnorm self._tol_abs_gradnorm = tol_abs_gradnorm
self._tol_rel_gradnorm = tol_rel_gradnorm self._tol_rel_gradnorm = tol_rel_gradnorm
self._tol_custom = tol_custom
self._convergence_level = convergence_level self._convergence_level = convergence_level
self._iteration_limit = iteration_limit self._iteration_limit = iteration_limit
self._name = name self._name = name
self._verbose = verbose self._verbose = verbose
def start(self, energy, custom_measure=None): def start(self, energy):
self._itcount = -1 self._itcount = -1
self._ccount = 0 self._ccount = 0
if self._tol_rel_gradnorm is not None: if self._tol_rel_gradnorm is not None:
self._tol_rel_gradnorm_now = self._tol_rel_gradnorm \ self._tol_rel_gradnorm_now = self._tol_rel_gradnorm \
* energy.gradient_norm * energy.gradient_norm
return self.check(energy, custom_measure) return self.check(energy)
def check(self, energy, custom_measure=None): def check(self, energy):
self._itcount += 1 self._itcount += 1
inclvl = False inclvl = False
...@@ -51,9 +50,6 @@ class DefaultIterationController(IterationController): ...@@ -51,9 +50,6 @@ class DefaultIterationController(IterationController):
if self._tol_rel_gradnorm is not None: if self._tol_rel_gradnorm is not None:
if energy.gradient_norm <= self._tol_rel_gradnorm_now: if energy.gradient_norm <= self._tol_rel_gradnorm_now:
inclvl = True inclvl = True
if self._tol_custom is not None and custom_measure is not None:
if custom_measure <= self._tol_custom:
inclvl = True
if inclvl: if inclvl:
self._ccount += 1 self._ccount += 1
else: else:
...@@ -67,8 +63,6 @@ class DefaultIterationController(IterationController): ...@@ -67,8 +63,6 @@ class DefaultIterationController(IterationController):
msg += " Iteration #" + str(self._itcount) msg += " Iteration #" + str(self._itcount)
msg += " energy=" + str(energy.value) msg += " energy=" + str(energy.value)
msg += " gradnorm=" + str(energy.gradient_norm) msg += " gradnorm=" + str(energy.gradient_norm)
if custom_measure is not None:
msg += " custom=" + str(custom_measure)
msg += " clvl=" + str(self._ccount) msg += " clvl=" + str(self._ccount)
print(msg) print(msg)
# self.logger.info(msg) # self.logger.info(msg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment