conjugate_gradient.py 7.05 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# NIFTy
# Copyright (C) 2017  Theo Steininger
#
# Author: Theo Steininger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18

19
20
21
from __future__ import division
import numpy as np

22
from keepers import Loggable
23

24

theos's avatar
theos committed
25
class ConjugateGradient(Loggable, object):
26
27
    """ Implementation of the Conjugate Gradient scheme.

28
29
    It is an iterative method for solving a linear system of equations:
                                    Ax = b
30

31
32
    Parameters
    ----------
33
34
35
36
37
    convergence_tolerance : float *optional*
        Tolerance specifying the case of convergence. (default: 1E-4)
    convergence_level : integer *optional*
        Number of times the tolerance must be undershot before convergence
        is reached. (default: 3)
38
    iteration_limit : integer *optional*
39
40
        Maximum number of iterations performed (default: None).
    reset_count : integer *optional*
41
        Number of iterations after which to restart; i.e., forget previous
42
43
44
45
46
47
48
49
        conjugated directions (default: None).
    preconditioner : Operator *optional*
        This operator can be provided which transforms the variables of the
        system to improve the conditioning (default: None).
    callback : callable *optional*
        Function f(energy, iteration_number) supplied by the user to perform
        in-situ analysis at every iteration step. When being called the
        current energy and iteration_number are passed. (default: None)
50
51
52
53

    Attributes
    ----------
    convergence_tolerance : float
54
55
56
57
        Tolerance specifying the case of convergence.
    convergence_level : integer
        Number of times the tolerance must be undershot before convergence
        is reached. (default: 3)
58
59
60
61
62
63
    iteration_limit : integer
        Maximum number of iterations performed.
    reset_count : integer
        Number of iterations after which to restart; i.e., forget previous
        conjugated directions.
    preconditioner : function
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        This operator can be provided which transforms the variables of the
        system to improve the conditioning (default: None).
    callback : callable
        Function f(energy, iteration_number) supplied by the user to perform
        in-situ analysis at every iteration step. When being called the
        current energy and iteration_number are passed. (default: None)

    References
    ----------
    Thomas V. Mikosch et al., "Numerical Optimization", Second Edition,
    2006, Springer-Verlag New York

    """

78
79
    def __init__(self, convergence_tolerance=1E-4, convergence_level=3,
                 iteration_limit=None, reset_count=None,
80
81
82
83
                 preconditioner=None, callback=None):

        self.convergence_tolerance = np.float(convergence_tolerance)
        self.convergence_level = np.float(convergence_level)
84
85
86
87
88
89
90
91

        if iteration_limit is not None:
            iteration_limit = int(iteration_limit)
        self.iteration_limit = iteration_limit

        if reset_count is not None:
            reset_count = int(reset_count)
        self.reset_count = reset_count
92
93
94
95
96
97
98

        if preconditioner is None:
            preconditioner = lambda z: z

        self.preconditioner = preconditioner
        self.callback = callback

99
    def __call__(self, A, b, x0):
100
101
        """ Runs the conjugate gradient minimization.
        For `Ax = b` the variable `x` is infered.
102
103
104
105
106
107

        Parameters
        ----------
        A : Operator
            Operator `A` applicable to a Field.
        b : Field
108
            Result of the operation `A(x)`.
109
110
111
112
113
114
115
116
117
118
119
        x0 : Field
            Starting guess for the minimization.

        Returns
        -------
        x : Field
            Latest `x` of the minimization.
        convergence : integer
            Latest convergence level indicating whether the minimization
            has converged or not.

120
        """
121

122
123
124
125
        r = b - A(x0)
        d = self.preconditioner(r)
        previous_gamma = r.dot(d)
        if previous_gamma == 0:
126
127
            self.logger.info("The starting guess is already perfect solution "
                             "for the inverse problem.")
128
            return x0, self.convergence_level+1
129
        norm_b = np.sqrt(b.dot(b))
130
131
132
        x = x0
        convergence = 0
        iteration_number = 1
133
        self.logger.info("Starting conjugate gradient.")
134

135
        while True:
136
137
138
139
140
141
142
            if self.callback is not None:
                self.callback(x, iteration_number)

            q = A(d)
            alpha = previous_gamma/d.dot(q)

            if not np.isfinite(alpha):
143
                self.logger.error("Alpha became infinite! Stopping.")
144
145
146
147
148
149
                return x0, 0

            x += d * alpha

            reset = False
            if alpha.real < 0:
150
                self.logger.warn("Positive definiteness of A violated!")
151
                reset = True
152
153
154
            if self.reset_count is not None:
                reset += (iteration_number % self.reset_count == 0)
            if reset:
155
                self.logger.info("Resetting conjugate directions.")
156
157
158
159
160
161
162
163
                r = b - A(x)
            else:
                r -= q * alpha

            s = self.preconditioner(r)
            gamma = r.dot(s)

            if gamma.real < 0:
164
165
                self.logger.warn("Positive definitness of preconditioner "
                                 "violated!")
166
167
168

            beta = max(0, gamma/previous_gamma)

169
            delta = np.sqrt(gamma)/norm_b
170

171
172
173
174
175
176
            self.logger.debug("Iteration : %08u   alpha = %3.1E   "
                              "beta = %3.1E   delta = %3.1E" %
                              (iteration_number,
                               np.real(alpha),
                               np.real(beta),
                               np.real(delta)))
177
178
179

            if gamma == 0:
                convergence = self.convergence_level+1
180
                self.logger.info("Reached infinite convergence.")
181
182
183
                break
            elif abs(delta) < self.convergence_tolerance:
                convergence += 1
184
185
                self.logger.info("Updated convergence level to: %u" %
                                 convergence)
186
                if convergence == self.convergence_level:
187
                    self.logger.info("Reached target convergence level.")
188
189
190
191
                    break
            else:
                convergence = max(0, convergence-1)

192
193
            if self.iteration_limit is not None:
                if iteration_number == self.iteration_limit:
194
                    self.logger.warn("Reached iteration limit. Stopping.")
195
196
197
                    break

            d = s + d * beta
198
199
200
201
202

            iteration_number += 1
            previous_gamma = gamma

        return x, convergence