cartesian_wiener_filter.py 4.03 KB
Newer Older
Theo Steininger's avatar
Theo Steininger committed
1
import numpy as np
Martin Reinecke's avatar
Martin Reinecke committed
2
import nifty4 as ift
Theo Steininger's avatar
Theo Steininger committed
3
4

if __name__ == "__main__":
Martin Reinecke's avatar
Martin Reinecke committed
5
    signal_to_noise = 0.5  # The signal to noise ratio
Theo Steininger's avatar
Theo Steininger committed
6

Martin Reinecke's avatar
Martin Reinecke committed
7
    # Setting up parameters
Martin Reinecke's avatar
Martin Reinecke committed
8
9
10
11
12
13
14
15
16
17
18
19
    L_1 = 2.                   # Total side-length of the domain
    N_pixels_1 = 512           # Grid resolution (pixels per axis)
    L_2 = 2.                   # Total side-length of the domain
    N_pixels_2 = 512           # Grid resolution (pixels per axis)
    correlation_length_1 = 1.
    field_variance_1 = 2.      # Variance of field in position space
    response_sigma_1 = 0.05    # Smoothing length of response
    correlation_length_2 = 1.
    field_variance_2 = 2.      # Variance of field in position space
    response_sigma_2 = 0.01    # Smoothing length of response

    def power_spectrum_1(k):   # note: field_variance**2 = a*k_0/4.
Theo Steininger's avatar
Theo Steininger committed
20
21
22
        a = 4 * correlation_length_1 * field_variance_1**2
        return a / (1 + k * correlation_length_1) ** 4.

23
24
25
26
    def power_spectrum_2(k):  # note: field_variance**2 = a*k_0/4.
        a = 4 * correlation_length_2 * field_variance_2**2
        return a / (1 + k * correlation_length_2) ** 2.5

Theo Steininger's avatar
Theo Steininger committed
27
    signal_space_1 = ift.RGSpace([N_pixels_1], distances=L_1/N_pixels_1)
Martin Reinecke's avatar
Martin Reinecke committed
28
    harmonic_space_1 = signal_space_1.get_default_codomain()
29
30
31
32
    signal_space_2 = ift.RGSpace([N_pixels_2], distances=L_2/N_pixels_2)
    harmonic_space_2 = signal_space_2.get_default_codomain()

    signal_domain = ift.DomainTuple.make((signal_space_1, signal_space_2))
Martin Reinecke's avatar
Martin Reinecke committed
33
34
    harmonic_domain = ift.DomainTuple.make((harmonic_space_1,
                                            harmonic_space_2))
35

Martin Reinecke's avatar
Martin Reinecke committed
36
    ht_1 = ift.HarmonicTransformOperator(harmonic_domain, space=0)
37
    ht_2 = ift.HarmonicTransformOperator(ht_1.target, space=1)
Martin Reinecke's avatar
Martin Reinecke committed
38
    ht = ht_2*ht_1
Theo Steininger's avatar
Theo Steininger committed
39

40
41
    S = (ift.create_power_operator(harmonic_domain, power_spectrum_1, 0) *
         ift.create_power_operator(harmonic_domain, power_spectrum_2, 1))
Theo Steininger's avatar
Theo Steininger committed
42
43

    np.random.seed(10)
44
    mock_signal = S.draw_sample()
Theo Steininger's avatar
Theo Steininger committed
45
46
47

    # Setting up a exemplary response
    N1_10 = int(N_pixels_1/10)
Martin Reinecke's avatar
Martin Reinecke committed
48
49
    mask_1 = np.ones(signal_space_1.shape)
    mask_1[N1_10*7:N1_10*9] = 0.
50
    mask_1 = ift.Field.from_global_data(signal_space_1, mask_1)
Theo Steininger's avatar
Theo Steininger committed
51
52

    N2_10 = int(N_pixels_2/10)
Martin Reinecke's avatar
Martin Reinecke committed
53
54
    mask_2 = np.ones(signal_space_2.shape)
    mask_2[N2_10*7:N2_10*9] = 0.
55
    mask_2 = ift.Field.from_global_data(signal_space_2, mask_2)
Theo Steininger's avatar
Theo Steininger committed
56

Martin Reinecke's avatar
Martin Reinecke committed
57
58
    R = ift.GeometryRemover(signal_domain)
    R = R*ift.DiagonalOperator(mask_1, signal_domain, spaces=0)
Martin Reinecke's avatar
Martin Reinecke committed
59
    R = R*ift.DiagonalOperator(mask_2, signal_domain, spaces=1)
Martin Reinecke's avatar
Martin Reinecke committed
60
61
62
63
64
    R = R*ht
    R = R * ift.create_harmonic_smoothing_operator(harmonic_domain, 0,
                                                   response_sigma_1)
    R = R * ift.create_harmonic_smoothing_operator(harmonic_domain, 1,
                                                   response_sigma_2)
Theo Steininger's avatar
Theo Steininger committed
65
66
    data_domain = R.target

Martin Reinecke's avatar
Martin Reinecke committed
67
    noiseless_data = R(mock_signal)
Martin Reinecke's avatar
Martin Reinecke committed
68
    noise_amplitude = noiseless_data.val.std()/signal_to_noise
Theo Steininger's avatar
Theo Steininger committed
69
    # Setting up the noise covariance and drawing a random noise realization
70
    N = ift.ScalingOperator(noise_amplitude**2, data_domain)
Martin Reinecke's avatar
Martin Reinecke committed
71
    noise = N.draw_sample()
Martin Reinecke's avatar
Martin Reinecke committed
72
    data = noiseless_data + noise
Theo Steininger's avatar
Theo Steininger committed
73
74

    # Wiener filter
Martin Reinecke's avatar
Martin Reinecke committed
75
    j = R.adjoint_times(N.inverse_times(data))
76
    ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=0.1)
77
    inverter = ift.ConjugateGradient(controller=ctrl)
78
    wiener_curvature = ift.library.WienerFilterCurvature(
Martin Reinecke's avatar
Martin Reinecke committed
79
        S=S, N=N, R=R, inverter=inverter)
Theo Steininger's avatar
Theo Steininger committed
80

Martin Reinecke's avatar
Martin Reinecke committed
81
    m_k = wiener_curvature.inverse_times(j)
Martin Reinecke's avatar
Martin Reinecke committed
82
    m = ht(m_k)
Theo Steininger's avatar
Theo Steininger committed
83

Martin Reinecke's avatar
Martin Reinecke committed
84
85
86
87
88
89
    plotdict = {"colormap": "Planck-like"}
    plot_space = ift.RGSpace((N_pixels_1, N_pixels_2))
    ift.plot(ht(mock_signal).cast_domain(plot_space),
             name='mock_signal.png', **plotdict)
    ift.plot(data.cast_domain(plot_space), name='data.png', **plotdict)
    ift.plot(m.cast_domain(plot_space), name='map.png', **plotdict)
Martin Reinecke's avatar
Martin Reinecke committed
90
    # sampling the uncertainty map
91
    mean, variance = ift.probe_with_posterior_samples(wiener_curvature, ht, 10)
Martin Reinecke's avatar
Martin Reinecke committed
92
    ift.plot(ift.sqrt(variance).cast_domain(plot_space),
Martin Reinecke's avatar
Martin Reinecke committed
93
             name="uncertainty.png", **plotdict)
Martin Reinecke's avatar
Martin Reinecke committed
94
    ift.plot((mean+m).cast_domain(plot_space),
Martin Reinecke's avatar
Martin Reinecke committed
95
             name="posterior_mean.png", **plotdict)