conjugate_gradient.py 5.51 KB
Newer Older
1
2
3
# -*- coding: utf-8 -*-


4
5
6
from __future__ import division
import numpy as np

7
from keepers import Loggable
8

9

theos's avatar
theos committed
10
class ConjugateGradient(Loggable, object):
11
12
    def __init__(self, convergence_tolerance=1E-4, convergence_level=3,
                 iteration_limit=None, reset_count=None,
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
                 preconditioner=None, callback=None):
        """
            Initializes the conjugate_gradient and sets the attributes (except
            for `x`).

            Parameters
            ----------
            A : {operator, function}
                Operator `A` applicable to a field.
            b : field
                Resulting field of the operation `A(x)`.
            W : {operator, function}, *optional*
                Operator `W` that is a preconditioner on `A` and is applicable to a
                field (default: None).
            spam : function, *optional*
                Callback function which is given the current `x` and iteration
                counter each iteration (default: None).
            reset : integer, *optional*
                Number of iterations after which to restart; i.e., forget previous
                conjugated directions (default: sqrt(b.dim)).
            note : bool, *optional*
                Indicates whether notes are printed or not (default: False).

        """
        self.convergence_tolerance = np.float(convergence_tolerance)
        self.convergence_level = np.float(convergence_level)
39
40
41
42
43
44
45
46

        if iteration_limit is not None:
            iteration_limit = int(iteration_limit)
        self.iteration_limit = iteration_limit

        if reset_count is not None:
            reset_count = int(reset_count)
        self.reset_count = reset_count
47
48
49
50
51
52
53

        if preconditioner is None:
            preconditioner = lambda z: z

        self.preconditioner = preconditioner
        self.callback = callback

54
    def __call__(self, A, b, x0):
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        """
            Runs the conjugate gradient minimization.

            Parameters
            ----------
            x0 : field, *optional*
                Starting guess for the minimization.
            tol : scalar, *optional*
                Tolerance specifying convergence; measured by current relative
                residual (default: 1E-4).
            clevel : integer, *optional*
                Number of times the tolerance should be undershot before
                exiting (default: 1).
            limii : integer, *optional*
                Maximum number of iterations performed (default: 10 * b.dim).

            Returns
            -------
            x : field
                Latest `x` of the minimization.
            convergence : integer
                Latest convergence level indicating whether the minimization
                has converged or not.

            """
        r = b - A(x0)
        d = self.preconditioner(r)
        previous_gamma = r.dot(d)
        if previous_gamma == 0:
84
85
            self.logger.info("The starting guess is already perfect solution "
                             "for the inverse problem.")
86
            return x0, self.convergence_level+1
87
        norm_b = np.sqrt(b.dot(b))
88
89
90
        x = x0
        convergence = 0
        iteration_number = 1
91
        self.logger.info("Starting conjugate gradient.")
92

93
        while True:
94
95
96
97
98
99
100
            if self.callback is not None:
                self.callback(x, iteration_number)

            q = A(d)
            alpha = previous_gamma/d.dot(q)

            if not np.isfinite(alpha):
101
                self.logger.error("Alpha became infinite! Stopping.")
102
103
104
105
106
107
                return x0, 0

            x += d * alpha

            reset = False
            if alpha.real < 0:
108
                self.logger.warn("Positive definiteness of A violated!")
109
                reset = True
110
111
112
            if self.reset_count is not None:
                reset += (iteration_number % self.reset_count == 0)
            if reset:
113
                self.logger.info("Resetting conjugate directions.")
114
115
116
117
118
119
120
121
                r = b - A(x)
            else:
                r -= q * alpha

            s = self.preconditioner(r)
            gamma = r.dot(s)

            if gamma.real < 0:
122
123
                self.logger.warn("Positive definitness of preconditioner "
                                 "violated!")
124
125
126

            beta = max(0, gamma/previous_gamma)

127
            delta = np.sqrt(gamma)/norm_b
128

129
130
131
132
133
134
            self.logger.debug("Iteration : %08u   alpha = %3.1E   "
                              "beta = %3.1E   delta = %3.1E" %
                              (iteration_number,
                               np.real(alpha),
                               np.real(beta),
                               np.real(delta)))
135
136
137

            if gamma == 0:
                convergence = self.convergence_level+1
138
                self.logger.info("Reached infinite convergence.")
139
140
141
                break
            elif abs(delta) < self.convergence_tolerance:
                convergence += 1
142
143
                self.logger.info("Updated convergence level to: %u" %
                                 convergence)
144
                if convergence == self.convergence_level:
145
                    self.logger.info("Reached target convergence level.")
146
147
148
149
                    break
            else:
                convergence = max(0, convergence-1)

150
151
            if self.iteration_limit is not None:
                if iteration_number == self.iteration_limit:
152
                    self.logger.warn("Reached iteration limit. Stopping.")
153
154
155
                    break

            d = s + d * beta
156
157
158
159
160

            iteration_number += 1
            previous_gamma = gamma

        return x, convergence