Commit 12c5065b authored by Philipp Arras's avatar Philipp Arras

Merge branch 'AbsDeltaEnergyController' into 'NIFTy_5'

Add AbsDeltaEnergyController to NIFTy

See merge request !341
parents 5ded9f44 21053a6b
Pipeline #61020 passed with stages
in 23 minutes and 4 seconds
......@@ -59,7 +59,7 @@ from .probing import probe_with_posterior_samples, probe_diagonal, \
from .minimization.line_search import LineSearch
from .minimization.iteration_controllers import (
IterationController, GradientNormController, DeltaEnergyController,
GradInfNormController)
GradInfNormController, AbsDeltaEnergyController)
from .minimization.minimizer import Minimizer
from .minimization.conjugate_gradient import ConjugateGradient
from .minimization.nonlinear_cg import NonlinearCG
......
......@@ -15,9 +15,10 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..logger import logger
from ..utilities import NiftyMeta
import numpy as np
class IterationController(metaclass=NiftyMeta):
......@@ -266,3 +267,70 @@ class DeltaEnergyController(IterationController):
return self.CONVERGED
return self.CONTINUE
class AbsDeltaEnergyController(IterationController):
"""An iteration controller checking (mainly) the energy change from one
iteration to the next.
Parameters
----------
deltaE : float
If the difference between the last and current energies is below this
value, the convergence counter will be increased in this iteration.
convergence_level : int, default=1
The number which the convergence counter must reach before the
iteration is considered to be converged
iteration_limit : int, optional
The maximum number of iterations that will be carried out.
name : str, optional
if supplied, this string and some diagnostic information will be
printed after every iteration
"""
def __init__(self, deltaE, convergence_level=1, iteration_limit=None,
name=None):
self._deltaE = deltaE
self._convergence_level = convergence_level
self._iteration_limit = iteration_limit
self._name = name
def start(self, energy):
self._itcount = -1
self._ccount = 0
self._Eold = 0.
return self.check(energy)
def check(self, energy):
self._itcount += 1
inclvl = False
Eval = energy.value
diff = abs(self._Eold-Eval)
if self._itcount > 0:
if diff < self._deltaE:
inclvl = True
self._Eold = Eval
if inclvl:
self._ccount += 1
else:
self._ccount = max(0, self._ccount-1)
# report
if self._name is not None:
logger.info(
"{}: Iteration #{} energy={:.6E} diff={:.6E} crit={:.1E} clvl={}"
.format(self._name, self._itcount, Eval, diff, self._deltaE,
self._ccount))
# Are we done?
if self._iteration_limit is not None:
if self._itcount >= self._iteration_limit:
logger.warning(
"{} Iteration limit reached. Assuming convergence"
.format("" if self._name is None else self._name+": "))
return self.CONVERGED
if self._ccount >= self._convergence_level:
return self.CONVERGED
return self.CONTINUE
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment