Skip to content
Snippets Groups Projects
Commit 1cbaf2d0 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

add new iteration controller

parent 8e75c29a
No related branches found
No related tags found
No related merge requests found
...@@ -55,7 +55,8 @@ from .probing import probe_with_posterior_samples, probe_diagonal, \ ...@@ -55,7 +55,8 @@ from .probing import probe_with_posterior_samples, probe_diagonal, \
from .minimization.line_search import LineSearch from .minimization.line_search import LineSearch
from .minimization.iteration_controllers import ( from .minimization.iteration_controllers import (
IterationController, GradientNormController, DeltaEnergyController) IterationController, GradientNormController, DeltaEnergyController,
GradInfNormController)
from .minimization.minimizer import Minimizer from .minimization.minimizer import Minimizer
from .minimization.conjugate_gradient import ConjugateGradient from .minimization.conjugate_gradient import ConjugateGradient
from .minimization.nonlinear_cg import NonlinearCG from .minimization.nonlinear_cg import NonlinearCG
......
...@@ -21,6 +21,7 @@ from __future__ import absolute_import, division, print_function ...@@ -21,6 +21,7 @@ from __future__ import absolute_import, division, print_function
from ..compat import * from ..compat import *
from ..logger import logger from ..logger import logger
from ..utilities import NiftyMetaBase from ..utilities import NiftyMetaBase
import numpy as np
class IterationController(NiftyMetaBase()): class IterationController(NiftyMetaBase()):
...@@ -145,6 +146,48 @@ class GradientNormController(IterationController): ...@@ -145,6 +146,48 @@ class GradientNormController(IterationController):
return self.CONTINUE return self.CONTINUE
class GradInfNormController(IterationController):
def __init__(self, tol=None, convergence_level=1, iteration_limit=None,
name=None):
self._tol = tol
self._convergence_level = convergence_level
self._iteration_limit = iteration_limit
self._name = name
def start(self, energy):
self._itcount = -1
self._ccount = 0
return self.check(energy)
def check(self, energy):
self._itcount += 1
crit = energy.gradient.norm(np.inf) / abs(energy.value)
if self._tol is not None and crit <= self._tol:
self._ccount += 1
else:
self._ccount = max(0, self._ccount-1)
# report
if self._name is not None:
logger.info(
"{}: Iteration #{} energy={:.6E} crit={:.2E} clvl={}"
.format(self._name, self._itcount, energy.value,
crit, self._ccount))
# Are we done?
if self._iteration_limit is not None:
if self._itcount >= self._iteration_limit:
logger.warning(
"{} Iteration limit reached. Assuming convergence"
.format("" if self._name is None else self._name+": "))
return self.CONVERGED
if self._ccount >= self._convergence_level:
return self.CONVERGED
return self.CONTINUE
class DeltaEnergyController(IterationController): class DeltaEnergyController(IterationController):
def __init__(self, tol_rel_deltaE, convergence_level=1, def __init__(self, tol_rel_deltaE, convergence_level=1,
iteration_limit=None, name=None): iteration_limit=None, name=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment