Commit 387298b6 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add energy logging

parent 4d8c1460
......@@ -11,10 +11,12 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import functools
import numpy as np
from ..logger import logger
......@@ -41,6 +43,9 @@ class IterationController(metaclass=NiftyMeta):
CONVERGED, CONTINUE, ERROR = list(range(3))
def __init__(self):
self._history = None
def start(self, energy):
"""Starts the iteration.
......@@ -69,6 +74,28 @@ class IterationController(metaclass=NiftyMeta):
"""
raise NotImplementedError
def pop_history(self):
"""FIXME"""
if self._history is None:
raise RuntimeError('No history was taken')
res = self._history
self._history = []
return res
def activate_and_reset_logging(self):
"""FIXME"""
self._history = []
def append_history(func):
"""FIXME"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if args[0]._history is not None:
args[0]._history.append(args[1].value)
return func(*args, **kwargs)
return wrapper
class GradientNormController(IterationController):
"""An iteration controller checking (mainly) the L2 gradient norm.
......@@ -94,12 +121,14 @@ class GradientNormController(IterationController):
def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None,
convergence_level=1, iteration_limit=None, name=None):
super(GradientNormController, self).__init__()
self._tol_abs_gradnorm = tol_abs_gradnorm
self._tol_rel_gradnorm = tol_rel_gradnorm
self._convergence_level = convergence_level
self._iteration_limit = iteration_limit
self._name = name
@append_history
def start(self, energy):
self._itcount = -1
self._ccount = 0
......@@ -108,6 +137,7 @@ class GradientNormController(IterationController):
* energy.gradient_norm
return self.check(energy)
@append_history
def check(self, energy):
self._itcount += 1
......@@ -163,16 +193,19 @@ class GradInfNormController(IterationController):
def __init__(self, tol, convergence_level=1, iteration_limit=None,
name=None):
super(GradInfNormController, self).__init__()
self._tol = tol
self._convergence_level = convergence_level
self._iteration_limit = iteration_limit
self._name = name
@append_history
def start(self, energy):
self._itcount = -1
self._ccount = 0
return self.check(energy)
@append_history
def check(self, energy):
self._itcount += 1
......@@ -224,17 +257,20 @@ class DeltaEnergyController(IterationController):
def __init__(self, tol_rel_deltaE, convergence_level=1,
iteration_limit=None, name=None):
super(DeltaEnergyController, self).__init__()
self._tol_rel_deltaE = tol_rel_deltaE
self._convergence_level = convergence_level
self._iteration_limit = iteration_limit
self._name = name
@append_history
def start(self, energy):
self._itcount = -1
self._ccount = 0
self._Eold = 0.
return self.check(energy)
@append_history
def check(self, energy):
self._itcount += 1
......@@ -290,17 +326,20 @@ class AbsDeltaEnergyController(IterationController):
def __init__(self, deltaE, convergence_level=1, iteration_limit=None,
name=None):
super(AbsDeltaEnergyController, self).__init__()
self._deltaE = deltaE
self._convergence_level = convergence_level
self._iteration_limit = iteration_limit
self._name = name
@append_history
def start(self, energy):
self._itcount = -1
self._ccount = 0
self._Eold = 0.
return self.check(energy)
@append_history
def check(self, energy):
self._itcount += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment