descent_minimizer.py 4.23 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
Martin Reinecke's avatar
Martin Reinecke committed
14
# Copyright(C) 2013-2018 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

19
from __future__ import absolute_import, division, print_function
20
21

from ..compat import *
Martin Reinecke's avatar
Martin Reinecke committed
22
from ..logger import logger
23
24
from .line_search_strong_wolfe import LineSearchStrongWolfe
from .minimizer import Minimizer
25
26


Martin Reinecke's avatar
Martin Reinecke committed
27
class DescentMinimizer(Minimizer):
28
29
30
31
32
33
34
    """ A base class used by gradient methods to find a local minimum.

    Descent minimization methods are used to find a local minimum of a scalar
    function by following a descent direction. This class implements the
    minimization procedure once a descent direction is known. The descent
    direction has to be implemented separately.

35
36
    Parameters
    ----------
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
37
38
    controller : IterationController
        Object that decides when to terminate the minimization.
39
40
41
42
    line_searcher : callable *optional*
        Function which infers the step size in the descent direction
        (default : LineSearchStrongWolfe()).
    """
43

Martin Reinecke's avatar
Martin Reinecke committed
44
45
    def __init__(self, controller, line_searcher=LineSearchStrongWolfe()):
        self._controller = controller
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
46
        self.line_searcher = line_searcher
47

48
    def __call__(self, energy):
49
        """ Performs the minimization of the provided Energy functional.
50
51
52

        Parameters
        ----------
Martin Reinecke's avatar
Martin Reinecke committed
53
        energy : Energy
Martin Reinecke's avatar
Martin Reinecke committed
54
           Energy object which provides value, gradient and metric at a
55
           specific position in parameter space.
56
57
58

        Returns
        -------
Martin Reinecke's avatar
Martin Reinecke committed
59
        Energy
60
            Latest `energy` of the minimization.
Martin Reinecke's avatar
Martin Reinecke committed
61
        int
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
62
            Can be controller.CONVERGED or controller.ERROR
63

64
65
        Notes
        -----
66
        The minimization is stopped if
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
67
            * the controller returns controller.CONVERGED or controller.ERROR,
68
69
            * a perfectly flat point is reached,
            * according to the line-search the minimum is found,
70
71
        """
        f_k_minus_1 = None
Martin Reinecke's avatar
Martin Reinecke committed
72
73
74
        controller = self._controller
        status = controller.start(energy)
        if status != controller.CONTINUE:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
75
            return energy, status
76
77

        while True:
78
            # check if position is at a flat point
Martin Reinecke's avatar
Martin Reinecke committed
79
80
            if energy.gradient_norm == 0:
                return energy, controller.CONVERGED
81

Martin Reinecke's avatar
Martin Reinecke committed
82
            # compute a step length that reduces energy.value sufficiently
Martin Reinecke's avatar
Martin Reinecke committed
83
84
85
            new_energy, success = self.line_searcher.perform_line_search(
                energy=energy, pk=self.get_descent_direction(energy),
                f_k_minus_1=f_k_minus_1)
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
86
87
            if not success:
                self.reset()
Theo Steininger's avatar
Theo Steininger committed
88

89
            f_k_minus_1 = energy.value
Martin Reinecke's avatar
Martin Reinecke committed
90
91

            if new_energy.value > energy.value:
Martin Reinecke's avatar
Martin Reinecke committed
92
                logger.error("Error: Energy has increased")
Martin Reinecke's avatar
Martin Reinecke committed
93
                return energy, controller.ERROR
94

95
            if new_energy.value == energy.value:
Martin Reinecke's avatar
Martin Reinecke committed
96
                logger.warning(
97
                    "Warning: Energy has not changed. Assuming convergence...")
98
99
                return new_energy, controller.CONVERGED

100
            energy = new_energy
Martin Reinecke's avatar
Martin Reinecke committed
101
102
103
104
            status = self._controller.check(energy)
            if status != controller.CONTINUE:
                return energy, status

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
105
106
107
    def reset(self):
        pass

108
    def get_descent_direction(self, energy):
Martin Reinecke's avatar
Martin Reinecke committed
109
110
111
112
113
114
115
116
117
118
119
120
121
        """ Calculates the next descent direction.

        Parameters
        ----------
        energy : Energy
            An instance of the Energy class which shall be minimized. The
            position of `energy` is used as the starting point of minimization.

        Returns
        -------
        Field
           The descent direction.
        """
122
        raise NotImplementedError