krylov_sampling.py 2.67 KB
Newer Older
1
2
3
4
5
import nifty4 as ift
import numpy as np
import matplotlib.pyplot as plt
from nifty4.sugar import create_power_operator

Philipp Arras's avatar
Philipp Arras committed
6
7
np.random.seed(42)

8
9
10
11
x_space = ift.RGSpace(1024)
h_space = x_space.get_default_codomain()

d_space = x_space
Martin Reinecke's avatar
Martin Reinecke committed
12
13
14
15
N_hat = np.full(d_space.shape, 10.)
N_hat[400:450] = 0.0001
N_hat = ift.Field.from_global_data(d_space, N_hat)
N = ift.DiagonalOperator(N_hat)
16

Martin Reinecke's avatar
Martin Reinecke committed
17
FFT = ift.HarmonicTransformOperator(h_space, x_space)
18
R = ift.ScalingOperator(1., x_space)
Philipp Arras's avatar
PEP8    
Philipp Arras committed
19
20
21
22
23
24


def ampspec(k): return 1. / (1. + k**2.)


S = ift.ScalingOperator(1., h_space)
25
26
A = create_power_operator(h_space, ampspec)
s_h = S.draw_sample()
Philipp Arras's avatar
PEP8    
Philipp Arras committed
27
sky = FFT * A
28
29
30
31
s_x = sky(s_h)
n = N.draw_sample()
d = R(s_x) + n

Philipp Arras's avatar
PEP8    
Philipp Arras committed
32
R_p = R * FFT * A
33
j = R_p.adjoint(N.inverse(d))
34
D_inv = ift.SandwichOperator(R_p, N.inverse) + S.inverse
35

Philipp Arras's avatar
PEP8    
Philipp Arras committed
36

37
N_samps = 200
Martin Reinecke's avatar
Martin Reinecke committed
38
39
40
N_iter = 100
IC = ift.GradientNormController(tol_abs_gradnorm=1e-3, iteration_limit=N_iter)
m, samps = ift.library.generate_krylov_samples(D_inv, S, j, N_samps, IC)
41
42
m_x = sky(m)
inverter = ift.ConjugateGradient(IC)
Philipp Arras's avatar
PEP8    
Philipp Arras committed
43
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter)
Martin Reinecke's avatar
Martin Reinecke committed
44
samps_old = [curv.draw_sample(from_inverse=True) for i in range(N_samps)]
45

Martin Reinecke's avatar
Martin Reinecke committed
46
47
48
plt.plot(d.to_global_data(), '+', label="data", alpha=.5)
plt.plot(s_x.to_global_data(), label="original")
plt.plot(m_x.to_global_data(), label="reconstruction")
49
plt.legend()
Philipp Arras's avatar
Philipp Arras committed
50
51
plt.savefig('Krylov_reconstruction.png')
plt.close()
52

Philipp Arras's avatar
Philipp Arras committed
53
pltdict = {'alpha': .3, 'linewidth': .2}
54
for i in range(N_samps):
Philipp Arras's avatar
Philipp Arras committed
55
    if i == 0:
Martin Reinecke's avatar
Martin Reinecke committed
56
        plt.plot(sky(samps_old[i]).to_global_data(), color='b',
Martin Reinecke's avatar
Martin Reinecke committed
57
58
                 label='Traditional samples (residuals)',
                 **pltdict)
Martin Reinecke's avatar
Martin Reinecke committed
59
        plt.plot(sky(samps[i]).to_global_data(), color='r',
Martin Reinecke's avatar
Martin Reinecke committed
60
61
                 label='Krylov samples (residuals)',
                 **pltdict)
Philipp Arras's avatar
Philipp Arras committed
62
    else:
Martin Reinecke's avatar
Martin Reinecke committed
63
64
65
        plt.plot(sky(samps_old[i]).to_global_data(), color='b', **pltdict)
        plt.plot(sky(samps[i]).to_global_data(), color='r', **pltdict)
plt.plot((s_x - m_x).to_global_data(), color='k', label='signal - mean')
Philipp Arras's avatar
Philipp Arras committed
66
67
plt.legend()
plt.savefig('Krylov_samples_residuals.png')
Philipp Arras's avatar
Philipp Arras committed
68
69
plt.close()

Martin Reinecke's avatar
Martin Reinecke committed
70
71
D_hat_old = ift.Field.zeros(x_space).to_global_data()
D_hat_new = ift.Field.zeros(x_space).to_global_data()
72
for i in range(N_samps):
Martin Reinecke's avatar
Martin Reinecke committed
73
74
    D_hat_old += sky(samps_old[i]).to_global_data()**2
    D_hat_new += sky(samps[i]).to_global_data()**2
Philipp Arras's avatar
Philipp Arras committed
75
plt.plot(np.sqrt(D_hat_old / N_samps), 'r--', label='Traditional uncertainty')
Philipp Arras's avatar
PEP8    
Philipp Arras committed
76
77
plt.plot(-np.sqrt(D_hat_old / N_samps), 'r--')
plt.fill_between(range(len(D_hat_new)), -np.sqrt(D_hat_new / N_samps), np.sqrt(
Martin Reinecke's avatar
Martin Reinecke committed
78
79
    D_hat_new / N_samps), facecolor='0.5', alpha=0.5,
    label='Krylov uncertainty')
Martin Reinecke's avatar
Martin Reinecke committed
80
plt.plot((s_x - m_x).to_global_data(), color='k', label='signal - mean')
Philipp Arras's avatar
Philipp Arras committed
81
82
83
plt.legend()
plt.savefig('Krylov_uncertainty.png')
plt.close()