krylov_sampling.py 3.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
20
import numpy as np
from ..minimization.quadratic_energy import QuadraticEnergy
21
22


Martin Reinecke's avatar
Martin Reinecke committed
23
def generate_krylov_samples(D_inv, S, j, N_samps, controller):
24
    """
Martin Reinecke's avatar
Martin Reinecke committed
25
    Generates inverse samples from a curvature D.
26
27
28
29
30
31
32
33
    This algorithm iteratively generates samples from
    a curvature D by applying conjugate gradient steps
    and resampling the curvature in search direction.

    Parameters
    ----------
    D_inv : EndomorphicOperator
        The curvature which will be the inverse of the covarianc
Martin Reinecke's avatar
Martin Reinecke committed
34
        of the generated samples
35
36
37
38
39
40
41
    S : EndomorphicOperator (from which one can sample)
        A prior covariance operator which is used to generate prior
        samples that are then iteratively updated
    j : Field, optional
        A Field to which the inverse of D_inv is applied. The solution
        of this matrix inversion problem is a side product of generating
        the samples.
42
        If not supplied, it is sampled from the inverse prior.
Martin Reinecke's avatar
Martin Reinecke committed
43
44
45
46
    N_samps : Int
        How many samples to generate.
    controller : IterationController
        convergence controller for the conjugate gradient iteration
47
48
49
50
51

    Returns
    -------
    (solution, samples) : A tuple of a field 'solution' and a list of fields
        'samples'. The first entry of the tuple is the solution x to
52
            D_inv(x) = j
53
54
        and the second entry are a list of samples from D_inv.inverse
    """
55
    # RL FIXME: make consistent with complex numbers
Martin Reinecke's avatar
Martin Reinecke committed
56
    j = S.draw_sample(from_inverse=True) if j is None else j
57
    energy = QuadraticEnergy(j.empty_copy().fill(0.), D_inv, j)
Martin Reinecke's avatar
Martin Reinecke committed
58
59
60
61
    y = [S.draw_sample() for _ in range(N_samps)]

    status = controller.start(energy)
    if status != controller.CONTINUE:
Martin Reinecke's avatar
Martin Reinecke committed
62
63
64
65
66
67
68
69
        return energy.position, y

    r = energy.gradient
    d = r.copy()

    previous_gamma = r.vdot(r).real
    if previous_gamma == 0:
        return energy.position, y
Martin Reinecke's avatar
Martin Reinecke committed
70
71

    while True:
Martin Reinecke's avatar
Martin Reinecke committed
72
73
74
75
76
77
78
79
80
81
82
        q = energy.curvature(d)
        ddotq = d.vdot(q).real
        if ddotq == 0.:
            logger.error("Error: ConjugateGradient: ddotq==0.")
            return energy.position, y
        alpha = previous_gamma/ddotq

        if alpha < 0:
            logger.error("Error: ConjugateGradient: alpha<0.")
            return energy.position, y

83
84
        for i in range(len(y)):
            y[i] += (np.random.randn()*np.sqrt(ddotq) - y[i].vdot(q))/ddotq * d
Martin Reinecke's avatar
Martin Reinecke committed
85
86
87
88
89
90
91
92
93
94

        q *= -alpha
        r = r + q

        energy = energy.at_with_grad(energy.position - alpha*d, r)

        gamma = r.vdot(r).real
        if gamma == 0:
            return energy.position, y

Martin Reinecke's avatar
Martin Reinecke committed
95
96
        status = controller.check(energy)
        if status != controller.CONTINUE:
Martin Reinecke's avatar
Martin Reinecke committed
97
98
99
100
101
102
            return energy.position, y

        d *= max(0, gamma/previous_gamma)
        d += r

        previous_gamma = gamma