Commit 9584a133 authored by Philipp Arras's avatar Philipp Arras Committed by Lukas Platz

add AbsDeltaEnergyController

parent 5ded9f44
Pipeline #60211 passed with stages
in 8 minutes and 25 seconds
......@@ -59,7 +59,7 @@ from .probing import probe_with_posterior_samples, probe_diagonal, \
from .minimization.line_search import LineSearch
from .minimization.iteration_controllers import (
IterationController, GradientNormController, DeltaEnergyController,
GradInfNormController)
GradInfNormController, AbsDeltaEnergyController)
from .minimization.minimizer import Minimizer
from .minimization.conjugate_gradient import ConjugateGradient
from .minimization.nonlinear_cg import NonlinearCG
......
......@@ -18,6 +18,7 @@
from ..logger import logger
from ..utilities import NiftyMeta
import numpy as np
from time import time
class IterationController(metaclass=NiftyMeta):
......@@ -266,3 +267,75 @@ class DeltaEnergyController(IterationController):
return self.CONVERGED
return self.CONTINUE
class AbsDeltaEnergyController(IterationController):
"""An iteration controller checking (mainly) the energy change from one
iteration to the next.
Parameters
----------
deltaE : float
If the difference between the last and current energies is below this
value, the convergence counter will be increased in this iteration.
convergence_level : int, default=1
The number which the convergence counter must reach before the
iteration is considered to be converged
iteration_limit : int, optional
The maximum number of iterations that will be carried out.
name : str, optional
if supplied, this string and some diagnostic information will be
printed after every iteration
"""
def __init__(self, deltaE, convergence_level=1, iteration_limit=None,
name=None, file_name=None):
self._deltaE = deltaE
self._convergence_level = convergence_level
self._iteration_limit = iteration_limit
self._name = name
self._file_name = file_name
def start(self, energy):
self._itcount = -1
self._ccount = 0
self._Eold = 0.
return self.check(energy)
def check(self, energy):
self._itcount += 1
inclvl = False
Eval = energy.value
diff = abs(self._Eold-Eval)
if self._itcount > 0:
if diff < self._deltaE:
inclvl = True
self._Eold = Eval
if inclvl:
self._ccount += 1
else:
self._ccount = max(0, self._ccount-1)
# report
if self._name is not None:
logger.info(
"{}: Iteration #{} energy={:.6E} diff={:.6E} crit={:.6E}"
.format(self._name, self._itcount, Eval, diff, self._deltaE))
# Are we done?
if self._iteration_limit is not None:
if self._itcount >= self._iteration_limit:
logger.warning(
"{} Iteration limit reached. Assuming convergence"
.format("" if self._name is None else self._name+": "))
return self.CONVERGED
if self._ccount >= self._convergence_level:
return self.CONVERGED
# Write energy to file
if self._file_name is not None:
with open(self._file_name, 'a+') as f:
f.write('{} {} {}\n'.format(time(), energy.value, diff))
return self.CONTINUE
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment