Commit dd3b24aa authored by Martin Reinecke's avatar Martin Reinecke

more restructuring

parent 59ca2979
Pipeline #16889 failed with stage
in 6 minutes and 19 seconds
......@@ -17,6 +17,9 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from line_searching import *
from iteration_controller import IterationController
from default_iteration_controller import DefaultIterationController
from minimizer import Minimizer
from conjugate_gradient import ConjugateGradient
from descent_minimizer import DescentMinimizer
from steepest_descent import SteepestDescent
......
......@@ -19,10 +19,10 @@
from __future__ import division
import numpy as np
from keepers import Loggable
from .minimizer import Minimizer
class ConjugateGradient(Loggable, object):
class ConjugateGradient(Minimizer):
""" Implementation of the Conjugate Gradient scheme.
It is an iterative method for solving a linear system of equations:
......@@ -30,43 +30,22 @@ class ConjugateGradient(Loggable, object):
Parameters
----------
convergence_tolerance : float *optional*
Tolerance specifying the case of convergence. (default: 1E-4)
convergence_level : integer *optional*
Number of times the tolerance must be undershot before convergence
is reached. (default: 3)
iteration_limit : integer *optional*
Maximum number of iterations performed (default: None).
reset_count : integer *optional*
Number of iterations after which to restart; i.e., forget previous
conjugated directions (default: None).
preconditioner : Operator *optional*
This operator can be provided which transforms the variables of the
system to improve the conditioning (default: None).
callback : callable *optional*
Function f(energy, iteration_number) supplied by the user to perform
in-situ analysis at every iteration step. When being called the
current energy and iteration_number are passed. (default: None)
Attributes
----------
convergence_tolerance : float
Tolerance specifying the case of convergence.
convergence_level : integer
Number of times the tolerance must be undershot before convergence
is reached. (default: 3)
iteration_limit : integer
Maximum number of iterations performed.
reset_count : integer
Number of iterations after which to restart; i.e., forget previous
conjugated directions.
preconditioner : function
This operator can be provided which transforms the variables of the
system to improve the conditioning (default: None).
callback : callable
Function f(energy, iteration_number) supplied by the user to perform
in-situ analysis at every iteration step. When being called the
current energy and iteration_number are passed. (default: None)
controller : IterationController
References
----------
......@@ -75,23 +54,13 @@ class ConjugateGradient(Loggable, object):
"""
def __init__(self, convergence_tolerance=1E-4, convergence_level=3,
iteration_limit=None, reset_count=None,
preconditioner=None, callback=None):
self.convergence_tolerance = np.float(convergence_tolerance)
self.convergence_level = np.float(convergence_level)
if iteration_limit is not None:
iteration_limit = int(iteration_limit)
self.iteration_limit = iteration_limit
def __init__(self, controller, reset_count=None, preconditioner=None):
if reset_count is not None:
reset_count = int(reset_count)
self.reset_count = reset_count
self.preconditioner = preconditioner
self.callback = callback
self._controller = controller
def __call__(self, E):
""" Runs the conjugate gradient minimization.
......@@ -111,6 +80,11 @@ class ConjugateGradient(Loggable, object):
"""
controller = self._controller
status = controller.start(E)
if status != controller.CONTINUE:
return E, status
r = -E.gradient
if self.preconditioner is not None:
d = self.preconditioner(r)
......@@ -118,26 +92,20 @@ class ConjugateGradient(Loggable, object):
d = r.copy()
previous_gamma = (r.vdot(d)).real
if previous_gamma == 0:
self.logger.info("The starting guess is already perfect solution "
"for the inverse problem.")
return E, self.convergence_level+1
convergence = 0
iteration_number = 1
self.logger.info("Starting conjugate gradient.")
return E, controller.CONVERGED
while True:
if self.callback is not None:
self.callback(E, iteration_number)
q = E.curvature(d)
alpha = previous_gamma/(d.vdot(q).real)
if not np.isfinite(alpha):
self.logger.error("Alpha became infinite! Stopping.")
return E, 0
return E, controller.ERROR
E = E.at(E.position+d*alpha)
status = self._controller.check(E)
if status != controller.CONTINUE:
return E, status
reset = False
if alpha < 0:
......@@ -155,42 +123,14 @@ class ConjugateGradient(Loggable, object):
s = self.preconditioner(r)
else:
s = r.copy()
gamma = r.vdot(s).real
gamma = r.vdot(s).real
if gamma < 0:
self.logger.warn("Positive definiteness of preconditioner "
"violated!")
beta = max(0, gamma/previous_gamma)
delta = r.norm()
self.logger.debug("Iteration : %08u alpha = %3.1E "
"beta = %3.1E delta = %3.1E" %
(iteration_number, alpha, beta, delta))
if gamma == 0:
convergence = self.convergence_level+1
self.logger.info("Reached infinite convergence.")
break
elif abs(delta) < self.convergence_tolerance:
convergence += 1
self.logger.info("Updated convergence level to: %u" %
convergence)
if convergence == self.convergence_level:
self.logger.info("Reached target convergence level.")
break
else:
convergence = max(0, convergence-1)
if self.iteration_limit is not None:
if iteration_number == self.iteration_limit:
self.logger.warn("Reached iteration limit. Stopping.")
break
return E, controller.CONVERGED
d = s + d * beta
d = s + d * max(0, gamma/previous_gamma)
iteration_number += 1
previous_gamma = gamma
return E, convergence
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .iteration_controller import IterationController
class DefaultIterationController(IterationController):
def __init__ (self, tol_gradnorm=None, convergence_level=1,
iteration_limit=None):
super(DefaultIterationController, self).__init__()
self._tol_gradnorm = tol_gradnorm
self._convergence_level = convergence_level
self._iteration_limit = iteration_limit
def start(self, energy):
self._itcount = -1
self._ccount = 0
return self.check(energy)
def check(self, energy):
self._itcount += 1
print "iteration",self._itcount,"gradnorm",energy.gradient_norm,"level",self._ccount
if self._iteration_limit is not None:
if self._itcount >= self._iteration_limit:
return self.CONVERGED
if self._tol_gradnorm is not None:
if energy.gradient_norm <= self._tol_gradnorm:
self._ccount += 1
if self._ccount >= self._convergence_level:
return self.CONVERGED
else:
self._ccount = max(0, self._ccount-1)
return self.CONTINUE
......@@ -17,16 +17,13 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import abc
from nifty.nifty_meta import NiftyMeta
import numpy as np
from keepers import Loggable
from .minimizer import Minimizer
from .line_searching import LineSearchStrongWolfe
class DescentMinimizer(Loggable, object):
class DescentMinimizer(Minimizer):
""" A base class used by gradient methods to find a local minimum.
Descent minimization methods are used to find a local minimum of a scalar
......@@ -43,23 +40,9 @@ class DescentMinimizer(Loggable, object):
Function f(energy, iteration_number) supplied by the user to perform
in-situ analysis at every iteration step. When being called the
current energy and iteration_number are passed. (default: None)
convergence_tolerance : float *optional*
Tolerance specifying the case of convergence. (default: 1E-4)
convergence_level : integer *optional*
Number of times the tolerance must be undershot before convergence
is reached. (default: 3)
iteration_limit : integer *optional*
Maximum number of iterations performed (default: None).
Attributes
----------
convergence_tolerance : float
Tolerance specifying the case of convergence.
convergence_level : integer
Number of times the tolerance must be undershot before convergence
is reached. (default: 3)
iteration_limit : integer
Maximum number of iterations performed.
line_searcher : LineSearch
Function which infers the optimal step size for functional minization
given a descent direction.
......@@ -77,21 +60,11 @@ class DescentMinimizer(Loggable, object):
"""
__metaclass__ = NiftyMeta
def __init__(self, line_searcher=LineSearchStrongWolfe(), callback=None,
convergence_tolerance=1E-4, convergence_level=3,
iteration_limit=None):
self.convergence_tolerance = np.float(convergence_tolerance)
self.convergence_level = np.int(convergence_level)
if iteration_limit is not None:
iteration_limit = int(iteration_limit)
self.iteration_limit = iteration_limit
def __init__(self, controller, line_searcher=LineSearchStrongWolfe()):
super(DescentMinimizer, self).__init__()
self.line_searcher = line_searcher
self.callback = callback
self._controller = controller
def __call__(self, energy):
""" Performs the minimization of the provided Energy functional.
......@@ -121,28 +94,17 @@ class DescentMinimizer(Loggable, object):
"""
convergence = 0
f_k_minus_1 = None
iteration_number = 1
controller = self._controller
status = controller.start(energy)
if status != controller.CONTINUE:
return E, status
while True:
if self.callback is not None:
try:
self.callback(energy, iteration_number)
except StopIteration:
self.logger.info("Minimization was stopped by callback "
"function.")
break
# compute the the gradient for the current location
gradient = energy.gradient
gradient_norm = gradient.norm()
# check if position is at a flat point
if gradient_norm == 0:
if energy.gradient_norm == 0:
self.logger.info("Reached perfectly flat point. Stopping.")
convergence = self.convergence_level+2
break
return energy, controller.CONVERGED
# current position is encoded in energy object
descent_direction = self.get_descent_direction(energy)
......@@ -157,47 +119,20 @@ class DescentMinimizer(Loggable, object):
except RuntimeError:
self.logger.warn(
"Stopping because of RuntimeError in line-search")
break
return energy, controller.ERROR
f_k_minus_1 = energy.value
f_k = new_energy.value
delta = (abs(f_k-f_k_minus_1) /
max(abs(f_k), abs(f_k_minus_1), 1.))
# check if new energy value is bigger than old energy value
if (new_energy.value - energy.value) > 0:
self.logger.info("Line search algorithm returned a new energy "
"that was larger than the old one. Stopping.")
break
return energy, controller.ERROR
energy = new_energy
# check convergence
self.logger.debug("Iteration:%08u "
"delta=%3.1E energy=%3.1E" %
(iteration_number, delta,
energy.value))
if delta == 0:
convergence = self.convergence_level + 2
self.logger.info("Found minimum according to line-search. "
"Stopping.")
break
elif delta < self.convergence_tolerance:
convergence += 1
self.logger.info("Updated convergence level to: %u" %
convergence)
if convergence == self.convergence_level:
self.logger.info("Reached target convergence level.")
break
else:
convergence = max(0, convergence-1)
if self.iteration_limit is not None:
if iteration_number == self.iteration_limit:
self.logger.warn("Reached iteration limit. Stopping.")
break
iteration_number += 1
return energy, convergence
status = self._controller.check(energy)
if status != controller.CONTINUE:
return energy, status
@abc.abstractmethod
def get_descent_direction(self, energy):
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import abc
from nifty.nifty_meta import NiftyMeta
import numpy as np
from keepers import Loggable
class IterationController(Loggable, object):
__metaclass__ = NiftyMeta
CONVERGED, CONTINUE, ERROR = range(3)
@abc.abstractmethod
def start(self, energy):
"""
Parameters
----------
energy : Energy object
Energy object at the start of the iteration
Returns
-------
status : integer status, can be CONVERGED, CONTINUE or ERROR
"""
raise NotImplementedError
@abc.abstractmethod
def check(self, energy):
"""
Parameters
----------
energy : Energy object
Energy object at the start of the iteration
Returns
-------
status : integer status, can be CONVERGED, CONTINUE or ERROR
"""
raise NotImplementedError
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import abc
from nifty.nifty_meta import NiftyMeta
import numpy as np
from keepers import Loggable
class Minimizer(Loggable, object):
""" A base class used by all minimizers.
"""
__metaclass__ = NiftyMeta
@abc.abstractmethod
def __call__(self, energy):
""" Performs the minimization of the provided Energy functional.
Parameters
----------
energy : Energy object
Energy object which provides value, gradient and curvature at a
specific position in parameter space.
Returns
-------
energy : Energy object
Latest `energy` of the minimization.
status : integer
"""
raise NotImplementedError
......@@ -23,16 +23,12 @@ from .line_searching import LineSearchStrongWolfe
class VL_BFGS(DescentMinimizer):
def __init__(self, line_searcher=LineSearchStrongWolfe(), callback=None,
convergence_tolerance=1E-4, convergence_level=3,
iteration_limit=None, max_history_length=5):
def __init__(self, controller, line_searcher=LineSearchStrongWolfe(),
max_history_length=5):
super(VL_BFGS, self).__init__(
line_searcher=line_searcher,
callback=callback,
convergence_tolerance=convergence_tolerance,
convergence_level=convergence_level,
iteration_limit=iteration_limit)
controller=controller,
line_searcher=line_searcher)
self.max_history_length = max_history_length
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment