descent_minimizer.py 4.25 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
Martin Reinecke's avatar
Martin Reinecke committed
14
# Copyright(C) 2013-2018 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15 16 17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import division
20
import abc
Martin Reinecke's avatar
Martin Reinecke committed
21
from .minimizer import Minimizer
Martin Reinecke's avatar
Martin Reinecke committed
22
from .line_search_strong_wolfe import LineSearchStrongWolfe
23
from .. import dobj
24 25


Martin Reinecke's avatar
Martin Reinecke committed
26
class DescentMinimizer(Minimizer):
27 28 29 30 31 32 33
    """ A base class used by gradient methods to find a local minimum.

    Descent minimization methods are used to find a local minimum of a scalar
    function by following a descent direction. This class implements the
    minimization procedure once a descent direction is known. The descent
    direction has to be implemented separately.

34 35
    Parameters
    ----------
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
36 37
    controller : IterationController
        Object that decides when to terminate the minimization.
38 39 40 41
    line_searcher : callable *optional*
        Function which infers the step size in the descent direction
        (default : LineSearchStrongWolfe()).
    """
42

Martin Reinecke's avatar
Martin Reinecke committed
43 44 45
    def __init__(self, controller, line_searcher=LineSearchStrongWolfe()):
        super(DescentMinimizer, self).__init__()
        self._controller = controller
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
46
        self.line_searcher = line_searcher
47

48
    def __call__(self, energy):
49
        """ Performs the minimization of the provided Energy functional.
50 51 52

        Parameters
        ----------
Martin Reinecke's avatar
Martin Reinecke committed
53
        energy : Energy
54 55
           Energy object which provides value, gradient and curvature at a
           specific position in parameter space.
56 57 58

        Returns
        -------
Martin Reinecke's avatar
Martin Reinecke committed
59
        Energy
60
            Latest `energy` of the minimization.
Martin Reinecke's avatar
Martin Reinecke committed
61
        int
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
62
            Can be controller.CONVERGED or controller.ERROR
63

64 65
        Notes
        -----
66
        The minimization is stopped if
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
67
            * the controller returns controller.CONVERGED or controller.ERROR,
68 69
            * a perfectly flat point is reached,
            * according to the line-search the minimum is found,
70 71
        """
        f_k_minus_1 = None
Martin Reinecke's avatar
Martin Reinecke committed
72 73 74
        controller = self._controller
        status = controller.start(energy)
        if status != controller.CONTINUE:
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
75
            return energy, status
76 77

        while True:
78
            # check if position is at a flat point
Martin Reinecke's avatar
Martin Reinecke committed
79 80
            if energy.gradient_norm == 0:
                return energy, controller.CONVERGED
81

Martin Reinecke's avatar
Martin Reinecke committed
82
            # compute a step length that reduces energy.value sufficiently
Theo Steininger's avatar
Theo Steininger committed
83
            try:
Martin Reinecke's avatar
Martin Reinecke committed
84 85 86
                new_energy = self.line_searcher.perform_line_search(
                    energy=energy, pk=self.get_descent_direction(energy),
                    f_k_minus_1=f_k_minus_1)
Martin Reinecke's avatar
bug fix  
Martin Reinecke committed
87
            except ValueError:
Martin Reinecke's avatar
Martin Reinecke committed
88
                return energy, controller.ERROR
Theo Steininger's avatar
Theo Steininger committed
89

90
            f_k_minus_1 = energy.value
Martin Reinecke's avatar
Martin Reinecke committed
91 92

            if new_energy.value > energy.value:
93
                dobj.mprint("Error: Energy has increased")
Martin Reinecke's avatar
Martin Reinecke committed
94
                return energy, controller.ERROR
95

96
            if new_energy.value == energy.value:
97 98
                dobj.mprint(
                    "Warning: Energy has not changed. Assuming convergence...")
99 100
                return new_energy, controller.CONVERGED

101
            energy = new_energy
Martin Reinecke's avatar
Martin Reinecke committed
102 103 104 105
            status = self._controller.check(energy)
            if status != controller.CONTINUE:
                return energy, status

106
    @abc.abstractmethod
107
    def get_descent_direction(self, energy):
Martin Reinecke's avatar
Martin Reinecke committed
108 109 110 111 112 113 114 115 116 117 118 119 120
        """ Calculates the next descent direction.

        Parameters
        ----------
        energy : Energy
            An instance of the Energy class which shall be minimized. The
            position of `energy` is used as the starting point of minimization.

        Returns
        -------
        Field
           The descent direction.
        """
121
        raise NotImplementedError