krylov_sampling.py 2.48 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import nifty4 as ift
import numpy as np
import matplotlib.pyplot as plt
from nifty4.sugar import create_power_operator

x_space = ift.RGSpace(1024)
h_space = x_space.get_default_codomain()

d_space = x_space
N_hat = ift.Field(d_space,10.)
N_hat.val[400:450] = 0.0001
N = ift.DiagonalOperator(N_hat, d_space)

FFT = ift.HarmonicTransformOperator(h_space, target=x_space)
R = ift.ScalingOperator(1., x_space)
ampspec = lambda k : 1./(1.+k**2.)
S = ift.ScalingOperator(1.,h_space)
A = create_power_operator(h_space, ampspec)
s_h = S.draw_sample()
sky = FFT*A
s_x = sky(s_h)
n = N.draw_sample()
d = R(s_x) + n

R_p = R*FFT*A
j = R_p.adjoint(N.inverse(d))
D_inv = R_p.adjoint*N.inverse*R_p + S.inverse

def sample(D_inv, S, j,  N_samps, N_iter):
    space = D_inv.domain
    x = ift.Field.zeros(space)
    r = j.copy()
    p = r.copy()
    d = p.vdot(D_inv(p))
    y = []
    for i in range(N_samps):
        y += [S.draw_sample()]
    for k in range(1,1+N_iter):
        gamma = r.vdot(r)/d
        if gamma == 0.:
            break
        x +=  gamma*p
        print(p.vdot(D_inv(j)))
        for i in range(N_samps):
            y[i] -= p.vdot(D_inv(y[i]))*p/d
            y[i] += np.random.randn()/np.sqrt(d)*p
        #r_new = j - D_inv(x)
        r_new = r - gamma*D_inv(p)
        beta = r_new.vdot(r_new)/(r.vdot(r))
        r = r_new
        p = r + beta*p
        d = p.vdot(D_inv(p))
        if d == 0.:
            break;
    return x, y

N_samps = 200
N_iter = 10
m,samps = sample(D_inv, S, j, N_samps, N_iter)
m_x = sky(m)
IC = ift.GradientNormController(iteration_limit = N_iter)
inverter = ift.ConjugateGradient(IC)
curv = ift.library.WienerFilterCurvature(S=S,N=N,R=R_p, inverter=inverter)
samps_old = []
for i in range(N_samps):
    samps_old += [curv.draw_sample(from_inverse=True)]

plt.plot(d.val,'o', label="data")
plt.plot(s_x.val,label="original")
plt.plot(m_x.val, label="reconstruction")
plt.legend()
plt.show()

for i in range(N_samps):
    plt.plot(sky(samps_old[i]).val, color = 'b')
    plt.plot(sky(samps[i]).val, color = 'r')
plt.plot((s_x-m_x).val,color='k')
plt.legend()
plt.show()
D_hat_old = ift.Field.zeros(x_space).val
D_hat_new = ift.Field.zeros(x_space).val
for i in range(N_samps):
    D_hat_old += sky(samps_old[i]).val**2
    D_hat_new += sky(samps[i]).val**2
plt.plot(np.sqrt(D_hat_old/N_samps), color='b')
plt.plot(np.sqrt(D_hat_new/N_samps), color='r')
plt.plot(-np.sqrt(D_hat_old/N_samps), color='b')
plt.plot(-np.sqrt(D_hat_new/N_samps), color='r')
plt.plot((s_x-m_x).val, color='k')
plt.show()