wiener_filter_via_hamiltonian.py 2.75 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
import nifty5 as ift
Martin Reinecke's avatar
Martin Reinecke committed
2
import numpy as np
3

4
np.random.seed(42)
5

Martin Reinecke's avatar
Martin Reinecke committed
6

7
if __name__ == "__main__":
Jakob Knollmueller's avatar
Jakob Knollmueller committed
8
    # Set up position space
9
10
    # s_space = ift.RGSpace([128, 128])
    s_space = ift.HPSpace(32)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
11

Martin Reinecke's avatar
Martin Reinecke committed
12
13
    # Define associated harmonic space and harmonic transformation
    h_space = s_space.get_default_codomain()
Martin Reinecke's avatar
Martin Reinecke committed
14
    ht = ift.HarmonicTransformOperator(h_space, s_space)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
15

16
    # Choose prior correlation structure and define correlation operator
Martin Reinecke's avatar
Martin Reinecke committed
17
    p_spec = (lambda k: (42/(k+1)**3))
Jakob Knollmueller's avatar
Jakob Knollmueller committed
18

Martin Reinecke's avatar
Martin Reinecke committed
19
    S = ift.create_power_operator(h_space, power_spectrum=p_spec)
20

21
    # Draw sample sh from the prior distribution in harmonic space
22
    sh = S.draw_sample()
23

24
    # Choose measurement instrument
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
25
    diag = np.ones(s_space.shape)
26
27
28
29
30
31
32
    if len(s_space.shape) == 1:
        diag[3000:7000] = 0
    elif len(s_space.shape) == 2:
        diag[20:80, 20:80] = 0
    else:
        raise NotImplementedError

Martin Reinecke's avatar
Martin Reinecke committed
33
    diag = ift.Field.from_global_data(s_space, diag)
34
    Instrument = ift.DiagonalOperator(diag)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
35

36
    # Add harmonic transformation to the instrument
Martin Reinecke's avatar
Martin Reinecke committed
37
38
    R = Instrument*ht
    noiseless_data = R(sh)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
39
    signal_to_noise = 1.
Martin Reinecke's avatar
Martin Reinecke committed
40
    noise_amplitude = noiseless_data.val.std()/signal_to_noise
41
    N = ift.ScalingOperator(noise_amplitude**2, s_space)
Martin Reinecke's avatar
Martin Reinecke committed
42
    n = N.draw_sample()
43

44
    # Create mock data
Martin Reinecke's avatar
Martin Reinecke committed
45
    d = noiseless_data + n
46
    j = R.adjoint_times(N.inverse_times(d))
47

48
    # Choose minimization strategy
49
    ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=0.1)
50
    inverter = ift.ConjugateGradient(controller=ctrl)
51
    controller = ift.GradientNormController(name="min", tol_abs_gradnorm=0.1)
52
    minimizer = ift.RelaxedNewton(controller=controller)
Martin Reinecke's avatar
step 1  
Martin Reinecke committed
53
    m0 = ift.full(h_space, 0.)
54
55

    # Initialize Wiener filter energy
Martin Reinecke's avatar
Martin Reinecke committed
56
    energy = ift.library.WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S,
Philipp Arras's avatar
Fixup  
Philipp Arras committed
57
58
                                            inverter=inverter,
                                            sampling_inverter=inverter)
59

Martin Reinecke's avatar
Martin Reinecke committed
60
61
    energy, convergence = minimizer(energy)
    m = energy.position
Martin Reinecke's avatar
Martin Reinecke committed
62
    curv = energy.curvature
63

64
65
66
    # Generate plots
    zmax = max(ht(sh).max(), ht(m).max())
    zmin = min(ht(sh).min(), ht(m).min())
Martin Reinecke's avatar
Martin Reinecke committed
67
    plotdict = {"zmax": zmax, "zmin": zmin, "colormap": "Planck-like"}
Martin Reinecke's avatar
Martin Reinecke committed
68
    plotdict2 = {"colormap": "Planck-like"}
69
70
71
72
    ift.plot(ht(sh), name="mock_signal.png", **plotdict)
    ift.plot(ht(m), name="reconstruction.png", **plotdict)

    # Sample uncertainty map
73
    mean, variance = ift.probe_with_posterior_samples(curv, ht, 50)
Martin Reinecke's avatar
Martin Reinecke committed
74
    ift.plot(variance, name="posterior_variance.png", **plotdict2)
75
76
77
78
    ift.plot(mean+ht(m), name="posterior_mean.png", **plotdict)

    # try to do the same with diagonal probing
    variance = ift.probe_diagonal(ht*curv.inverse*ht.adjoint, 100)
Martin Reinecke's avatar
Martin Reinecke committed
79
    # sm = ift.FFTSmoothingOperator(s_space, sigma=0.015)
80
    ift.plot(variance, name="posterior_variance2.png", **plotdict)