From 0d71c455d6ff7271583822af7ad2dfb6787a6f11 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Fri, 29 Sep 2017 18:30:42 +0200
Subject: [PATCH] no more custom convergence measures

---
 demos/paper_demos/cartesian_wiener_filter.py     |  2 +-
 nifty/energies/line_energy.py                    |  6 ++----
 nifty/energies/quadratic_energy.py               | 16 ++++------------
 nifty/minimization/conjugate_gradient.py         |  5 +----
 .../minimization/default_iteration_controller.py | 14 ++++----------
 5 files changed, 12 insertions(+), 31 deletions(-)

diff --git a/demos/paper_demos/cartesian_wiener_filter.py b/demos/paper_demos/cartesian_wiener_filter.py
index eeab6ab5c..69cc161a7 100644
--- a/demos/paper_demos/cartesian_wiener_filter.py
+++ b/demos/paper_demos/cartesian_wiener_filter.py
@@ -90,7 +90,7 @@ if __name__ == "__main__":
 
     # Wiener filter
     j = R_harmonic.adjoint_times(N.inverse_times(data))
-    ctrl = ift.DefaultIterationController(verbose=True, tol_custom=1e-3, convergence_level=3)
+    ctrl = ift.DefaultIterationController(verbose=True, iteration_limit=100)
     inverter = ift.ConjugateGradient(controller=ctrl)
     wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic, inverter=inverter)
 
diff --git a/nifty/energies/line_energy.py b/nifty/energies/line_energy.py
index f0e67deef..7215072bc 100644
--- a/nifty/energies/line_energy.py
+++ b/nifty/energies/line_energy.py
@@ -95,10 +95,8 @@ class LineEnergy(object):
 
         """
 
-        return self.__class__(line_position,
-                              self.energy,
-                              self.line_direction,
-                              offset=self.line_position)
+        return LineEnergy(line_position, self.energy, self.line_direction,
+                          offset=self.line_position)
 
     @property
     def value(self):
diff --git a/nifty/energies/quadratic_energy.py b/nifty/energies/quadratic_energy.py
index 2f97da569..7194c66be 100644
--- a/nifty/energies/quadratic_energy.py
+++ b/nifty/energies/quadratic_energy.py
@@ -8,23 +8,21 @@ class QuadraticEnergy(Energy):
     position-independent.
     """
 
-    def __init__(self, position, A, b, _grad=None, _bnorm=None):
+    def __init__(self, position, A, b, _grad=None):
         super(QuadraticEnergy, self).__init__(position=position)
         self._A = A
         self._b = b
-        self._bnorm = _bnorm
         if _grad is not None:
             self._Ax = _grad + self._b
         else:
             self._Ax = self._A(self.position)
 
     def at(self, position):
-        return self.__class__(position=position, A=self._A, b=self._b,
-                              _bnorm=self.norm_b)
+        return QuadraticEnergy(position=position, A=self._A, b=self._b)
 
     def at_with_grad(self, position, grad):
-        return self.__class__(position=position, A=self._A, b=self._b,
-                              _grad=grad, _bnorm=self.norm_b)
+        return QuadraticEnergy(position=position, A=self._A, b=self._b,
+                               _grad=grad)
 
     @property
     @memo
@@ -39,9 +37,3 @@ class QuadraticEnergy(Energy):
     @property
     def curvature(self):
         return self._A
-
-    @property
-    def norm_b(self):
-        if self._bnorm is None:
-            self._bnorm = self._b.norm()
-        return self._bnorm
diff --git a/nifty/minimization/conjugate_gradient.py b/nifty/minimization/conjugate_gradient.py
index 98ab46ecb..ed458d94f 100644
--- a/nifty/minimization/conjugate_gradient.py
+++ b/nifty/minimization/conjugate_gradient.py
@@ -73,7 +73,6 @@ class ConjugateGradient(Minimizer):
         if status != controller.CONTINUE:
             return energy, status
 
-        norm_b = energy.norm_b
         r = energy.gradient
         if preconditioner is not None:
             d = preconditioner(r)
@@ -111,9 +110,7 @@ class ConjugateGradient(Minimizer):
             if gamma == 0:
                 return energy, controller.CONVERGED
 
-            status = self._controller.check(energy,
-                                            custom_measure=np.sqrt(gamma) /
-                                            norm_b)
+            status = self._controller.check(energy)
             if status != controller.CONTINUE:
                 return energy, status
 
diff --git a/nifty/minimization/default_iteration_controller.py b/nifty/minimization/default_iteration_controller.py
index cd596e3e4..fa5353dc6 100644
--- a/nifty/minimization/default_iteration_controller.py
+++ b/nifty/minimization/default_iteration_controller.py
@@ -22,26 +22,25 @@ from .iteration_controller import IterationController
 
 class DefaultIterationController(IterationController):
     def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None,
-                 tol_custom=None, convergence_level=1, iteration_limit=None,
+                 convergence_level=1, iteration_limit=None,
                  name=None, verbose=None):
         super(DefaultIterationController, self).__init__()
         self._tol_abs_gradnorm = tol_abs_gradnorm
         self._tol_rel_gradnorm = tol_rel_gradnorm
-        self._tol_custom = tol_custom
         self._convergence_level = convergence_level
         self._iteration_limit = iteration_limit
         self._name = name
         self._verbose = verbose
 
-    def start(self, energy, custom_measure=None):
+    def start(self, energy):
         self._itcount = -1
         self._ccount = 0
         if self._tol_rel_gradnorm is not None:
             self._tol_rel_gradnorm_now = self._tol_rel_gradnorm \
                                        * energy.gradient_norm
-        return self.check(energy, custom_measure)
+        return self.check(energy)
 
-    def check(self, energy, custom_measure=None):
+    def check(self, energy):
         self._itcount += 1
 
         inclvl = False
@@ -51,9 +50,6 @@ class DefaultIterationController(IterationController):
         if self._tol_rel_gradnorm is not None:
             if energy.gradient_norm <= self._tol_rel_gradnorm_now:
                 inclvl = True
-        if self._tol_custom is not None and custom_measure is not None:
-            if custom_measure <= self._tol_custom:
-                inclvl = True
         if inclvl:
             self._ccount += 1
         else:
@@ -67,8 +63,6 @@ class DefaultIterationController(IterationController):
             msg += " Iteration #" + str(self._itcount)
             msg += " energy=" + str(energy.value)
             msg += " gradnorm=" + str(energy.gradient_norm)
-            if custom_measure is not None:
-                msg += " custom=" + str(custom_measure)
             msg += " clvl=" + str(self._ccount)
             print(msg)
             # self.logger.info(msg)
-- 
GitLab