# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import abc
import numpy as np
from .minimizer import Minimizer
from .line_searching import LineSearchStrongWolfe
class DescentMinimizer(Minimizer):
""" A base class used by gradient methods to find a local minimum.
Descent minimization methods are used to find a local minimum of a scalar
function by following a descent direction. This class implements the
minimization procedure once a descent direction is known. The descent
direction has to be implemented separately.
Parameters
----------
line_searcher : callable *optional*
Function which infers the step size in the descent direction
(default : LineSearchStrongWolfe()).
callback : callable *optional*
Function f(energy, iteration_number) supplied by the user to perform
in-situ analysis at every iteration step. When being called the
current energy and iteration_number are passed. (default: None)
Attributes
----------
line_searcher : LineSearch
Function which infers the optimal step size for functional minization
given a descent direction.
callback : function
Function f(energy, iteration_number) supplied by the user to perform
in-situ analysis at every iteration step. When being called the
current energy and iteration_number are passed.
Notes
------
The callback function can be used to externally stop the minimization by
raising a `StopIteration` exception.
Check `get_descent_direction` of a derived class for information on the
concrete minization scheme.
"""
def __init__(self, controller, line_searcher=LineSearchStrongWolfe()):
super(DescentMinimizer, self).__init__()
self.line_searcher = line_searcher
self._controller = controller
def __call__(self, energy):
""" Performs the minimization of the provided Energy functional.
Parameters
----------
energy : Energy object
Energy object which provides value, gradient and curvature at a
specific position in parameter space.
Returns
-------
energy : Energy object
Latest `energy` of the minimization.
convergence : integer
Latest convergence level indicating whether the minimization
has converged or not.
Note
----
The minimization is stopped if
* the callback function raises a `StopIteration` exception,
* a perfectly flat point is reached,
* according to the line-search the minimum is found,
* the target convergence level is reached,
* the iteration limit is reached.
"""
f_k_minus_1 = None
controller = self._controller
status = controller.start(energy)
if status != controller.CONTINUE:
return E, status
while True:
# check if position is at a flat point
if energy.gradient_norm == 0:
self.logger.info("Reached perfectly flat point. Stopping.")
return energy, controller.CONVERGED
# current position is encoded in energy object
descent_direction = self.get_descent_direction(energy)
# compute the step length, which minimizes energy.value along the
# search direction
try:
new_energy = \
self.line_searcher.perform_line_search(
energy=energy,
pk=descent_direction,
f_k_minus_1=f_k_minus_1)
except RuntimeError:
self.logger.warn(
"Stopping because of RuntimeError in line-search")
return energy, controller.ERROR
f_k_minus_1 = energy.value
# check if new energy value is bigger than old energy value
if (new_energy.value - energy.value) > 0:
self.logger.info("Line search algorithm returned a new energy "
"that was larger than the old one. Stopping.")
return energy, controller.ERROR
energy = new_energy
status = self._controller.check(energy)
if status != controller.CONTINUE:
return energy, status
@abc.abstractmethod
def get_descent_direction(self, energy):
raise NotImplementedError