wiener_filter_via_hamiltonian.py 3.07 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
import nifty4 as ift
Martin Reinecke's avatar
Martin Reinecke committed
2
import numpy as np
3

4
np.random.seed(42)
5

Martin Reinecke's avatar
Martin Reinecke committed
6

7
if __name__ == "__main__":
Jakob Knollmueller's avatar
Jakob Knollmueller committed
8
    # Set up position space
9
10
    # s_space = ift.RGSpace([128, 128])
    s_space = ift.HPSpace(32)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
11

Martin Reinecke's avatar
Martin Reinecke committed
12
13
    # Define associated harmonic space and harmonic transformation
    h_space = s_space.get_default_codomain()
Martin Reinecke's avatar
Martin Reinecke committed
14
    ht = ift.HarmonicTransformOperator(h_space, s_space)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
15

16
    # Set up power space
Martin Reinecke's avatar
Martin Reinecke committed
17
    p_space = ift.PowerSpace(h_space)
18

19
    # Choose prior correlation structure and define correlation operator
Martin Reinecke's avatar
Martin Reinecke committed
20
    p_spec = (lambda k: (42/(k+1)**3))
Jakob Knollmueller's avatar
Jakob Knollmueller committed
21

Martin Reinecke's avatar
Martin Reinecke committed
22
    S = ift.create_power_operator(h_space, power_spectrum=p_spec)
23

24
    # Draw sample sh from the prior distribution in harmonic space
Martin Reinecke's avatar
Martin Reinecke committed
25
    sp = ift.PS_field(p_space, p_spec)
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
26
    sh = ift.power_synthesize(sp, real_signal=True)
27

28
    # Choose measurement instrument
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
29
    diag = np.ones(s_space.shape)
30
31
32
33
34
35
36
    if len(s_space.shape) == 1:
        diag[3000:7000] = 0
    elif len(s_space.shape) == 2:
        diag[20:80, 20:80] = 0
    else:
        raise NotImplementedError

Martin Reinecke's avatar
Martin Reinecke committed
37
    diag = ift.Field(s_space, ift.dobj.from_global_data(diag))
38
    Instrument = ift.DiagonalOperator(diag)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
39

40
    # Add harmonic transformation to the instrument
Martin Reinecke's avatar
Martin Reinecke committed
41
42
    R = Instrument*ht
    noiseless_data = R(sh)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
43
    signal_to_noise = 1.
Martin Reinecke's avatar
Martin Reinecke committed
44
    noise_amplitude = noiseless_data.val.std()/signal_to_noise
45
    N = ift.ScalingOperator(noise_amplitude**2, s_space)
Martin Reinecke's avatar
Martin Reinecke committed
46
    n = ift.Field.from_random(domain=s_space,
Martin Reinecke's avatar
Martin Reinecke committed
47
                              random_type='normal',
Martin Reinecke's avatar
Martin Reinecke committed
48
                              std=noise_amplitude,
Martin Reinecke's avatar
Martin Reinecke committed
49
                              mean=0)
50

51
    # Create mock data
Martin Reinecke's avatar
Martin Reinecke committed
52
    d = noiseless_data + n
53
    j = R.adjoint_times(N.inverse_times(d))
54

55
    # Choose minimization strategy
56
    ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=0.1)
57
    inverter = ift.ConjugateGradient(controller=ctrl)
58
    controller = ift.GradientNormController(name="min", tol_abs_gradnorm=0.1)
59
    minimizer = ift.RelaxedNewton(controller=controller)
60
    m0 = ift.Field.zeros(h_space)
61
62

    # Initialize Wiener filter energy
Martin Reinecke's avatar
Martin Reinecke committed
63
    energy = ift.library.WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S,
Martin Reinecke's avatar
Martin Reinecke committed
64
                                            inverter=inverter)
65

Martin Reinecke's avatar
Martin Reinecke committed
66
67
    energy, convergence = minimizer(energy)
    m = energy.position
Martin Reinecke's avatar
Martin Reinecke committed
68
    curv = energy.curvature
69

70
71
72
    # Generate plots
    zmax = max(ht(sh).max(), ht(m).max())
    zmin = min(ht(sh).min(), ht(m).min())
Martin Reinecke's avatar
Martin Reinecke committed
73
    plotdict = {"zmax": zmax, "zmin": zmin, "colormap": "Planck-like"}
Martin Reinecke's avatar
Martin Reinecke committed
74
    plotdict2 = {"colormap": "Planck-like"}
75
76
77
78
    ift.plot(ht(sh), name="mock_signal.png", **plotdict)
    ift.plot(ht(m), name="reconstruction.png", **plotdict)

    # Sample uncertainty map
Martin Reinecke's avatar
Martin Reinecke committed
79
80
81
    sample_variance = ift.Field.zeros(s_space)
    sample_mean = ift.Field.zeros(s_space)

82
    mean, variance = ift.probe_with_posterior_samples(curv, ht, 50)
Martin Reinecke's avatar
Martin Reinecke committed
83
    ift.plot(variance, name="posterior_variance.png", **plotdict2)
84
85
86
87
    ift.plot(mean+ht(m), name="posterior_mean.png", **plotdict)

    # try to do the same with diagonal probing
    variance = ift.probe_diagonal(ht*curv.inverse*ht.adjoint, 100)
Martin Reinecke's avatar
Martin Reinecke committed
88
    # sm = ift.FFTSmoothingOperator(s_space, sigma=0.015)
89
    ift.plot(variance, name="posterior_variance2.png", **plotdict)