wiener_filter_via_hamiltonian.py 2.72 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
import nifty5 as ift
Martin Reinecke's avatar
Martin Reinecke committed
2
import numpy as np
3

4
np.random.seed(42)
5

Martin Reinecke's avatar
Martin Reinecke committed
6

7
if __name__ == "__main__":
Jakob Knollmueller's avatar
Jakob Knollmueller committed
8
    # Set up position space
9 10
    # s_space = ift.RGSpace([128, 128])
    s_space = ift.HPSpace(32)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
11

Martin Reinecke's avatar
Martin Reinecke committed
12 13
    # Define associated harmonic space and harmonic transformation
    h_space = s_space.get_default_codomain()
Martin Reinecke's avatar
Martin Reinecke committed
14
    ht = ift.HarmonicTransformOperator(h_space, s_space)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
15

16
    # Choose prior correlation structure and define correlation operator
Martin Reinecke's avatar
Martin Reinecke committed
17
    p_spec = (lambda k: (42/(k+1)**3))
Jakob Knollmueller's avatar
Jakob Knollmueller committed
18

Martin Reinecke's avatar
Martin Reinecke committed
19
    S = ift.create_power_operator(h_space, power_spectrum=p_spec)
20

21
    # Draw sample sh from the prior distribution in harmonic space
22
    sh = S.draw_sample()
23

24
    # Choose measurement instrument
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
25
    diag = np.ones(s_space.shape)
26 27 28 29 30 31 32
    if len(s_space.shape) == 1:
        diag[3000:7000] = 0
    elif len(s_space.shape) == 2:
        diag[20:80, 20:80] = 0
    else:
        raise NotImplementedError

Martin Reinecke's avatar
Martin Reinecke committed
33
    diag = ift.Field.from_global_data(s_space, diag)
34
    Instrument = ift.DiagonalOperator(diag)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
35

36
    # Add harmonic transformation to the instrument
Martin Reinecke's avatar
Martin Reinecke committed
37 38
    R = Instrument*ht
    noiseless_data = R(sh)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
39
    signal_to_noise = 1.
Martin Reinecke's avatar
Martin Reinecke committed
40
    noise_amplitude = noiseless_data.val.std()/signal_to_noise
41
    N = ift.ScalingOperator(noise_amplitude**2, s_space)
Martin Reinecke's avatar
Martin Reinecke committed
42
    n = N.draw_sample()
43

44
    # Create mock data
Martin Reinecke's avatar
Martin Reinecke committed
45
    d = noiseless_data + n
46
    j = R.adjoint_times(N.inverse_times(d))
47

48
    # Choose minimization strategy
49 50
    ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=0.1)
    controller = ift.GradientNormController(name="min", tol_abs_gradnorm=0.1)
51
    minimizer = ift.RelaxedNewton(controller=controller)
Martin Reinecke's avatar
step 1  
Martin Reinecke committed
52
    m0 = ift.full(h_space, 0.)
53 54

    # Initialize Wiener filter energy
Martin Reinecke's avatar
Martin Reinecke committed
55
    energy = ift.library.WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S,
56 57
                                            iteration_controller=ctrl,
                                            iteration_controller_sampling=ctrl)
58

Martin Reinecke's avatar
Martin Reinecke committed
59 60
    energy, convergence = minimizer(energy)
    m = energy.position
Martin Reinecke's avatar
Martin Reinecke committed
61
    curv = energy.curvature
62

63 64 65
    # Generate plots
    zmax = max(ht(sh).max(), ht(m).max())
    zmin = min(ht(sh).min(), ht(m).min())
Martin Reinecke's avatar
Martin Reinecke committed
66
    plotdict = {"zmax": zmax, "zmin": zmin, "colormap": "Planck-like"}
Martin Reinecke's avatar
Martin Reinecke committed
67
    plotdict2 = {"colormap": "Planck-like"}
68 69 70 71
    ift.plot(ht(sh), name="mock_signal.png", **plotdict)
    ift.plot(ht(m), name="reconstruction.png", **plotdict)

    # Sample uncertainty map
72
    mean, variance = ift.probe_with_posterior_samples(curv, ht, 50)
Martin Reinecke's avatar
Martin Reinecke committed
73
    ift.plot(variance, name="posterior_variance.png", **plotdict2)
74 75 76 77
    ift.plot(mean+ht(m), name="posterior_mean.png", **plotdict)

    # try to do the same with diagonal probing
    variance = ift.probe_diagonal(ht*curv.inverse*ht.adjoint, 100)
Martin Reinecke's avatar
Martin Reinecke committed
78
    # sm = ift.FFTSmoothingOperator(s_space, sigma=0.015)
79
    ift.plot(variance, name="posterior_variance2.png", **plotdict)