sampling.py 3.04 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
import nifty5 as ift
Philipp Arras's avatar
Philipp Arras committed
2
3
import numpy as np
import matplotlib.pyplot as plt
Philipp Arras's avatar
Philipp Arras committed
4
from nifty5.sugar import create_power_operator
Philipp Arras's avatar
Philipp Arras committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

np.random.seed(42)

x_space = ift.RGSpace(1024)
h_space = x_space.get_default_codomain()

d_space = x_space
N_hat = np.full(d_space.shape, 10.)
N_hat[400:450] = 0.0001
N_hat = ift.Field.from_global_data(d_space, N_hat)
N = ift.DiagonalOperator(N_hat)

FFT = ift.HarmonicTransformOperator(h_space, x_space)
R = ift.ScalingOperator(1., x_space)


def ampspec(k): return 1. / (1. + k**2.)


S = ift.ScalingOperator(1., h_space)
A = create_power_operator(h_space, ampspec)
s_h = S.draw_sample()
sky = FFT * A
s_x = sky(s_h)
n = N.draw_sample()
d = R(s_x) + n

R_p = R * FFT * A
j = R_p.adjoint(N.inverse(d))
D_inv = ift.SandwichOperator.make(R_p, N.inverse) + S.inverse


N_samps = 200
N_iter = 100

tol = 1e-3
IC = ift.GradientNormController(tol_abs_gradnorm=tol, iteration_limit=N_iter)
inverter = ift.ConjugateGradient(IC)
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter, sampling_inverter=inverter)
44
m_xi = curv.inverse_times(j)
Philipp Arras's avatar
Philipp Arras committed
45
46
47
48
49
50
51
52
53
54
55
56
samps_long = [curv.draw_sample(from_inverse=True) for i in range(N_samps)]

tol = 1e2
IC = ift.GradientNormController(tol_abs_gradnorm=tol, iteration_limit=N_iter)
inverter = ift.ConjugateGradient(IC)
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter, sampling_inverter=inverter)
samps_short = [curv.draw_sample(from_inverse=True) for i in range(N_samps)]

# Compute mean
sc = ift.StatCalculator()
for samp in samps_long:
    sc.add(samp)
57
m_x = sky(sc.mean + m_xi)
Philipp Arras's avatar
Philipp Arras committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

plt.plot(d.to_global_data(), '+', label="data", alpha=.5)
plt.plot(s_x.to_global_data(), label="original")
plt.plot(m_x.to_global_data(), label="reconstruction")
plt.legend()
plt.savefig('reconstruction.png')
plt.close()

pltdict = {'alpha': .3, 'linewidth': .2}
for i in range(N_samps):
    if i == 0:
        plt.plot(sky(samps_short[i]).to_global_data(), color='b',
                 label='Short samples (residuals)',
                 **pltdict)
        plt.plot(sky(samps_long[i]).to_global_data(), color='r',
                 label='Long samples (residuals)',
                 **pltdict)
    else:
        plt.plot(sky(samps_short[i]).to_global_data(), color='b', **pltdict)
        plt.plot(sky(samps_long[i]).to_global_data(), color='r', **pltdict)
plt.plot((s_x - m_x).to_global_data(), color='k', label='signal - mean')
plt.legend()
plt.savefig('samples_residuals.png')
plt.close()

D_hat_old = ift.full(x_space, 0.).to_global_data()
D_hat_new = ift.full(x_space, 0.).to_global_data()
for i in range(N_samps):
    D_hat_old += sky(samps_short[i]).to_global_data()**2
    D_hat_new += sky(samps_long[i]).to_global_data()**2
plt.plot(np.sqrt(D_hat_old / N_samps), 'r--', label='Short uncertainty')
plt.plot(-np.sqrt(D_hat_old / N_samps), 'r--')
plt.fill_between(range(len(D_hat_new)), -np.sqrt(D_hat_new / N_samps), np.sqrt(
    D_hat_new / N_samps), facecolor='0.5', alpha=0.5,
    label='Long uncertainty')
plt.plot((s_x - m_x).to_global_data(), color='k', label='signal - mean')
plt.legend()
plt.savefig('uncertainty.png')
plt.close()