conjugate_gradient.py 6.97 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# NIFTy
# Copyright (C) 2017  Theo Steininger
#
# Author: Theo Steininger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18

19
20
21
from __future__ import division
import numpy as np

22
from keepers import Loggable
23

24

25
class ConjugateGradient(Loggable, object):
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    """Implementation of the Conjugate Gradient scheme.
    
    It is an iterative method for solving a linear system of equations:
                                    Ax = b
    
    SUGESTED LITERATURE:
        Thomas V. Mikosch et al., "Numerical Optimization", Second Edition, 
        2006, Springer-Verlag New York
        
    Parameters
    ----------
    convergence_tolerance : scalar
        Tolerance specifying convergence. (default: 1E-4)
    convergence_level : integer
        Number of times the tolerance should be undershot before exiting. 
        (default: 3)
    iteration_limit : integer *optional*
        Maximum number of iterations performed. (default: None)
    reset_count : integer, *optional*
        Number of iterations after which to restart; i.e., forget previous
        conjugated directions. (default: None)
    preconditioner : function *optional*
        The user can provide a function which transforms the variables of the 
        system to make the convarge more favorable.(default: None)
    callback : function, *optional*
        Function f(energy, iteration_number) specified by the user to print 
        iteration number and energy value at every iteration step. It accepts 
        an Energy object(energy) and integer(iteration_number). (default: None)

    Attributes
    ----------
    convergence_tolerance : float
        Tolerance specifying convergence.
    convergence_level : float
        Number of times the tolerance should be undershot before exiting.
    iteration_limit : integer
        Maximum number of iterations performed.
    reset_count : integer
        Number of iterations after which to restart; i.e., forget previous
        conjugated directions.
    preconditioner : function
        The user can provide a function which transforms the variables of the 
        system to make the convarge more favorable.
    callback : function
        Function f(energy, iteration_number) specified by the user to print 
        iteration number and energy value at every iteration step. It accepts 
        an Energy object(energy) and integer(iteration_number).
    
    """    
    
76
77
    def __init__(self, convergence_tolerance=1E-4, convergence_level=3,
                 iteration_limit=None, reset_count=None,
78
79
80
81
                 preconditioner=None, callback=None):

        self.convergence_tolerance = np.float(convergence_tolerance)
        self.convergence_level = np.float(convergence_level)
82
83
84
85
86
87
88
89

        if iteration_limit is not None:
            iteration_limit = int(iteration_limit)
        self.iteration_limit = iteration_limit

        if reset_count is not None:
            reset_count = int(reset_count)
        self.reset_count = reset_count
90
91
92
93
94
95
96

        if preconditioner is None:
            preconditioner = lambda z: z

        self.preconditioner = preconditioner
        self.callback = callback

97
    def __call__(self, A, b, x0):
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        """Runs the conjugate gradient minimization.

        Parameters
        ----------
        A : Operator
            Operator `A` applicable to a Field.
        b : Field
            Resulting Field of the operation `A(x)`.
        x0 : Field
            Starting guess for the minimization.

        Returns
        -------
        x : Field
            Latest `x` of the minimization.
        convergence : integer
            Latest convergence level indicating whether the minimization
            has converged or not.

117
118
119
120
121
        """
        r = b - A(x0)
        d = self.preconditioner(r)
        previous_gamma = r.dot(d)
        if previous_gamma == 0:
122
123
            self.logger.info("The starting guess is already perfect solution "
                             "for the inverse problem.")
124
            return x0, self.convergence_level+1
125
        norm_b = np.sqrt(b.dot(b))
126
127
128
        x = x0
        convergence = 0
        iteration_number = 1
129
        self.logger.info("Starting conjugate gradient.")
130

131
        while True:
132
133
134
135
136
137
138
            if self.callback is not None:
                self.callback(x, iteration_number)

            q = A(d)
            alpha = previous_gamma/d.dot(q)

            if not np.isfinite(alpha):
139
                self.logger.error("Alpha became infinite! Stopping.")
140
141
142
143
144
145
                return x0, 0

            x += d * alpha

            reset = False
            if alpha.real < 0:
146
                self.logger.warn("Positive definiteness of A violated!")
147
                reset = True
148
149
150
            if self.reset_count is not None:
                reset += (iteration_number % self.reset_count == 0)
            if reset:
151
                self.logger.info("Resetting conjugate directions.")
152
153
154
155
156
157
158
159
                r = b - A(x)
            else:
                r -= q * alpha

            s = self.preconditioner(r)
            gamma = r.dot(s)

            if gamma.real < 0:
160
161
                self.logger.warn("Positive definitness of preconditioner "
                                 "violated!")
162
163
164

            beta = max(0, gamma/previous_gamma)

165
            delta = np.sqrt(gamma)/norm_b
166

167
168
169
170
171
172
            self.logger.debug("Iteration : %08u   alpha = %3.1E   "
                              "beta = %3.1E   delta = %3.1E" %
                              (iteration_number,
                               np.real(alpha),
                               np.real(beta),
                               np.real(delta)))
173
174
175

            if gamma == 0:
                convergence = self.convergence_level+1
176
                self.logger.info("Reached infinite convergence.")
177
178
179
                break
            elif abs(delta) < self.convergence_tolerance:
                convergence += 1
180
181
                self.logger.info("Updated convergence level to: %u" %
                                 convergence)
182
                if convergence == self.convergence_level:
183
                    self.logger.info("Reached target convergence level.")
184
185
186
187
                    break
            else:
                convergence = max(0, convergence-1)

188
189
            if self.iteration_limit is not None:
                if iteration_number == self.iteration_limit:
190
                    self.logger.warn("Reached iteration limit. Stopping.")
191
192
193
                    break

            d = s + d * beta
194
195
196
197
198

            iteration_number += 1
            previous_gamma = gamma

        return x, convergence