krylov_sampling.py 3.52 KB
Newer Older
1
2
3
4
5
import nifty4 as ift
import numpy as np
import matplotlib.pyplot as plt
from nifty4.sugar import create_power_operator

Philipp Arras's avatar
Philipp Arras committed
6
7
np.random.seed(42)

8
9
10
11
x_space = ift.RGSpace(1024)
h_space = x_space.get_default_codomain()

d_space = x_space
Philipp Arras's avatar
PEP8    
Philipp Arras committed
12
N_hat = ift.Field(d_space, 10.)
13
14
15
16
17
N_hat.val[400:450] = 0.0001
N = ift.DiagonalOperator(N_hat, d_space)

FFT = ift.HarmonicTransformOperator(h_space, target=x_space)
R = ift.ScalingOperator(1., x_space)
Philipp Arras's avatar
PEP8    
Philipp Arras committed
18
19
20
21
22
23


def ampspec(k): return 1. / (1. + k**2.)


S = ift.ScalingOperator(1., h_space)
24
25
A = create_power_operator(h_space, ampspec)
s_h = S.draw_sample()
Philipp Arras's avatar
PEP8    
Philipp Arras committed
26
sky = FFT * A
27
28
29
30
s_x = sky(s_h)
n = N.draw_sample()
d = R(s_x) + n

Philipp Arras's avatar
PEP8    
Philipp Arras committed
31
R_p = R * FFT * A
32
j = R_p.adjoint(N.inverse(d))
Philipp Arras's avatar
PEP8    
Philipp Arras committed
33
34
D_inv = R_p.adjoint * N.inverse * R_p + S.inverse

35
history = []
36

Martin Reinecke's avatar
Martin Reinecke committed
37

38
def sample(D_inv, S, j,  N_samps, N_iter):
39
    global history
40
41
42
43
44
45
46
47
    space = D_inv.domain
    x = ift.Field.zeros(space)
    r = j.copy()
    p = r.copy()
    d = p.vdot(D_inv(p))
    y = []
    for i in range(N_samps):
        y += [S.draw_sample()]
Philipp Arras's avatar
PEP8    
Philipp Arras committed
48
    for k in range(1, 1 + N_iter):
49
        history += [y[0].copy()]
Philipp Arras's avatar
PEP8    
Philipp Arras committed
50
        gamma = r.vdot(r) / d
51
52
        if gamma == 0.:
            break
Philipp Arras's avatar
PEP8    
Philipp Arras committed
53
        x += gamma * p
54
        #print(p.vdot(D_inv(j)))
55
        for i in range(N_samps):
Philipp Arras's avatar
PEP8    
Philipp Arras committed
56
57
            y[i] -= p.vdot(D_inv(y[i])) * p / d
            y[i] += np.random.randn() / np.sqrt(d) * p
58
        print("variance iteration "+str(k)+":", np.sqrt(p.vdot(p)/d))
59
        #r_new = j - D_inv(x)
Philipp Arras's avatar
PEP8    
Philipp Arras committed
60
61
        r_new = r - gamma * D_inv(p)
        beta = r_new.vdot(r_new) / (r.vdot(r))
62
        r = r_new
Philipp Arras's avatar
PEP8    
Philipp Arras committed
63
        p = r + beta * p
64
65
        d = p.vdot(D_inv(p))
        if d == 0.:
Philipp Arras's avatar
PEP8    
Philipp Arras committed
66
            break
67
68
    return x, y

Philipp Arras's avatar
PEP8    
Philipp Arras committed
69

70
71
N_samps = 200
N_iter = 10
Philipp Arras's avatar
PEP8    
Philipp Arras committed
72
m, samps = sample(D_inv, S, j, N_samps, N_iter)
73
m_x = sky(m)
Philipp Arras's avatar
PEP8    
Philipp Arras committed
74
IC = ift.GradientNormController(iteration_limit=N_iter)
75
inverter = ift.ConjugateGradient(IC)
Philipp Arras's avatar
PEP8    
Philipp Arras committed
76
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter)
77
78
79
80
samps_old = []
for i in range(N_samps):
    samps_old += [curv.draw_sample(from_inverse=True)]

Philipp Arras's avatar
Philipp Arras committed
81
plt.plot(d.val, '+', label="data", alpha=.5)
Philipp Arras's avatar
PEP8    
Philipp Arras committed
82
plt.plot(s_x.val, label="original")
83
84
plt.plot(m_x.val, label="reconstruction")
plt.legend()
Philipp Arras's avatar
Philipp Arras committed
85
86
plt.savefig('Krylov_reconstruction.png')
plt.close()
87

Philipp Arras's avatar
Philipp Arras committed
88
pltdict = {'alpha': .3, 'linewidth': .2}
89
for i in range(N_samps):
Philipp Arras's avatar
Philipp Arras committed
90
    if i == 0:
Martin Reinecke's avatar
Martin Reinecke committed
91
92
93
94
95
96
        plt.plot(sky(samps_old[i]).val, color='b',
                 label='Traditional samples (residuals)',
                 **pltdict)
        plt.plot(sky(samps[i]).val, color='r',
                 label='Krylov samples (residuals)',
                 **pltdict)
Philipp Arras's avatar
Philipp Arras committed
97
98
99
100
101
102
    else:
        plt.plot(sky(samps_old[i]).val, color='b', **pltdict)
        plt.plot(sky(samps[i]).val, color='r', **pltdict)
plt.plot((s_x - m_x).val, color='k', label='signal - mean')
plt.legend()
plt.savefig('Krylov_samples_residuals.png')
Philipp Arras's avatar
Philipp Arras committed
103
104
plt.close()

105
106
107
108
109
D_hat_old = ift.Field.zeros(x_space).val
D_hat_new = ift.Field.zeros(x_space).val
for i in range(N_samps):
    D_hat_old += sky(samps_old[i]).val**2
    D_hat_new += sky(samps[i]).val**2
Philipp Arras's avatar
Philipp Arras committed
110
plt.plot(np.sqrt(D_hat_old / N_samps), 'r--', label='Traditional uncertainty')
Philipp Arras's avatar
PEP8    
Philipp Arras committed
111
112
plt.plot(-np.sqrt(D_hat_old / N_samps), 'r--')
plt.fill_between(range(len(D_hat_new)), -np.sqrt(D_hat_new / N_samps), np.sqrt(
Martin Reinecke's avatar
Martin Reinecke committed
113
114
    D_hat_new / N_samps), facecolor='0.5', alpha=0.5,
    label='Krylov uncertainty')
Philipp Arras's avatar
PEP8    
Philipp Arras committed
115
plt.plot((s_x - m_x).val, color='k', label='signal - mean')
Philipp Arras's avatar
Philipp Arras committed
116
117
118
plt.legend()
plt.savefig('Krylov_uncertainty.png')
plt.close()
119
120
121
122
123
124
125

for i in range(min(6, len(history))):
    plt.plot(sky(history[i]).val, label="step " + str(i+1))
plt.plot(s_x.val-m_x.val, 'k-', label="residual")
plt.legend()
plt.savefig('iterations.png')
plt.close()