Commit fc889da1 authored by Philipp Arras's avatar Philipp Arras

Remove intermediate functionality from iteration controllers

parent be972356
Pipeline #61918 passed with stages
in 8 minutes and 17 seconds
...@@ -90,41 +90,33 @@ class GradientNormController(IterationController): ...@@ -90,41 +90,33 @@ class GradientNormController(IterationController):
name : str, optional name : str, optional
if supplied, this string and some diagnostic information will be if supplied, this string and some diagnostic information will be
printed after every iteration printed after every iteration
p : float
Order of norm, default is the 2-Norm (p=2)
""" """
def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None, def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None,
convergence_level=1, iteration_limit=None, name=None, p=2): convergence_level=1, iteration_limit=None, name=None):
self._tol_abs_gradnorm = tol_abs_gradnorm self._tol_abs_gradnorm = tol_abs_gradnorm
self._tol_rel_gradnorm = tol_rel_gradnorm self._tol_rel_gradnorm = tol_rel_gradnorm
self._convergence_level = convergence_level self._convergence_level = convergence_level
self._iteration_limit = iteration_limit self._iteration_limit = iteration_limit
self._name = name self._name = name
self._p = p
def start(self, energy): def start(self, energy):
self._itcount = -1 self._itcount = -1
self._ccount = 0 self._ccount = 0
if self._tol_rel_gradnorm is not None: if self._tol_rel_gradnorm is not None:
self._tol_rel_gradnorm_now = self._tol_rel_gradnorm * self._norm(energy) self._tol_rel_gradnorm_now = self._tol_rel_gradnorm \
* energy.gradient_norm
return self.check(energy) return self.check(energy)
def _norm(self, energy):
# FIXME Only p=2 norm is cached in energy class
if self._p == 2:
return energy.gradient_norm
return energy.gradient.norm(self._p)
def check(self, energy): def check(self, energy):
self._itcount += 1 self._itcount += 1
inclvl = False inclvl = False
if self._tol_abs_gradnorm is not None: if self._tol_abs_gradnorm is not None:
if self._norm(energy) <= self._tol_abs_gradnorm: if energy.gradient_norm <= self._tol_abs_gradnorm:
inclvl = True inclvl = True
if self._tol_rel_gradnorm is not None: if self._tol_rel_gradnorm is not None:
if self._norm(energy) <= self._tol_rel_gradnorm_now: if energy.gradient_norm <= self._tol_rel_gradnorm_now:
inclvl = True inclvl = True
if inclvl: if inclvl:
self._ccount += 1 self._ccount += 1
...@@ -136,7 +128,7 @@ class GradientNormController(IterationController): ...@@ -136,7 +128,7 @@ class GradientNormController(IterationController):
logger.info( logger.info(
"{}: Iteration #{} energy={:.6E} gradnorm={:.2E} clvl={}" "{}: Iteration #{} energy={:.6E} gradnorm={:.2E} clvl={}"
.format(self._name, self._itcount, energy.value, .format(self._name, self._itcount, energy.value,
self._norm(energy), self._ccount)) energy.gradient_norm, self._ccount))
# Are we done? # Are we done?
if self._iteration_limit is not None: if self._iteration_limit is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment