Commit 1cbaf2d0 authored by Martin Reinecke's avatar Martin Reinecke

add new iteration controller

parent 8e75c29a
......@@ -55,7 +55,8 @@ from .probing import probe_with_posterior_samples, probe_diagonal, \
from .minimization.line_search import LineSearch
from .minimization.iteration_controllers import (
IterationController, GradientNormController, DeltaEnergyController)
IterationController, GradientNormController, DeltaEnergyController,
GradInfNormController)
from .minimization.minimizer import Minimizer
from .minimization.conjugate_gradient import ConjugateGradient
from .minimization.nonlinear_cg import NonlinearCG
......
......@@ -21,6 +21,7 @@ from __future__ import absolute_import, division, print_function
from ..compat import *
from ..logger import logger
from ..utilities import NiftyMetaBase
import numpy as np
class IterationController(NiftyMetaBase()):
......@@ -145,6 +146,48 @@ class GradientNormController(IterationController):
return self.CONTINUE
class GradInfNormController(IterationController):
def __init__(self, tol=None, convergence_level=1, iteration_limit=None,
name=None):
self._tol = tol
self._convergence_level = convergence_level
self._iteration_limit = iteration_limit
self._name = name
def start(self, energy):
self._itcount = -1
self._ccount = 0
return self.check(energy)
def check(self, energy):
self._itcount += 1
crit = energy.gradient.norm(np.inf) / abs(energy.value)
if self._tol is not None and crit <= self._tol:
self._ccount += 1
else:
self._ccount = max(0, self._ccount-1)
# report
if self._name is not None:
logger.info(
"{}: Iteration #{} energy={:.6E} crit={:.2E} clvl={}"
.format(self._name, self._itcount, energy.value,
crit, self._ccount))
# Are we done?
if self._iteration_limit is not None:
if self._itcount >= self._iteration_limit:
logger.warning(
"{} Iteration limit reached. Assuming convergence"
.format("" if self._name is None else self._name+": "))
return self.CONVERGED
if self._ccount >= self._convergence_level:
return self.CONVERGED
return self.CONTINUE
class DeltaEnergyController(IterationController):
def __init__(self, tol_rel_deltaE, convergence_level=1,
iteration_limit=None, name=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment