From b2e21a64f8469c2f11b9c00e750b9dd512d9df2e Mon Sep 17 00:00:00 2001
From: Philipp Arras <parras@mpa-garching.mpg.de>
Date: Wed, 24 Jul 2019 15:47:17 +0200
Subject: [PATCH] Add Reimar's criterion to NewtonCG

---
 nifty5/__init__.py                           |  2 +-
 nifty5/minimization/descent_minimizers.py    | 20 +++---
 nifty5/minimization/iteration_controllers.py | 67 ++++++++++++++++++++
 nifty5/probing.py                            |  1 +
 4 files changed, 80 insertions(+), 10 deletions(-)

diff --git a/nifty5/__init__.py b/nifty5/__init__.py
index 0dfba4cc0..a9326ba4a 100644
--- a/nifty5/__init__.py
+++ b/nifty5/__init__.py
@@ -59,7 +59,7 @@ from .probing import probe_with_posterior_samples, probe_diagonal, \
 from .minimization.line_search import LineSearch
 from .minimization.iteration_controllers import (
     IterationController, GradientNormController, DeltaEnergyController,
-    GradInfNormController)
+    GradInfNormController, AbsDeltaEnergyController)
 from .minimization.minimizer import Minimizer
 from .minimization.conjugate_gradient import ConjugateGradient
 from .minimization.nonlinear_cg import NonlinearCG
diff --git a/nifty5/minimization/descent_minimizers.py b/nifty5/minimization/descent_minimizers.py
index b25c41963..d0520ea5b 100644
--- a/nifty5/minimization/descent_minimizers.py
+++ b/nifty5/minimization/descent_minimizers.py
@@ -19,7 +19,6 @@ import numpy as np
 
 from ..logger import logger
 from .conjugate_gradient import ConjugateGradient
-from .iteration_controllers import GradientNormController
 from .line_search import LineSearch
 from .minimizer import Minimizer
 from .quadratic_energy import QuadraticEnergy
@@ -46,7 +45,7 @@ class DescentMinimizer(Minimizer):
         self._controller = controller
         self.line_searcher = line_searcher
 
-    def __call__(self, energy, preconditioner=None):
+    def __call__(self, energy):
         """Performs the minimization of the provided Energy functional.
 
         Parameters
@@ -82,7 +81,7 @@ class DescentMinimizer(Minimizer):
 
             # compute a step length that reduces energy.value sufficiently
             new_energy, success = self.line_searcher.perform_line_search(
-                energy=energy, pk=self.get_descent_direction(energy),
+                energy=energy, pk=self.get_descent_direction(energy, f_k_minus_1),
                 f_k_minus_1=f_k_minus_1)
             if not success:
                 self.reset()
@@ -163,12 +162,15 @@ class NewtonCG(DescentMinimizer):
         super(NewtonCG, self).__init__(controller=controller,
                                        line_searcher=line_searcher)
 
-    def get_descent_direction(self, energy):
-        g = energy.gradient
-        maggrad = abs(g).sum()
-        termcond = np.min([0.5, np.sqrt(maggrad)]) * maggrad
-        ic = GradientNormController(tol_abs_gradnorm=termcond, p=1)
-        e = QuadraticEnergy(0*energy.position, energy.metric, g)
+    def get_descent_direction(self, energy, f_k_minus_1):
+        from .iteration_controllers import AbsDeltaEnergyController, GradientNormController
+        if f_k_minus_1 is None:
+            ic = GradientNormController(iteration_limit=1)
+        else:
+            alpha = 0.1
+            ediff = alpha*(f_k_minus_1 - energy.value)
+            ic = AbsDeltaEnergyController(ediff, iteration_limit=200, name='    Internal', convergence_level=1)
+        e = QuadraticEnergy(0*energy.position, energy.metric, energy.gradient)
         e, conv = ConjugateGradient(ic, nreset=np.inf)(e)
         if conv == ic.ERROR:
             raise RuntimeError
diff --git a/nifty5/minimization/iteration_controllers.py b/nifty5/minimization/iteration_controllers.py
index 88142a869..efaaa14e8 100644
--- a/nifty5/minimization/iteration_controllers.py
+++ b/nifty5/minimization/iteration_controllers.py
@@ -276,3 +276,70 @@ class DeltaEnergyController(IterationController):
             return self.CONVERGED
 
         return self.CONTINUE
+
+
+class AbsDeltaEnergyController(IterationController):
+    """An iteration controller checking (mainly) the energy change from one
+    iteration to the next.
+
+    Parameters
+    ----------
+    tol_rel_deltaE : float
+        If the difference between the last and current energies divided by
+        the current energy is below this value, the convergence counter will
+        be increased in this iteration.
+    convergence_level : int, default=1
+        The number which the convergence counter must reach before the
+        iteration is considered to be converged
+    iteration_limit : int, optional
+        The maximum number of iterations that will be carried out.
+    name : str, optional
+        if supplied, this string and some diagnostic information will be
+        printed after every iteration
+    """
+
+    def __init__(self, deltaE, convergence_level=1,
+                 iteration_limit=None, name=None):
+        self._deltaE = deltaE
+        self._convergence_level = convergence_level
+        self._iteration_limit = iteration_limit
+        self._name = name
+
+    def start(self, energy):
+        self._itcount = -1
+        self._ccount = 0
+        self._Eold = 0.
+        return self.check(energy)
+
+    def check(self, energy):
+        self._itcount += 1
+
+        inclvl = False
+        Eval = energy.value
+        diff = abs(self._Eold-Eval)
+        if self._itcount > 0:
+            if diff < self._deltaE:
+                inclvl = True
+        self._Eold = Eval
+        if inclvl:
+            self._ccount += 1
+        else:
+            self._ccount = max(0, self._ccount-1)
+
+        # report
+        if self._name is not None:
+            logger.info(
+                "{}: Iteration #{} energy={:.6E} diff={:.6E} crit={:.6E}"
+                .format(self._name, self._itcount, Eval, diff, self._deltaE))
+
+        # Are we done?
+        if self._iteration_limit is not None:
+            if self._itcount >= self._iteration_limit:
+                logger.warning(
+                    "{} Iteration limit reached. Assuming convergence"
+                    .format("" if self._name is None else self._name+": "))
+                return self.CONVERGED
+        if self._ccount >= self._convergence_level:
+            return self.CONVERGED
+
+        return self.CONTINUE
diff --git a/nifty5/probing.py b/nifty5/probing.py
index 1eca3105e..2c1ad8a3f 100644
--- a/nifty5/probing.py
+++ b/nifty5/probing.py
@@ -138,6 +138,7 @@ def probe_diagonal(op, nprobes, random_type="pm1"):
 
 
 def approximation2endo(op, nsamples):
+    print('Calculate preconditioner')
     sc = StatCalculator()
     for _ in range(nsamples):
         sc.add(op.draw_sample())
-- 
GitLab