log_normal_wiener_filter.py 3.41 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
import nifty4 as ift
Martin Reinecke's avatar
updates  
Martin Reinecke committed
2
import numpy as np
Theo Steininger's avatar
Theo Steininger committed
3

4
if __name__ == "__main__":
Martin Reinecke's avatar
updates  
Martin Reinecke committed
5
    np.random.seed(42)
Martin Reinecke's avatar
Martin Reinecke committed
6
    # Setting up parameters
7
8
    correlation_length = 1.     # Typical distance over which the field is correlated
    field_variance = 2.         # Variance of field in position space
Theo Steininger's avatar
Theo Steininger committed
9
    response_sigma = 0.02       # Smoothing length of response (in same unit as L)
Martin Reinecke's avatar
Martin Reinecke committed
10
    signal_to_noise = 100         # The signal to noise ratio
11
    np.random.seed(43)          # Fixing the random seed
Martin Reinecke's avatar
Martin Reinecke committed
12

13
14
15
16
    def power_spectrum(k):      # Defining the power spectrum
        a = 4 * correlation_length * field_variance**2
        return a / (1 + k * correlation_length) ** 4

Martin Reinecke's avatar
Martin Reinecke committed
17
18
19
20
    # Setting up the geometry
    L = 2.  # Total side-length of the domain
    N_pixels = 128  # Grid resolution (pixels per axis)
    # signal_space = ift.RGSpace([N_pixels, N_pixels], distances=L/N_pixels)
Martin Reinecke's avatar
updates  
Martin Reinecke committed
21
22
    signal_space = ift.HPSpace(16)
    harmonic_space = signal_space.get_default_codomain()
23
    HT = ift.HarmonicTransformOperator(harmonic_space, target=signal_space)
Martin Reinecke's avatar
updates  
Martin Reinecke committed
24
    power_space = ift.PowerSpace(harmonic_space)
25

Martin Reinecke's avatar
Martin Reinecke committed
26
27
28
    # Creating the mock signal
    S = ift.create_power_operator(harmonic_space,
                                  power_spectrum=power_spectrum)
Martin Reinecke's avatar
Martin Reinecke committed
29
    mock_power = ift.PS_field(power_space, power_spectrum)
Martin Reinecke's avatar
Martin Reinecke committed
30
    mock_signal = ift.power_synthesize(mock_power, real_signal=True)
31
32

    # Setting up an exemplary response
33
    mask = ift.Field.ones(signal_space)
34
    N10 = int(N_pixels/10)
Martin Reinecke's avatar
Martin Reinecke committed
35
    # mask.val[N10*5:N10*9, N10*5:N10*9] = 0.
Martin Reinecke's avatar
Martin Reinecke committed
36
37
    R = ift.GeometryRemover(signal_space)
    R = R*ift.DiagonalOperator(mask)
38
    R = R*HT
Martin Reinecke's avatar
Martin Reinecke committed
39
    R = R * ift.create_harmonic_smoothing_operator((harmonic_space,),0,response_sigma)
40
41
42
    data_domain = R.target[0]

    # Setting up the noise covariance and drawing a random noise realization
Martin Reinecke's avatar
Martin Reinecke committed
43
    noiseless_data = R(mock_signal)
Martin Reinecke's avatar
Martin Reinecke committed
44
    noise_amplitude = noiseless_data.val.std()/signal_to_noise
45
    N = ift.ScalingOperator(noise_amplitude**2, data_domain)
Martin Reinecke's avatar
Martin Reinecke committed
46
47
48
49
    noise = ift.Field.from_random(
        domain=data_domain, random_type='normal',
        std=noise_amplitude, mean=0)
    data = noiseless_data + noise
50
51

    # Wiener filter
52
    m0 = ift.Field.zeros(harmonic_space)
53
54
    ctrl = ift.GradientNormController(tol_abs_gradnorm=0.0001)
    ctrl2 = ift.GradientNormController(tol_abs_gradnorm=0.1, name="outer")
55
    inverter = ift.ConjugateGradient(controller=ctrl)
Martin Reinecke's avatar
Martin Reinecke committed
56
    energy = ift.library.LogNormalWienerFilterEnergy(m0, data, R,
Martin Reinecke's avatar
Martin Reinecke committed
57
                                                     N, S, inverter=inverter)
Martin Reinecke's avatar
Martin Reinecke committed
58
    #minimizer = ift.VL_BFGS(controller=ctrl2, max_history_length=20)
Martin Reinecke's avatar
Martin Reinecke committed
59
    minimizer = ift.RelaxedNewton(controller=ctrl2)
Martin Reinecke's avatar
Martin Reinecke committed
60
    #minimizer = ift.SteepestDescent(controller=ctrl2)
Martin Reinecke's avatar
updates  
Martin Reinecke committed
61

Martin Reinecke's avatar
Martin Reinecke committed
62
    me = minimizer(energy)
63
    m = HT(me[0].position)
64

Martin Reinecke's avatar
Martin Reinecke committed
65
    # Plotting
66
    plotdict = {"colormap": "Planck-like"}
67
    ift.plot(HT(mock_signal), name="mock_signal.png", **plotdict)
68
    logdata = np.log(ift.dobj.to_global_data(data.val)).reshape(signal_space.shape)
69
70
    ift.plot(ift.Field(signal_space, val=ift.dobj.from_global_data(logdata)),
             name="log_of_data.png", **plotdict)
Martin Reinecke's avatar
Martin Reinecke committed
71
    ift.plot(m, name='m.png', **plotdict)
Theo Steininger's avatar
Theo Steininger committed
72

Martin Reinecke's avatar
Martin Reinecke committed
73
    # Probing the variance
Martin Reinecke's avatar
Martin Reinecke committed
74
75
    class Proby(ift.DiagonalProberMixin, ift.Prober):
        pass
Martin Reinecke's avatar
Martin Reinecke committed
76
    proby = Proby(signal_space, probe_count=1)
77
    proby(lambda z: HT(me2[0].curvature.inverse_times(HT.adjoint_times(z))))
78

Martin Reinecke's avatar
Martin Reinecke committed
79
    sm = ift.FFTSmoothingOperator(signal_space, sigma=0.02)
Martin Reinecke's avatar
fixes  
Martin Reinecke committed
80
    variance = sm(proby.diagonal.weight(-1))
81
    ift.plot(variance, name='variance.png', **plotdict)