descent_minimizer.py 7.77 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13 14 15 16 17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18 19

import abc
20
from nifty.nifty_meta import NiftyMeta
21 22 23 24 25 26 27 28

import numpy as np

from keepers import Loggable

from .line_searching import LineSearchStrongWolfe


29
class DescentMinimizer(Loggable, object):
30 31 32 33 34 35 36
    """ A base class used by gradient methods to find a local minimum.

    Descent minimization methods are used to find a local minimum of a scalar
    function by following a descent direction. This class implements the
    minimization procedure once a descent direction is known. The descent
    direction has to be implemented separately.

37 38
    Parameters
    ----------
39 40 41 42 43 44 45 46 47 48 49 50
    line_searcher : callable *optional*
        Function which infers the step size in the descent direction
        (default : LineSearchStrongWolfe()).
    callback : callable *optional*
        Function f(energy, iteration_number) supplied by the user to perform
        in-situ analysis at every iteration step. When being called the
        current energy and iteration_number are passed. (default: None)
    convergence_tolerance : float *optional*
        Tolerance specifying the case of convergence. (default: 1E-4)
    convergence_level : integer *optional*
        Number of times the tolerance must be undershot before convergence
        is reached. (default: 3)
51
    iteration_limit : integer *optional*
52 53
        Maximum number of iterations performed (default: None).

54 55 56
    Attributes
    ----------
    convergence_tolerance : float
57 58 59 60
        Tolerance specifying the case of convergence.
    convergence_level : integer
        Number of times the tolerance must be undershot before convergence
        is reached. (default: 3)
61 62
    iteration_limit : integer
        Maximum number of iterations performed.
63 64 65
    line_searcher : LineSearch
        Function which infers the optimal step size for functional minization
        given a descent direction.
66
    callback : function
67 68 69 70 71
        Function f(energy, iteration_number) supplied by the user to perform
        in-situ analysis at every iteration step. When being called the
        current energy and iteration_number are passed.

    Notes
72
    ------
73 74 75 76 77 78
    The callback function can be used to externally stop the minimization by
    raising a `StopIteration` exception.
    Check `get_descent_direction` of a derived class for information on the
    concrete minization scheme.

    """
79

80
    __metaclass__ = NiftyMeta
81 82 83 84 85 86

    def __init__(self, line_searcher=LineSearchStrongWolfe(), callback=None,
                 convergence_tolerance=1E-4, convergence_level=3,
                 iteration_limit=None):

        self.convergence_tolerance = np.float(convergence_tolerance)
87
        self.convergence_level = np.int(convergence_level)
88 89 90 91 92 93 94 95

        if iteration_limit is not None:
            iteration_limit = int(iteration_limit)
        self.iteration_limit = iteration_limit

        self.line_searcher = line_searcher
        self.callback = callback

96
    def __call__(self, energy):
97
        """ Performs the minimization of the provided Energy functional.
98 99 100 101

        Parameters
        ----------
        energy : Energy object
102 103
           Energy object which provides value, gradient and curvature at a
           specific position in parameter space.
104 105 106

        Returns
        -------
107
        energy : Energy object
108 109 110 111
            Latest `energy` of the minimization.
        convergence : integer
            Latest convergence level indicating whether the minimization
            has converged or not.
112

113 114
        Note
        ----
115 116 117 118 119 120 121
        The minimization is stopped if
            * the callback function raises a `StopIteration` exception,
            * a perfectly flat point is reached,
            * according to the line-search the minimum is found,
            * the target convergence level is reached,
            * the iteration limit is reached.

122 123 124 125 126 127 128 129 130
        """

        convergence = 0
        f_k_minus_1 = None
        iteration_number = 1

        while True:
            if self.callback is not None:
                try:
131
                    self.callback(energy, iteration_number)
132 133 134 135 136
                except StopIteration:
                    self.logger.info("Minimization was stopped by callback "
                                     "function.")
                    break

137 138
            # compute the the gradient for the current location
            gradient = energy.gradient
Martin Reinecke's avatar
Martin Reinecke committed
139
            gradient_norm = gradient.norm()
140

141
            # check if position is at a flat point
142 143 144 145 146
            if gradient_norm == 0:
                self.logger.info("Reached perfectly flat point. Stopping.")
                convergence = self.convergence_level+2
                break

147
            # current position is encoded in energy object
148
            descent_direction = self.get_descent_direction(energy)
149 150
            # compute the step length, which minimizes energy.value along the
            # search direction
Theo Steininger's avatar
Theo Steininger committed
151
            try:
152
                new_energy = \
Theo Steininger's avatar
Theo Steininger committed
153 154 155 156 157 158 159 160 161
                    self.line_searcher.perform_line_search(
                                                   energy=energy,
                                                   pk=descent_direction,
                                                   f_k_minus_1=f_k_minus_1)
            except RuntimeError:
                self.logger.warn(
                        "Stopping because of RuntimeError in line-search")
                break

162
            f_k_minus_1 = energy.value
163 164 165
            f_k = new_energy.value
            delta = (abs(f_k-f_k_minus_1) /
                     max(abs(f_k), abs(f_k_minus_1), 1.))
166 167
            # check if new energy value is bigger than old energy value
            if (new_energy.value - energy.value) > 0:
168 169
                self.logger.info("Line search algorithm returned a new energy "
                                 "that was larger than the old one. Stopping.")
170 171
                break

172
            energy = new_energy
173
            # check convergence
174
            self.logger.debug("Iteration:%08u "
175
                              "delta=%3.1E energy=%3.1E" %
176
                              (iteration_number, delta,
177
                               energy.value))
178 179
            if delta == 0:
                convergence = self.convergence_level + 2
180 181
                self.logger.info("Found minimum according to line-search. "
                                 "Stopping.")
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
                break
            elif delta < self.convergence_tolerance:
                convergence += 1
                self.logger.info("Updated convergence level to: %u" %
                                 convergence)
                if convergence == self.convergence_level:
                    self.logger.info("Reached target convergence level.")
                    break
            else:
                convergence = max(0, convergence-1)

            if self.iteration_limit is not None:
                if iteration_number == self.iteration_limit:
                    self.logger.warn("Reached iteration limit. Stopping.")
                    break

            iteration_number += 1

200
        return energy, convergence
201 202

    @abc.abstractmethod
203
    def get_descent_direction(self, energy):
204
        raise NotImplementedError