conjugate_gradient.py 7.82 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13 14 15 16 17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

19 20 21
from __future__ import division
import numpy as np

22
from keepers import Loggable
23

24

25
class ConjugateGradient(Loggable, object):
26 27
    """ Implementation of the Conjugate Gradient scheme.

28 29
    It is an iterative method for solving a linear system of equations:
                                    Ax = b
30

31 32
    Parameters
    ----------
33 34 35 36 37
    convergence_tolerance : float *optional*
        Tolerance specifying the case of convergence. (default: 1E-4)
    convergence_level : integer *optional*
        Number of times the tolerance must be undershot before convergence
        is reached. (default: 3)
38
    iteration_limit : integer *optional*
39 40
        Maximum number of iterations performed (default: None).
    reset_count : integer *optional*
41
        Number of iterations after which to restart; i.e., forget previous
42 43 44 45 46 47 48 49
        conjugated directions (default: None).
    preconditioner : Operator *optional*
        This operator can be provided which transforms the variables of the
        system to improve the conditioning (default: None).
    callback : callable *optional*
        Function f(energy, iteration_number) supplied by the user to perform
        in-situ analysis at every iteration step. When being called the
        current energy and iteration_number are passed. (default: None)
50 51 52 53

    Attributes
    ----------
    convergence_tolerance : float
54 55 56 57
        Tolerance specifying the case of convergence.
    convergence_level : integer
        Number of times the tolerance must be undershot before convergence
        is reached. (default: 3)
58 59 60 61 62 63
    iteration_limit : integer
        Maximum number of iterations performed.
    reset_count : integer
        Number of iterations after which to restart; i.e., forget previous
        conjugated directions.
    preconditioner : function
64 65 66 67 68 69 70 71 72
        This operator can be provided which transforms the variables of the
        system to improve the conditioning (default: None).
    callback : callable
        Function f(energy, iteration_number) supplied by the user to perform
        in-situ analysis at every iteration step. When being called the
        current energy and iteration_number are passed. (default: None)

    References
    ----------
73
    Jorge Nocedal & Stephen Wright, "Numerical Optimization", Second Edition,
74 75 76 77
    2006, Springer-Verlag New York

    """

Martin Reinecke's avatar
Martin Reinecke committed
78
    def __init__(self, convergence_tolerance=1E-4, convergence_level=3,
79
                 iteration_limit=None, reset_count=None,
80 81 82 83
                 preconditioner=None, callback=None):

        self.convergence_tolerance = np.float(convergence_tolerance)
        self.convergence_level = np.float(convergence_level)
84 85 86 87 88 89 90 91

        if iteration_limit is not None:
            iteration_limit = int(iteration_limit)
        self.iteration_limit = iteration_limit

        if reset_count is not None:
            reset_count = int(reset_count)
        self.reset_count = reset_count
92 93 94 95 96 97 98

        if preconditioner is None:
            preconditioner = lambda z: z

        self.preconditioner = preconditioner
        self.callback = callback

99
    def __call__(self, A, b, x0):
100 101
        """ Runs the conjugate gradient minimization.
        For `Ax = b` the variable `x` is infered.
102 103 104 105 106 107

        Parameters
        ----------
        A : Operator
            Operator `A` applicable to a Field.
        b : Field
108
            Result of the operation `A(x)`.
109 110 111 112 113 114 115 116 117 118 119
        x0 : Field
            Starting guess for the minimization.

        Returns
        -------
        x : Field
            Latest `x` of the minimization.
        convergence : integer
            Latest convergence level indicating whether the minimization
            has converged or not.

120
        """
121

122 123
        r = b - A(x0)
        d = self.preconditioner(r)
Martin Reinecke's avatar
Martin Reinecke committed
124
        previous_gamma = (r.vdot(d)).real
125
        if previous_gamma == 0:
126 127
            self.logger.info("The starting guess is already perfect solution "
                             "for the inverse problem.")
128
            return x0, self.convergence_level+1
Martin Reinecke's avatar
Martin Reinecke committed
129 130
        norm_b = np.sqrt((b.vdot(b)).real)
        x = x0.copy()
131 132
        convergence = 0
        iteration_number = 1
133
        self.logger.info("Starting conjugate gradient.")
134

Theo Steininger's avatar
Theo Steininger committed
135 136 137
        beta = np.inf
        delta = np.inf

138
        while True:
139 140 141 142
            if self.callback is not None:
                self.callback(x, iteration_number)

            q = A(d)
Martin Reinecke's avatar
Martin Reinecke committed
143
            alpha = previous_gamma/d.vdot(q).real
144 145

            if not np.isfinite(alpha):
Theo Steininger's avatar
Theo Steininger committed
146 147 148 149
                self.logger.error(
                        "Alpha became infinite! Stopping. Iteration : %08u   "
                        "alpha = %3.1E   beta = %3.1E   delta = %3.1E" %
                        (iteration_number, alpha, beta, delta))
150 151 152 153 154
                return x0, 0

            x += d * alpha

            reset = False
Martin Reinecke's avatar
Martin Reinecke committed
155
            if alpha < 0:
156
                self.logger.warn("Positive definiteness of A violated!")
157
                reset = True
158 159 160
            if self.reset_count is not None:
                reset += (iteration_number % self.reset_count == 0)
            if reset:
161
                self.logger.info("Computing accurate residuum.")
162 163 164 165 166
                r = b - A(x)
            else:
                r -= q * alpha

            s = self.preconditioner(r)
Martin Reinecke's avatar
Martin Reinecke committed
167
            gamma = r.vdot(s).real
168

Martin Reinecke's avatar
Martin Reinecke committed
169
            if gamma < 0:
170 171
                self.logger.warn("Positive definitness of preconditioner "
                                 "violated!")
172 173 174

            beta = max(0, gamma/previous_gamma)

175
            delta = np.sqrt(gamma)/norm_b
176

177 178
            self.logger.debug("Iteration : %08u   alpha = %3.1E   "
                              "beta = %3.1E   delta = %3.1E" %
Martin Reinecke's avatar
Martin Reinecke committed
179
                              (iteration_number, alpha, beta, delta))
180 181 182

            if gamma == 0:
                convergence = self.convergence_level+1
Theo Steininger's avatar
Theo Steininger committed
183 184 185 186
                self.logger.info(
                        "Reached infinite convergence. Iteration : %08u   "
                        "alpha = %3.1E   beta = %3.1E   delta = %3.1E" %
                        (iteration_number, alpha, beta, delta))
187 188 189
                break
            elif abs(delta) < self.convergence_tolerance:
                convergence += 1
190 191
                self.logger.info("Updated convergence level to: %u" %
                                 convergence)
192
                if convergence == self.convergence_level:
Theo Steininger's avatar
Theo Steininger committed
193 194 195 196
                    self.logger.info(
                        "Reached target convergence level. Iteration : %08u   "
                        "alpha = %3.1E   beta = %3.1E   delta = %3.1E" %
                        (iteration_number, alpha, beta, delta))
197 198 199 200
                    break
            else:
                convergence = max(0, convergence-1)

201 202
            if self.iteration_limit is not None:
                if iteration_number == self.iteration_limit:
Theo Steininger's avatar
Theo Steininger committed
203 204 205 206
                    self.logger.info(
                        "Reached iteration limit. Iteration : %08u   "
                        "alpha = %3.1E   beta = %3.1E   delta = %3.1E" %
                        (iteration_number, alpha, beta, delta))
207 208 209
                    break

            d = s + d * beta
210 211 212 213 214

            iteration_number += 1
            previous_gamma = gamma

        return x, convergence