Commit 387298b6 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add energy logging

parent 4d8c1460
...@@ -11,10 +11,12 @@ ...@@ -11,10 +11,12 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2019 Max-Planck-Society # Copyright(C) 2013-2020 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import functools
import numpy as np import numpy as np
from ..logger import logger from ..logger import logger
...@@ -41,6 +43,9 @@ class IterationController(metaclass=NiftyMeta): ...@@ -41,6 +43,9 @@ class IterationController(metaclass=NiftyMeta):
CONVERGED, CONTINUE, ERROR = list(range(3)) CONVERGED, CONTINUE, ERROR = list(range(3))
def __init__(self):
self._history = None
def start(self, energy): def start(self, energy):
"""Starts the iteration. """Starts the iteration.
...@@ -69,6 +74,28 @@ class IterationController(metaclass=NiftyMeta): ...@@ -69,6 +74,28 @@ class IterationController(metaclass=NiftyMeta):
""" """
raise NotImplementedError raise NotImplementedError
def pop_history(self):
"""FIXME"""
if self._history is None:
raise RuntimeError('No history was taken')
res = self._history
self._history = []
return res
def activate_and_reset_logging(self):
"""FIXME"""
self._history = []
def append_history(func):
"""FIXME"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if args[0]._history is not None:
args[0]._history.append(args[1].value)
return func(*args, **kwargs)
return wrapper
class GradientNormController(IterationController): class GradientNormController(IterationController):
"""An iteration controller checking (mainly) the L2 gradient norm. """An iteration controller checking (mainly) the L2 gradient norm.
...@@ -94,12 +121,14 @@ class GradientNormController(IterationController): ...@@ -94,12 +121,14 @@ class GradientNormController(IterationController):
def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None, def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None,
convergence_level=1, iteration_limit=None, name=None): convergence_level=1, iteration_limit=None, name=None):
super(GradientNormController, self).__init__()
self._tol_abs_gradnorm = tol_abs_gradnorm self._tol_abs_gradnorm = tol_abs_gradnorm
self._tol_rel_gradnorm = tol_rel_gradnorm self._tol_rel_gradnorm = tol_rel_gradnorm
self._convergence_level = convergence_level self._convergence_level = convergence_level
self._iteration_limit = iteration_limit self._iteration_limit = iteration_limit
self._name = name self._name = name
@append_history
def start(self, energy): def start(self, energy):
self._itcount = -1 self._itcount = -1
self._ccount = 0 self._ccount = 0
...@@ -108,6 +137,7 @@ class GradientNormController(IterationController): ...@@ -108,6 +137,7 @@ class GradientNormController(IterationController):
* energy.gradient_norm * energy.gradient_norm
return self.check(energy) return self.check(energy)
@append_history
def check(self, energy): def check(self, energy):
self._itcount += 1 self._itcount += 1
...@@ -163,16 +193,19 @@ class GradInfNormController(IterationController): ...@@ -163,16 +193,19 @@ class GradInfNormController(IterationController):
def __init__(self, tol, convergence_level=1, iteration_limit=None, def __init__(self, tol, convergence_level=1, iteration_limit=None,
name=None): name=None):
super(GradInfNormController, self).__init__()
self._tol = tol self._tol = tol
self._convergence_level = convergence_level self._convergence_level = convergence_level
self._iteration_limit = iteration_limit self._iteration_limit = iteration_limit
self._name = name self._name = name
@append_history
def start(self, energy): def start(self, energy):
self._itcount = -1 self._itcount = -1
self._ccount = 0 self._ccount = 0
return self.check(energy) return self.check(energy)
@append_history
def check(self, energy): def check(self, energy):
self._itcount += 1 self._itcount += 1
...@@ -224,17 +257,20 @@ class DeltaEnergyController(IterationController): ...@@ -224,17 +257,20 @@ class DeltaEnergyController(IterationController):
def __init__(self, tol_rel_deltaE, convergence_level=1, def __init__(self, tol_rel_deltaE, convergence_level=1,
iteration_limit=None, name=None): iteration_limit=None, name=None):
super(DeltaEnergyController, self).__init__()
self._tol_rel_deltaE = tol_rel_deltaE self._tol_rel_deltaE = tol_rel_deltaE
self._convergence_level = convergence_level self._convergence_level = convergence_level
self._iteration_limit = iteration_limit self._iteration_limit = iteration_limit
self._name = name self._name = name
@append_history
def start(self, energy): def start(self, energy):
self._itcount = -1 self._itcount = -1
self._ccount = 0 self._ccount = 0
self._Eold = 0. self._Eold = 0.
return self.check(energy) return self.check(energy)
@append_history
def check(self, energy): def check(self, energy):
self._itcount += 1 self._itcount += 1
...@@ -290,17 +326,20 @@ class AbsDeltaEnergyController(IterationController): ...@@ -290,17 +326,20 @@ class AbsDeltaEnergyController(IterationController):
def __init__(self, deltaE, convergence_level=1, iteration_limit=None, def __init__(self, deltaE, convergence_level=1, iteration_limit=None,
name=None): name=None):
super(AbsDeltaEnergyController, self).__init__()
self._deltaE = deltaE self._deltaE = deltaE
self._convergence_level = convergence_level self._convergence_level = convergence_level
self._iteration_limit = iteration_limit self._iteration_limit = iteration_limit
self._name = name self._name = name
@append_history
def start(self, energy): def start(self, energy):
self._itcount = -1 self._itcount = -1
self._ccount = 0 self._ccount = 0
self._Eold = 0. self._Eold = 0.
return self.check(energy) return self.check(energy)
@append_history
def check(self, energy): def check(self, energy):
self._itcount += 1 self._itcount += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment