Commit 1507fdce authored by Martin Reinecke's avatar Martin Reinecke

some cleanups

parent 7ecdbcec
......@@ -74,7 +74,8 @@ if __name__ == '__main__':
# set up minimization and inversion schemes
ic_cg = ift.GradientNormController(iteration_limit=10)
ic_sampling = ift.GradientNormController(iteration_limit=100)
ic_newton = ift.GradientNormController(name='Newton', iteration_limit=100)
ic_newton = ift.DeltaEnergyController(
name='Newton', tol_rel_deltaE=1e-8, iteration_limit=100)
minimizer = ift.RelaxedNewton(ic_newton)
# minimizer = ift.VL_BFGS(ic_newton)
# minimizer = ift.NewtonCG(xtol=1e-10, maxiter=100, disp=True)
......
......@@ -53,8 +53,8 @@ from .probing import probe_with_posterior_samples, probe_diagonal, \
from .minimization.line_search import LineSearch
from .minimization.line_search_strong_wolfe import LineSearchStrongWolfe
from .minimization.iteration_controller import IterationController
from .minimization.gradient_norm_controller import GradientNormController
from .minimization.iteration_controllers import (
IterationController, GradientNormController, DeltaEnergyController)
from .minimization.minimizer import Minimizer
from .minimization.conjugate_gradient import ConjugateGradient
from .minimization.nonlinear_cg import NonlinearCG
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..utilities import NiftyMetaBase
class IterationController(NiftyMetaBase()):
"""The abstract base class for all iteration controllers.
An iteration controller is an object that monitors the progress of a
minimization iteration. At the begin of the minimization, its start()
method is called with the energy object at the initial position.
Afterwards, its check() method is called during every iteration step with
the energy object describing the current position.
Based on that information, the iteration controller has to decide whether
iteration needs to progress further (in this case it returns CONTINUE), or
if sufficient convergence has been reached (in this case it returns
CONVERGED), or if some error has been detected (then it returns ERROR).
The concrete convergence criteria can be chosen by inheriting from this
class; the implementer has full flexibility to use whichever criteria are
appropriate for a particular problem - as long as they can be computed from
the information passed to the controller during the iteration process.
"""
CONVERGED, CONTINUE, ERROR = list(range(3))
def start(self, energy):
"""Starts the iteration.
Parameters
----------
energy : Energy object
Energy object at the start of the iteration
Returns
-------
status : integer status, can be CONVERGED, CONTINUE or ERROR
"""
raise NotImplementedError
def check(self, energy):
"""Checks the state of the iteration. Called after every step.
Parameters
----------
energy : Energy object
Energy object at the start of the iteration
Returns
-------
status : integer status, can be CONVERGED, CONTINUE or ERROR
"""
raise NotImplementedError
......@@ -20,7 +20,56 @@ from __future__ import absolute_import, division, print_function
from ..compat import *
from ..logger import logger
from .iteration_controller import IterationController
from ..utilities import NiftyMetaBase
class IterationController(NiftyMetaBase()):
"""The abstract base class for all iteration controllers.
An iteration controller is an object that monitors the progress of a
minimization iteration. At the begin of the minimization, its start()
method is called with the energy object at the initial position.
Afterwards, its check() method is called during every iteration step with
the energy object describing the current position.
Based on that information, the iteration controller has to decide whether
iteration needs to progress further (in this case it returns CONTINUE), or
if sufficient convergence has been reached (in this case it returns
CONVERGED), or if some error has been detected (then it returns ERROR).
The concrete convergence criteria can be chosen by inheriting from this
class; the implementer has full flexibility to use whichever criteria are
appropriate for a particular problem - as long as they can be computed from
the information passed to the controller during the iteration process.
"""
CONVERGED, CONTINUE, ERROR = list(range(3))
def start(self, energy):
"""Starts the iteration.
Parameters
----------
energy : Energy object
Energy object at the start of the iteration
Returns
-------
status : integer status, can be CONVERGED, CONTINUE or ERROR
"""
raise NotImplementedError
def check(self, energy):
"""Checks the state of the iteration. Called after every step.
Parameters
----------
energy : Energy object
Energy object at the start of the iteration
Returns
-------
status : integer status, can be CONVERGED, CONTINUE or ERROR
"""
raise NotImplementedError
class GradientNormController(IterationController):
......@@ -54,20 +103,6 @@ class GradientNormController(IterationController):
self._name = name
def start(self, energy):
""" Start a new iteration.
The iteration and convergence counters are set to 0.
Parameters
----------
energy : Energy
The energy functional to be minimized.
Returns
-------
int : iteration status
can be CONVERGED or CONTINUE
"""
self._itcount = -1
self._ccount = 0
if self._tol_rel_gradnorm is not None:
......@@ -76,27 +111,6 @@ class GradientNormController(IterationController):
return self.check(energy)
def check(self, energy):
""" Check for convergence.
- Increase the iteration counter by 1.
- If any of the convergence criteria are fulfilled, increase the
convergence counter by 1; else decrease it by 1 (but not below 0).
- If the convergence counter exceeds the convergence level, return
CONVERGED.
- If the iteration counter exceeds the iteration limit, return
CONVERGED.
- Otherwise return CONTINUE.
Parameters
----------
energy : Energy
The current solution estimate
Returns
-------
int : iteration status
can be CONVERGED or CONTINUE
"""
self._itcount += 1
inclvl = False
......@@ -113,20 +127,65 @@ class GradientNormController(IterationController):
# report
if self._name is not None:
msg = self._name+":"
msg += " Iteration #" + str(self._itcount)
msg += " energy={:.6E}".format(energy.value)
msg += " gradnorm={:.2E}".format(energy.gradient_norm)
msg += " clvl=" + str(self._ccount)
logger.info(msg)
# self.logger.info(msg)
logger.info(
"{}: Iteration #{} energy={:.6E} gradnorm={:.2E} clvl={}"
.format(self._name, self._itcount, energy.value,
energy.gradient_norm, self._ccount))
# Are we done?
if self._iteration_limit is not None:
if self._itcount >= self._iteration_limit:
logger.warning(
"{} Iteration limit reached. Assuming convergence"
.format("" if self._name is None else self._name+": "))
return self.CONVERGED
if self._ccount >= self._convergence_level:
return self.CONVERGED
return self.CONTINUE
class DeltaEnergyController(IterationController):
def __init__(self, tol_rel_deltaE, convergence_level=1,
iteration_limit=None, name=None):
self._tol_rel_deltaE = tol_rel_deltaE
self._convergence_level = convergence_level
self._iteration_limit = iteration_limit
self._name = name
def start(self, energy):
self._itcount = -1
self._ccount = 0
self._Eold = 0.
return self.check(energy)
def check(self, energy):
self._itcount += 1
inclvl = False
Eval = energy.value
rel = abs(self._Eold-Eval)/max(abs(self._Eold), abs(Eval))
if self._itcount > 0:
if rel < self._tol_rel_deltaE:
inclvl = True
self._Eold = Eval
if inclvl:
self._ccount += 1
else:
self._ccount = max(0, self._ccount-1)
# report
if self._name is not None:
logger.info(
"{}: Iteration #{} energy={:.6E} reldiff={:.6E} clvl={}"
.format(self._name, self._itcount, Eval, rel, self._ccount))
# Are we done?
if self._iteration_limit is not None:
if self._itcount >= self._iteration_limit:
msg = "" if self._name is None else self._name+": "
msg += "Iteration limit reached. Assuming convergence"
logger.warning(msg)
logger.warning(
"{} Iteration limit reached. Assuming convergence"
.format("" if self._name is None else self._name+": "))
return self.CONVERGED
if self._ccount >= self._convergence_level:
return self.CONVERGED
......
......@@ -136,16 +136,16 @@ class ChainOperator(LinearOperator):
x = op.apply(x, mode)
return x
def draw_sample(self, from_inverse=False, dtype=np.float64):
from ..sugar import from_random
if len(self._ops) == 1:
return self._ops[0].draw_sample(from_inverse, dtype)
samp = from_random(random_type="normal", domain=self._domain,
dtype=dtype)
for op in self._ops:
samp = op.process_sample(samp, from_inverse)
return samp
# def draw_sample(self, from_inverse=False, dtype=np.float64):
# from ..sugar import from_random
# if len(self._ops) == 1:
# return self._ops[0].draw_sample(from_inverse, dtype)
#
# samp = from_random(random_type="normal", domain=self._domain,
# dtype=dtype)
# for op in self._ops:
# samp = op.process_sample(samp, from_inverse)
# return samp
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops)
......
......@@ -173,5 +173,4 @@ class DiagonalOperator(EndomorphicOperator):
return self.process_sample(res, from_inverse)
def __repr__(self):
subs = utilities.indent(self._domain.__repr__())
return "DiagonalOperator:\n Spaces={}\n".format(self._spaces) + subs
return "DiagonalOperator"
......@@ -60,6 +60,9 @@ class Operator(NiftyMetaBase()):
return _OpChain.make((self, x))
return self.apply(x)
def __repr__(self):
return self.__class__.__name__
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]:
def func(f):
......
......@@ -88,8 +88,8 @@ class ScalingOperator(EndomorphicOperator):
raise ValueError("operator not positive definite")
return 1./np.sqrt(fct) if from_inverse else np.sqrt(fct)
def process_sample(self, samp, from_inverse):
return samp*self._get_fct(from_inverse)
# def process_sample(self, samp, from_inverse):
# return samp*self._get_fct(from_inverse)
def draw_sample(self, from_inverse=False, dtype=np.float64):
from ..sugar import from_random
......@@ -97,5 +97,4 @@ class ScalingOperator(EndomorphicOperator):
std=self._get_fct(from_inverse), dtype=dtype)
def __repr__(self):
subs = utilities.indent(self._domain.__repr__())
return "ScalingOperator:\n Factor={}\n".format(self._factor) + subs
return "ScalingOperator ({})".format(self._factor)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment