critical_filtering.py 3.39 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
import numpy as np
Martin Reinecke's avatar
Martin Reinecke committed
2
import nifty4 as ift
3

4
np.random.seed(42)
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
5 6
# np.seterr(all="raise",under="ignore")

Jakob Knollmueller's avatar
Jakob Knollmueller committed
7

8 9
if __name__ == "__main__":
    # Set up position space
Martin Reinecke's avatar
updates  
Martin Reinecke committed
10 11
    s_space = ift.RGSpace([128, 128])
    # s_space = ift.HPSpace(32)
12 13

    # Define harmonic transformation and associated harmonic space
Martin Reinecke's avatar
Martin Reinecke committed
14 15
    h_space = s_space.get_default_codomain()
    fft = ift.FFTOperator(h_space, s_space)
16

Philipp Arras's avatar
Philipp Arras committed
17
    # Set up power space
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
18 19 20
    p_space = ift.PowerSpace(h_space,
                             binbounds=ift.PowerSpace.useful_binbounds(
                                 h_space, logarithmic=True))
21

Philipp Arras's avatar
Philipp Arras committed
22
    # Choose the prior correlation structure and defining correlation operator
23
    p_spec = (lambda k: (.5 / (k + 1) ** 3))
Martin Reinecke's avatar
updates  
Martin Reinecke committed
24
    S = ift.create_power_operator(h_space, power_spectrum=p_spec)
25

Philipp Arras's avatar
Philipp Arras committed
26
    # Draw a sample sh from the prior distribution in harmonic space
Martin Reinecke's avatar
Martin Reinecke committed
27
    sp = ift.PS_field(p_space, p_spec)
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
28
    sh = ift.power_synthesize(sp, real_signal=True)
29

Philipp Arras's avatar
Philipp Arras committed
30
    # Choose the measurement instrument
Martin Reinecke's avatar
Martin Reinecke committed
31
    # Instrument = SmoothingOperator(s_space, sigma=0.01)
Martin Reinecke's avatar
Martin Reinecke committed
32
    Instrument = ift.DiagonalOperator(ift.Field(s_space, 1.))
Jakob Knollmueller's avatar
Jakob Knollmueller committed
33
    # Instrument._diagonal.val[200:400, 200:400] = 0
Philipp Arras's avatar
Philipp Arras committed
34
    # Instrument._diagonal.val[64:512-64, 64:512-64] = 0
35

Philipp Arras's avatar
Philipp Arras committed
36
    # Add a harmonic transformation to the instrument
Martin Reinecke's avatar
Martin Reinecke committed
37
    R = Instrument*fft
Jakob Knollmueller's avatar
Jakob Knollmueller committed
38

39
    noise = 1.
40
    N = ift.DiagonalOperator(ift.Field.full(s_space, noise).weight(1))
Martin Reinecke's avatar
Martin Reinecke committed
41 42
    n = ift.Field.from_random(domain=s_space, random_type='normal',
                              std=np.sqrt(noise), mean=0)
43

Philipp Arras's avatar
Philipp Arras committed
44
    # Create mock data
45
    d = R(sh) + n
Jakob Knollmueller's avatar
Jakob Knollmueller committed
46

47 48
    # The information source
    j = R.adjoint_times(N.inverse_times(d))
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
49 50
    realized_power = ift.log(ift.power_analyze(sh,
                                               binbounds=p_space.binbounds))
Martin Reinecke's avatar
Martin Reinecke committed
51
    data_power = ift.log(ift.power_analyze(fft.adjoint_times(d),
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
52
                                           binbounds=p_space.binbounds))
Martin Reinecke's avatar
Martin Reinecke committed
53 54
    d_data = d.val
    ift.plotting.plot(d, name="data.png")
55

Martin Reinecke's avatar
adjust  
Martin Reinecke committed
56 57
    IC1 = ift.GradientNormController(verbose=True, iteration_limit=100,
                                     tol_abs_gradnorm=0.1)
Martin Reinecke's avatar
Martin Reinecke committed
58 59 60 61 62 63 64 65
    minimizer = ift.RelaxedNewton(IC1)

    ICI = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=0.1)
    map_inverter = ift.ConjugateGradient(controller=ICI)

    ICI2 = ift.GradientNormController(iteration_limit=200,
                                      tol_abs_gradnorm=1e-5)
    power_inverter = ift.ConjugateGradient(controller=ICI2)
66

Philipp Arras's avatar
Philipp Arras committed
67
    # Set starting position
68
    flat_power = ift.Field.full(p_space, 1e-8)
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
69
    m0 = ift.power_synthesize(flat_power, real_signal=True)
Martin Reinecke's avatar
Martin Reinecke committed
70
    t0 = ift.Field(p_space, val=-7.)
71

72
    for i in range(500):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
73
        S0 = ift.create_power_operator(h_space, power_spectrum=ift.exp(t0))
74

Philipp Arras's avatar
Philipp Arras committed
75
        # Initialize non-linear Wiener Filter energy
Martin Reinecke's avatar
Martin Reinecke committed
76 77
        map_energy = ift.library.WienerFilterEnergy(
            position=m0, d=d, R=R, N=N, S=S0, inverter=map_inverter)
Philipp Arras's avatar
Philipp Arras committed
78
        # Solve the Wiener Filter analytically
79
        D0 = map_energy.curvature
80
        m0 = D0.inverse_times(j)
Philipp Arras's avatar
Philipp Arras committed
81
        # Initialize power energy with updated parameters
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
82 83 84
        power_energy = ift.library.CriticalPowerEnergy(
            position=t0, m=m0, D=D0, smoothness_prior=10., samples=3,
            inverter=power_inverter)
85

Martin Reinecke's avatar
Martin Reinecke committed
86
        power_energy = minimizer(power_energy)[0]
87

Philipp Arras's avatar
Philipp Arras committed
88
        # Set new power spectrum
Martin Reinecke's avatar
Martin Reinecke committed
89
        t0 = power_energy.position
90

Philipp Arras's avatar
Philipp Arras committed
91
        # Plot current estimate
Martin Reinecke's avatar
Martin Reinecke committed
92
        ift.dobj.mprint(i)
Philipp Arras's avatar
Philipp Arras committed
93
        if i % 5 == 0:
Martin Reinecke's avatar
Martin Reinecke committed
94
            ift.plotting.plot(fft(m0), name='map.png')