poisson_demo.py 3.3 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import nifty4 as ift
import matplotlib.pyplot as plt


class Exp3(object):
    def __call__(self, x):
        return ift.exp(3*x)

    def derivative(self, x):
        return 3*ift.exp(3*x)


if __name__ == "__main__":
    np.random.seed(20)

    # Set up physical constants
    nu = 1.        # excitation field level
    kappa = 10.    # diffusion constant
    eps = 1e-8     # small number to tame zero mode
    sigma_n = 0.2  # noise level
    sigma_n2 = sigma_n**2
    L = 1.         # Total length of interval or volume the field lives on
    nprobes = 100  # Number of probes for uncertainty quantification

    # Define resolution (pixels per dimension)
    N_pixels = 1024

    # Define data gaps
    N1a = int(0.6*N_pixels)
    N1b = int(0.64*N_pixels)
    N2a = int(0.67*N_pixels)
    N2b = int(0.8*N_pixels)

    # Set up derived constants
    amp = nu/(2*kappa)  # spectral normalization
    pow_spec = lambda k: amp / (eps + k**2)
    lambda2 = 2*kappa*sigma_n2/nu  # resulting correlation length squared
    lambda1 = np.sqrt(lambda2)
    pixel_width = L/N_pixels
    x = np.arange(0, 1, pixel_width)

    # Set up the geometry
    s_domain = ift.RGSpace([N_pixels], distances=pixel_width)
    h_domain = s_domain.get_default_codomain()
    HT = ift.HarmonicTransformOperator(h_domain, s_domain)
    aHT = HT.adjoint

    # Create mock signal
    Phi_h = ift.create_power_operator(h_domain, power_spectrum=pow_spec)
    phi_h = Phi_h.draw_sample()
    # remove zero mode
    phi_h.val[0] = 0
    phi = HT(phi_h)

    # Setting up an exemplary response
    GeoRem = ift.GeometryRemover(s_domain)
    GeoAdd = GeoRem.adjoint
    d_domain = GeoRem.target[0]
    mask = np.ones(d_domain.shape)
    mask[N1a:N1b] = 0.
    mask[N2a:N2b] = 0.
    mask = ift.Field.from_global_data(d_domain, mask)
    Mask = ift.DiagonalOperator(mask)
    R0 = Mask*GeoRem
    R = R0*HT
    IC = ift.GradientNormController(name="inverter", iteration_limit=500,
                                    tol_abs_gradnorm=1e-3)
    inverter = ift.ConjugateGradient(controller=IC)
    x_mod = x*mask.val+2*(1-mask.val)

    nonlin = Exp3()
    lam = R0(nonlin(HT(phi_h)))
    data = ift.Field(d_domain, val=np.random.poisson(lam.val),
                     dtype=np.float64, copy=True)
    psi0 = ift.Field.full(h_domain, 1e-7)
    energy = ift.library.PoissonEnergy(psi0, data, R0, nonlin, HT, Phi_h,
                                       inverter)
    IC1 = ift.GradientNormController(name="IC1", iteration_limit=200,
                                     tol_abs_gradnorm=1e-4)
    minimizer = ift.RelaxedNewton(IC1)
    energy = minimizer(energy)[0]

    var = ift.probe_with_posterior_samples(energy.curvature, HT, nprobes)[1]
    sig = ift.sqrt(var)

    m = HT(energy.position)
    phi = HT(phi_h)
    plt.rcParams["text.usetex"] = True
    c1 = nonlin(m-sig).to_global_data()
    c2 = nonlin(m+sig).to_global_data()
    plt.fill_between(x, c1, c2, color='pink', alpha=None)
    plt.plot(x, nonlin(phi).to_global_data(), label=r"NL($\varphi$)",
             color='black')
    plt.scatter(x_mod, data.val, label=r'$d$', s=1, color='blue', alpha=0.5)
    plt.plot(x, nonlin(m).to_global_data(), label=r'NL(m)$', color='red')
    plt.xlim([0, L])
    plt.ylim([-0.1, None])
    plt.title('Poisson log-normal reconstruction')
    plt.legend()
    plt.savefig('Poisson.png')