conjugate_gradient.py 3.91 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13 14 15 16 17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

19
from __future__ import division
Martin Reinecke's avatar
Martin Reinecke committed
20
from .minimizer import Minimizer
21
from .. import Field, dobj
Martin Reinecke's avatar
Martin Reinecke committed
22
from ..utilities import general_axpy
23

Martin Reinecke's avatar
Martin Reinecke committed
24

Martin Reinecke's avatar
Martin Reinecke committed
25
class ConjugateGradient(Minimizer):
26 27
    """ Implementation of the Conjugate Gradient scheme.

28 29
    It is an iterative method for solving a linear system of equations:
                                    Ax = b
30

31 32
    Parameters
    ----------
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
33 34
    controller : IterationController
        Object that decides when to terminate the minimization.
35

36 37
    References
    ----------
38
    Jorge Nocedal & Stephen Wright, "Numerical Optimization", Second Edition,
39 40 41
    2006, Springer-Verlag New York
    """

42
    def __init__(self, controller):
Martin Reinecke's avatar
Martin Reinecke committed
43
        self._controller = controller
44

45
    def __call__(self, energy, preconditioner=None):
46
        """ Runs the conjugate gradient minimization.
47 48 49

        Parameters
        ----------
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
50 51
        energy : Energy object at the starting point of the iteration.
            Its curvature operator must be independent of position, otherwise
Martin Reinecke's avatar
Martin Reinecke committed
52
            linear conjugate gradient minimization will fail.
53 54 55
        preconditioner : Operator *optional*
            This operator can be provided which transforms the variables of the
            system to improve the conditioning (default: None).
56 57 58

        Returns
        -------
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
59 60 61 62
        energy : QuadraticEnergy
            state at last point of the iteration
        status : integer
            Can be controller.CONVERGED or controller.ERROR
63
        """
Martin Reinecke's avatar
Martin Reinecke committed
64
        controller = self._controller
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
65
        status = controller.start(energy)
Martin Reinecke's avatar
Martin Reinecke committed
66
        if status != controller.CONTINUE:
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
67
            return energy, status
Martin Reinecke's avatar
Martin Reinecke committed
68

69
        r = energy.gradient
70 71
        if preconditioner is not None:
            d = preconditioner(r)
Martin Reinecke's avatar
Martin Reinecke committed
72
        else:
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
73
            d = r.copy()
Martin Reinecke's avatar
Martin Reinecke committed
74
        previous_gamma = (r.vdot(d)).real
75
        if previous_gamma == 0:
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
76
            return energy, controller.CONVERGED
77

Martin Reinecke's avatar
Martin Reinecke committed
78
        tpos = Field(d.domain, dtype=d.dtype)  # temporary buffer
79
        while True:
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
80 81
            q = energy.curvature(d)
            ddotq = d.vdot(q).real
Martin Reinecke's avatar
Martin Reinecke committed
82
            if ddotq == 0.:
83
                dobj.mprint("Error: ConjugateGradient: ddotq==0.")
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
84 85
                return energy, controller.ERROR
            alpha = previous_gamma/ddotq
86

Martin Reinecke's avatar
Martin Reinecke committed
87
            if alpha < 0:
88
                dobj.mprint("Error: ConjugateGradient: alpha<0.")
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
89 90
                return energy, controller.ERROR

91
            general_axpy(-alpha, q, r, out=r)
Martin Reinecke's avatar
Martin Reinecke committed
92

93 94
            general_axpy(-alpha, d, energy.position, out=tpos)
            energy = energy.at_with_grad(tpos, r)
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
95

96 97
            if preconditioner is not None:
                s = preconditioner(r)
Martin Reinecke's avatar
Martin Reinecke committed
98
            else:
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
99
                s = r
100

Martin Reinecke's avatar
Martin Reinecke committed
101
            gamma = r.vdot(s).real
Martin Reinecke's avatar
Martin Reinecke committed
102
            if gamma < 0:
103
                dobj.mprint(
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
104
                    "Positive definiteness of preconditioner violated!")
105
                return energy, controller.ERROR
106
            if gamma == 0:
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
107
                return energy, controller.CONVERGED
108

109
            status = self._controller.check(energy)
110 111 112
            if status != controller.CONTINUE:
                return energy, status

113
            general_axpy(max(0, gamma/previous_gamma), d, s, out=d)
114 115

            previous_gamma = gamma