log_normal_wiener_filter.py 3.83 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
import nifty4 as ift
Martin Reinecke's avatar
updates  
Martin Reinecke committed
2
import numpy as np
Theo Steininger's avatar
Theo Steininger committed
3

4
if __name__ == "__main__":
Martin Reinecke's avatar
updates  
Martin Reinecke committed
5
    np.random.seed(42)
Martin Reinecke's avatar
Martin Reinecke committed
6
    # Setting up parameters
7 8
    correlation_length = 1.     # Typical distance over which the field is correlated
    field_variance = 2.         # Variance of field in position space
Theo Steininger's avatar
Theo Steininger committed
9 10
    response_sigma = 0.02       # Smoothing length of response (in same unit as L)
    signal_to_noise = 100       # The signal to noise ratio
11
    np.random.seed(43)          # Fixing the random seed
Martin Reinecke's avatar
Martin Reinecke committed
12

13 14 15 16
    def power_spectrum(k):      # Defining the power spectrum
        a = 4 * correlation_length * field_variance**2
        return a / (1 + k * correlation_length) ** 4

Martin Reinecke's avatar
Martin Reinecke committed
17 18 19 20
    # Setting up the geometry
    L = 2.  # Total side-length of the domain
    N_pixels = 128  # Grid resolution (pixels per axis)
    # signal_space = ift.RGSpace([N_pixels, N_pixels], distances=L/N_pixels)
Martin Reinecke's avatar
updates  
Martin Reinecke committed
21 22 23 24
    signal_space = ift.HPSpace(16)
    harmonic_space = signal_space.get_default_codomain()
    fft = ift.FFTOperator(harmonic_space, target=signal_space)
    power_space = ift.PowerSpace(harmonic_space)
25

Martin Reinecke's avatar
Martin Reinecke committed
26 27 28
    # Creating the mock signal
    S = ift.create_power_operator(harmonic_space,
                                  power_spectrum=power_spectrum)
Martin Reinecke's avatar
Martin Reinecke committed
29
    mock_power = ift.PS_field(power_space, power_spectrum)
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
30
    mock_signal = fft(ift.power_synthesize(mock_power, real_signal=True))
31 32

    # Setting up an exemplary response
33
    mask = ift.Field.ones(signal_space)
34
    N10 = int(N_pixels/10)
Martin Reinecke's avatar
Martin Reinecke committed
35 36 37
    # mask.val[N10*5:N10*9, N10*5:N10*9] = 0.
    R = ift.ResponseOperator(signal_space, sigma=(response_sigma,),
                             exposure=(mask,))
38
    data_domain = R.target[0]
Martin Reinecke's avatar
Martin Reinecke committed
39
    R_harmonic = R*fft
40 41

    # Setting up the noise covariance and drawing a random noise realization
42
    ndiag = ift.Field.full(data_domain, mock_signal.var()/signal_to_noise)
43
    N = ift.DiagonalOperator(ndiag.weight(1))
Martin Reinecke's avatar
updates  
Martin Reinecke committed
44
    noise = ift.Field.from_random(domain=data_domain, random_type='normal',
Martin Reinecke's avatar
Martin Reinecke committed
45 46
        std=mock_signal.std()/np.sqrt(signal_to_noise), mean=0)
    data = R(ift.exp(mock_signal)) + noise
47 48

    # Wiener filter
49
    m0 = ift.Field.zeros(harmonic_space)
50
    ctrl = ift.GradientNormController(verbose=False, tol_abs_gradnorm=0.0001)
Martin Reinecke's avatar
Martin Reinecke committed
51 52
    ctrl2 = ift.GradientNormController(verbose=True, tol_abs_gradnorm=0.1,
                                       name="outer")
53
    inverter = ift.ConjugateGradient(controller=ctrl)
Martin Reinecke's avatar
Martin Reinecke committed
54 55 56
    energy = ift.library.LogNormalWienerFilterEnergy(m0, data, R_harmonic,
                                                     N, S, inverter=inverter)
    # minimizer1 = ift.VL_BFGS(controller=ctrl2, max_history_length=20)
Martin Reinecke's avatar
updates  
Martin Reinecke committed
57
    minimizer2 = ift.RelaxedNewton(controller=ctrl2)
Martin Reinecke's avatar
Martin Reinecke committed
58
    # minimizer3 = ift.SteepestDescent(controller=ctrl2)
Martin Reinecke's avatar
updates  
Martin Reinecke committed
59

Martin Reinecke's avatar
Martin Reinecke committed
60
    # me1 = minimizer1(energy)
Martin Reinecke's avatar
updates  
Martin Reinecke committed
61
    me2 = minimizer2(energy)
Martin Reinecke's avatar
Martin Reinecke committed
62
    # me3 = minimizer3(energy)
Theo Steininger's avatar
Theo Steininger committed
63

Martin Reinecke's avatar
Martin Reinecke committed
64
    # m1 = fft(me1[0].position)
Martin Reinecke's avatar
updates  
Martin Reinecke committed
65
    m2 = fft(me2[0].position)
Martin Reinecke's avatar
Martin Reinecke committed
66
    # m3 = fft(me3[0].position)
67

Martin Reinecke's avatar
Martin Reinecke committed
68
    # Plotting
Martin Reinecke's avatar
Martin Reinecke committed
69 70 71
    plotdict = {"xlabel": "Pixel index", "ylabel": "Pixel index",
                "colormap": "Planck-like"}
    ift.plotting.plot(mock_signal, name="mock_signal.png", **plotdict)
72
    logdata = np.log(ift.dobj.to_global_data(data.val)).reshape(signal_space.shape)
Martin Reinecke's avatar
Martin Reinecke committed
73 74
    ift.plotting.plot(ift.Field(signal_space,
                                val=ift.dobj.from_global_data(logdata)),
Martin Reinecke's avatar
Martin Reinecke committed
75 76 77 78
                      name="log_of_data.png", **plotdict)
    # ift.plotting.plot(m1,name='m_LBFGS.png', **plotdict)
    ift.plotting.plot(m2, name='m_Newton.png', **plotdict)
    # ift.plotting.plot(m3, name='m_SteepestDescent.png', **plotdict)
Theo Steininger's avatar
Theo Steininger committed
79

Martin Reinecke's avatar
Martin Reinecke committed
80
    # Probing the variance
Martin Reinecke's avatar
Martin Reinecke committed
81 82
    class Proby(ift.DiagonalProberMixin, ift.Prober):
        pass
Martin Reinecke's avatar
Martin Reinecke committed
83 84
    proby = Proby(signal_space, probe_count=1)
    proby(lambda z: fft(me2[0].curvature.inverse_times(fft.adjoint_times(z))))
85

Martin Reinecke's avatar
Martin Reinecke committed
86 87
    sm = ift.FFTSmoothingOperator(signal_space, sigma=0.02)
    variance = sm(proby.diagonal.weight(-1))
Martin Reinecke's avatar
Martin Reinecke committed
88
    ift.plotting.plot(variance, name='variance.png', **plotdict)