conjugate_gradient.py 7.31 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
14
15
16
17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

19
20
21
from __future__ import division
import numpy as np

22
from keepers import Loggable
Martin Reinecke's avatar
Martin Reinecke committed
23
from nifty import Field
24

25

theos's avatar
theos committed
26
class ConjugateGradient(Loggable, object):
27
28
    """ Implementation of the Conjugate Gradient scheme.

29
30
    It is an iterative method for solving a linear system of equations:
                                    Ax = b
31

32
33
    Parameters
    ----------
34
35
36
37
38
    convergence_tolerance : float *optional*
        Tolerance specifying the case of convergence. (default: 1E-4)
    convergence_level : integer *optional*
        Number of times the tolerance must be undershot before convergence
        is reached. (default: 3)
39
    iteration_limit : integer *optional*
40
41
        Maximum number of iterations performed (default: None).
    reset_count : integer *optional*
42
        Number of iterations after which to restart; i.e., forget previous
43
44
45
46
47
48
49
50
        conjugated directions (default: None).
    preconditioner : Operator *optional*
        This operator can be provided which transforms the variables of the
        system to improve the conditioning (default: None).
    callback : callable *optional*
        Function f(energy, iteration_number) supplied by the user to perform
        in-situ analysis at every iteration step. When being called the
        current energy and iteration_number are passed. (default: None)
51
52
53
54

    Attributes
    ----------
    convergence_tolerance : float
55
56
57
58
        Tolerance specifying the case of convergence.
    convergence_level : integer
        Number of times the tolerance must be undershot before convergence
        is reached. (default: 3)
59
60
61
62
63
64
    iteration_limit : integer
        Maximum number of iterations performed.
    reset_count : integer
        Number of iterations after which to restart; i.e., forget previous
        conjugated directions.
    preconditioner : function
65
66
67
68
69
70
71
72
73
        This operator can be provided which transforms the variables of the
        system to improve the conditioning (default: None).
    callback : callable
        Function f(energy, iteration_number) supplied by the user to perform
        in-situ analysis at every iteration step. When being called the
        current energy and iteration_number are passed. (default: None)

    References
    ----------
74
    Jorge Nocedal & Stephen Wright, "Numerical Optimization", Second Edition,
75
76
77
78
    2006, Springer-Verlag New York

    """

Martin Reinecke's avatar
Martin Reinecke committed
79
    def __init__(self, convergence_tolerance=1E-8, convergence_level=3,
80
                 iteration_limit=None, reset_count=None,
81
82
83
84
                 preconditioner=None, callback=None):

        self.convergence_tolerance = np.float(convergence_tolerance)
        self.convergence_level = np.float(convergence_level)
85
86
87
88
89
90
91
92

        if iteration_limit is not None:
            iteration_limit = int(iteration_limit)
        self.iteration_limit = iteration_limit

        if reset_count is not None:
            reset_count = int(reset_count)
        self.reset_count = reset_count
93
94
95
96
97
98
99

        if preconditioner is None:
            preconditioner = lambda z: z

        self.preconditioner = preconditioner
        self.callback = callback

100
    def __call__(self, A, b, x0):
101
102
        """ Runs the conjugate gradient minimization.
        For `Ax = b` the variable `x` is infered.
103
104
105
106
107
108

        Parameters
        ----------
        A : Operator
            Operator `A` applicable to a Field.
        b : Field
109
            Result of the operation `A(x)`.
110
111
112
113
114
115
116
117
118
119
120
        x0 : Field
            Starting guess for the minimization.

        Returns
        -------
        x : Field
            Latest `x` of the minimization.
        convergence : integer
            Latest convergence level indicating whether the minimization
            has converged or not.

121
        """
122

123
124
        r = b - A(x0)
        d = self.preconditioner(r)
Martin Reinecke's avatar
Martin Reinecke committed
125
        previous_gamma = (r.vdot(d)).real
126
        if previous_gamma == 0:
127
128
            self.logger.info("The starting guess is already perfect solution "
                             "for the inverse problem.")
129
            return x0, self.convergence_level+1
Martin Reinecke's avatar
Martin Reinecke committed
130
131
        norm_b = np.sqrt((b.vdot(b)).real)
        x = x0.copy()
132
133
        convergence = 0
        iteration_number = 1
134
        self.logger.info("Starting conjugate gradient.")
135

136
        while True:
137
138
139
140
            if self.callback is not None:
                self.callback(x, iteration_number)

            q = A(d)
Martin Reinecke's avatar
Martin Reinecke committed
141
            alpha = previous_gamma/d.vdot(q).real
142
143

            if not np.isfinite(alpha):
144
                self.logger.error("Alpha became infinite! Stopping.")
145
146
147
148
149
                return x0, 0

            x += d * alpha

            reset = False
Martin Reinecke's avatar
Martin Reinecke committed
150
            if alpha < 0:
151
                self.logger.warn("Positive definiteness of A violated!")
152
                reset = True
153
154
155
            if self.reset_count is not None:
                reset += (iteration_number % self.reset_count == 0)
            if reset:
156
                self.logger.info("Resetting conjugate directions.")
157
158
159
                r = b - A(x)
            else:
                r -= q * alpha
Martin Reinecke's avatar
Martin Reinecke committed
160
161
162
            #tmp=r.val.get_full_data()
            #tmp.imag=0.
            #r=Field(r.domain,val=tmp)
163
164

            s = self.preconditioner(r)
Martin Reinecke's avatar
Martin Reinecke committed
165
            gamma = r.vdot(s).real
166

Martin Reinecke's avatar
Martin Reinecke committed
167
            if gamma < 0:
168
169
                self.logger.warn("Positive definitness of preconditioner "
                                 "violated!")
170
171

            beta = max(0, gamma/previous_gamma)
Martin Reinecke's avatar
Martin Reinecke committed
172
            print "beta:",beta
173

174
            delta = np.sqrt(gamma)/norm_b
Martin Reinecke's avatar
Martin Reinecke committed
175
            print "delta:",delta
176

177
178
            self.logger.debug("Iteration : %08u   alpha = %3.1E   "
                              "beta = %3.1E   delta = %3.1E" %
Martin Reinecke's avatar
Martin Reinecke committed
179
                              (iteration_number, alpha, beta, delta))
180
181
182

            if gamma == 0:
                convergence = self.convergence_level+1
183
                self.logger.info("Reached infinite convergence.")
184
185
186
                break
            elif abs(delta) < self.convergence_tolerance:
                convergence += 1
187
188
                self.logger.info("Updated convergence level to: %u" %
                                 convergence)
189
                if convergence == self.convergence_level:
190
                    self.logger.info("Reached target convergence level.")
191
192
193
194
                    break
            else:
                convergence = max(0, convergence-1)

195
196
            if self.iteration_limit is not None:
                if iteration_number == self.iteration_limit:
197
                    self.logger.warn("Reached iteration limit. Stopping.")
198
199
200
                    break

            d = s + d * beta
201

Martin Reinecke's avatar
Martin Reinecke committed
202
            print "iter:",iteration_number
203
204
205
206
            iteration_number += 1
            previous_gamma = gamma

        return x, convergence