conjugate_gradient.py 3.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
14
15
16
17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

19
from __future__ import division
Martin Reinecke's avatar
Martin Reinecke committed
20
from .minimizer import Minimizer
Martin Reinecke's avatar
Martin Reinecke committed
21
22
from ..field import Field
from .. import dobj
Martin Reinecke's avatar
Martin Reinecke committed
23
from ..utilities import general_axpy
24

Martin Reinecke's avatar
Martin Reinecke committed
25

Martin Reinecke's avatar
Martin Reinecke committed
26
class ConjugateGradient(Minimizer):
27
28
    """ Implementation of the Conjugate Gradient scheme.

29
30
    It is an iterative method for solving a linear system of equations:
                                    Ax = b
31

32
33
    Parameters
    ----------
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
34
35
    controller : IterationController
        Object that decides when to terminate the minimization.
36

37
38
    References
    ----------
39
    Jorge Nocedal & Stephen Wright, "Numerical Optimization", Second Edition,
40
41
42
    2006, Springer-Verlag New York
    """

43
    def __init__(self, controller):
Martin Reinecke's avatar
Martin Reinecke committed
44
        self._controller = controller
45

46
    def __call__(self, energy, preconditioner=None):
47
        """ Runs the conjugate gradient minimization.
48
49
50

        Parameters
        ----------
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
51
52
        energy : Energy object at the starting point of the iteration.
            Its curvature operator must be independent of position, otherwise
Martin Reinecke's avatar
Martin Reinecke committed
53
            linear conjugate gradient minimization will fail.
54
55
56
        preconditioner : Operator *optional*
            This operator can be provided which transforms the variables of the
            system to improve the conditioning (default: None).
57
58
59

        Returns
        -------
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
60
61
62
63
        energy : QuadraticEnergy
            state at last point of the iteration
        status : integer
            Can be controller.CONVERGED or controller.ERROR
64
        """
Martin Reinecke's avatar
Martin Reinecke committed
65
        controller = self._controller
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
66
        status = controller.start(energy)
Martin Reinecke's avatar
Martin Reinecke committed
67
        if status != controller.CONTINUE:
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
68
            return energy, status
Martin Reinecke's avatar
Martin Reinecke committed
69

70
        r = energy.gradient
71
72
        if preconditioner is not None:
            d = preconditioner(r)
Martin Reinecke's avatar
Martin Reinecke committed
73
        else:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
74
            d = r.copy()
Martin Reinecke's avatar
Martin Reinecke committed
75
        previous_gamma = (r.vdot(d)).real
76
        if previous_gamma == 0:
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
77
            return energy, controller.CONVERGED
78

Martin Reinecke's avatar
Martin Reinecke committed
79
        tpos = Field(d.domain, dtype=d.dtype)  # temporary buffer
80
        while True:
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
81
82
            q = energy.curvature(d)
            ddotq = d.vdot(q).real
Martin Reinecke's avatar
Martin Reinecke committed
83
            if ddotq == 0.:
84
                dobj.mprint("Error: ConjugateGradient: ddotq==0.")
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
85
86
                return energy, controller.ERROR
            alpha = previous_gamma/ddotq
87

Martin Reinecke's avatar
Martin Reinecke committed
88
            if alpha < 0:
89
                dobj.mprint("Error: ConjugateGradient: alpha<0.")
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
90
91
                return energy, controller.ERROR

92
            general_axpy(-alpha, q, r, out=r)
Martin Reinecke's avatar
Martin Reinecke committed
93

94
95
            general_axpy(-alpha, d, energy.position, out=tpos)
            energy = energy.at_with_grad(tpos, r)
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
96

97
98
            if preconditioner is not None:
                s = preconditioner(r)
Martin Reinecke's avatar
Martin Reinecke committed
99
            else:
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
100
                s = r
101

Martin Reinecke's avatar
Martin Reinecke committed
102
            gamma = r.vdot(s).real
Martin Reinecke's avatar
Martin Reinecke committed
103
            if gamma < 0:
104
                dobj.mprint(
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
105
                    "Positive definiteness of preconditioner violated!")
106
                return energy, controller.ERROR
107
            if gamma == 0:
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
108
                return energy, controller.CONVERGED
109

110
            status = self._controller.check(energy)
111
112
113
            if status != controller.CONTINUE:
                return energy, status

114
            general_axpy(max(0, gamma/previous_gamma), d, s, out=d)
115
116

            previous_gamma = gamma