Skip to content
Snippets Groups Projects
Commit 1cbaf2d0 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

add new iteration controller

parent 8e75c29a
No related branches found
No related tags found
No related merge requests found
......@@ -55,7 +55,8 @@ from .probing import probe_with_posterior_samples, probe_diagonal, \
from .minimization.line_search import LineSearch
from .minimization.iteration_controllers import (
IterationController, GradientNormController, DeltaEnergyController)
IterationController, GradientNormController, DeltaEnergyController,
GradInfNormController)
from .minimization.minimizer import Minimizer
from .minimization.conjugate_gradient import ConjugateGradient
from .minimization.nonlinear_cg import NonlinearCG
......
......@@ -21,6 +21,7 @@ from __future__ import absolute_import, division, print_function
from ..compat import *
from ..logger import logger
from ..utilities import NiftyMetaBase
import numpy as np
class IterationController(NiftyMetaBase()):
......@@ -145,6 +146,48 @@ class GradientNormController(IterationController):
return self.CONTINUE
class GradInfNormController(IterationController):
def __init__(self, tol=None, convergence_level=1, iteration_limit=None,
name=None):
self._tol = tol
self._convergence_level = convergence_level
self._iteration_limit = iteration_limit
self._name = name
def start(self, energy):
self._itcount = -1
self._ccount = 0
return self.check(energy)
def check(self, energy):
self._itcount += 1
crit = energy.gradient.norm(np.inf) / abs(energy.value)
if self._tol is not None and crit <= self._tol:
self._ccount += 1
else:
self._ccount = max(0, self._ccount-1)
# report
if self._name is not None:
logger.info(
"{}: Iteration #{} energy={:.6E} crit={:.2E} clvl={}"
.format(self._name, self._itcount, energy.value,
crit, self._ccount))
# Are we done?
if self._iteration_limit is not None:
if self._itcount >= self._iteration_limit:
logger.warning(
"{} Iteration limit reached. Assuming convergence"
.format("" if self._name is None else self._name+": "))
return self.CONVERGED
if self._ccount >= self._convergence_level:
return self.CONVERGED
return self.CONTINUE
class DeltaEnergyController(IterationController):
def __init__(self, tol_rel_deltaE, convergence_level=1,
iteration_limit=None, name=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment