critical_filtering.py 4.92 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
import numpy as np
Martin Reinecke's avatar
updates  
Martin Reinecke committed
2
import nifty2go as ift
3

4
np.random.seed(42)
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
5 6
# np.seterr(all="raise",under="ignore")

Jakob Knollmueller's avatar
Jakob Knollmueller committed
7

Philipp Arras's avatar
Philipp Arras committed
8
def plot_parameters(m, t, p, p_d):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
9
    m = fft.adjoint_times(m)
Martin Reinecke's avatar
updates  
Martin Reinecke committed
10 11 12
    t = t.val.real
    p = p.val.real
    p_d = p_d.val.real
Martin Reinecke's avatar
Martin Reinecke committed
13
    ift.plotting.plot(m.real, name='map.png')
Jakob Knollmueller's avatar
Jakob Knollmueller committed
14

Jakob Knollmueller's avatar
Jakob Knollmueller committed
15

Martin Reinecke's avatar
updates  
Martin Reinecke committed
16
class AdjointFFTResponse(ift.LinearOperator):
17 18
    def __init__(self, FFT, R):
        super(AdjointFFTResponse, self).__init__()
19
        self._domain = FFT.target
Jakob Knollmueller's avatar
Jakob Knollmueller committed
20 21
        self._target = R.target
        self.R = R
22 23
        self.FFT = FFT

24
    def _times(self, x):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
25
        return self.R(self.FFT.adjoint_times(x))
26

27
    def _adjoint_times(self, x):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
28
        return self.FFT(self.R.adjoint_times(x))
Philipp Arras's avatar
Philipp Arras committed
29

30 31 32 33 34 35 36 37 38 39 40 41
    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

    @property
    def unitary(self):
        return False

Philipp Arras's avatar
Philipp Arras committed
42

43 44
if __name__ == "__main__":
    # Set up position space
Martin Reinecke's avatar
updates  
Martin Reinecke committed
45 46
    s_space = ift.RGSpace([128, 128])
    # s_space = ift.HPSpace(32)
47 48

    # Define harmonic transformation and associated harmonic space
Martin Reinecke's avatar
updates  
Martin Reinecke committed
49
    fft = ift.FFTOperator(s_space)
50 51
    h_space = fft.target[0]

Philipp Arras's avatar
Philipp Arras committed
52
    # Set up power space
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
53 54 55
    p_space = ift.PowerSpace(h_space,
                             binbounds=ift.PowerSpace.useful_binbounds(
                                 h_space, logarithmic=True))
56

Philipp Arras's avatar
Philipp Arras committed
57
    # Choose the prior correlation structure and defining correlation operator
58
    p_spec = (lambda k: (.5 / (k + 1) ** 3))
Martin Reinecke's avatar
updates  
Martin Reinecke committed
59
    S = ift.create_power_operator(h_space, power_spectrum=p_spec)
60

Philipp Arras's avatar
Philipp Arras committed
61
    # Draw a sample sh from the prior distribution in harmonic space
Martin Reinecke's avatar
Martin Reinecke committed
62
    sp = ift.PS_field(p_space, p_spec)
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
63
    sh = ift.power_synthesize(sp, real_signal=True)
64

Philipp Arras's avatar
Philipp Arras committed
65
    # Choose the measurement instrument
Martin Reinecke's avatar
Martin Reinecke committed
66
    # Instrument = SmoothingOperator(s_space, sigma=0.01)
Martin Reinecke's avatar
Martin Reinecke committed
67
    Instrument = ift.DiagonalOperator(ift.Field(s_space, 1.))
Jakob Knollmueller's avatar
Jakob Knollmueller committed
68
    # Instrument._diagonal.val[200:400, 200:400] = 0
Philipp Arras's avatar
Philipp Arras committed
69
    # Instrument._diagonal.val[64:512-64, 64:512-64] = 0
70

Philipp Arras's avatar
Philipp Arras committed
71
    # Add a harmonic transformation to the instrument
Jakob Knollmueller's avatar
Jakob Knollmueller committed
72 73
    R = AdjointFFTResponse(fft, Instrument)

74
    noise = 1.
75
    N = ift.DiagonalOperator(ift.Field.full(s_space, noise).weight(1))
Martin Reinecke's avatar
updates  
Martin Reinecke committed
76
    n = ift.Field.from_random(domain=s_space,
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
77 78 79
                              random_type='normal',
                              std=np.sqrt(noise),
                              mean=0)
80

Philipp Arras's avatar
Philipp Arras committed
81
    # Create mock data
82
    d = R(sh) + n
Jakob Knollmueller's avatar
Jakob Knollmueller committed
83

84 85
    # The information source
    j = R.adjoint_times(N.inverse_times(d))
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
86 87 88 89
    realized_power = ift.log(ift.power_analyze(sh,
                                               binbounds=p_space.binbounds))
    data_power = ift.log(ift.power_analyze(fft(d),
                                           binbounds=p_space.binbounds))
Martin Reinecke's avatar
updates  
Martin Reinecke committed
90
    d_data = d.val.real
Martin Reinecke's avatar
Martin Reinecke committed
91
    ift.plotting.plot(d.real, name="data.png")
92

Martin Reinecke's avatar
adjust  
Martin Reinecke committed
93 94
    IC1 = ift.GradientNormController(verbose=True, iteration_limit=100,
                                     tol_abs_gradnorm=0.1)
Martin Reinecke's avatar
updates  
Martin Reinecke committed
95
    minimizer1 = ift.RelaxedNewton(IC1)
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
96 97
    IC2 = ift.GradientNormController(verbose=True, iteration_limit=100,
                                     tol_abs_gradnorm=0.1)
Martin Reinecke's avatar
updates  
Martin Reinecke committed
98
    minimizer2 = ift.VL_BFGS(IC2, max_history_length=20)
Martin Reinecke's avatar
Martin Reinecke committed
99
    IC3 = ift.GradientNormController(verbose=True, iteration_limit=1000,
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
100
                                     tol_abs_gradnorm=0.1)
Martin Reinecke's avatar
updates  
Martin Reinecke committed
101
    minimizer3 = ift.SteepestDescent(IC3)
102

Philipp Arras's avatar
Philipp Arras committed
103
    # Set starting position
104
    flat_power = ift.Field.full(p_space, 1e-8)
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
105
    m0 = ift.power_synthesize(flat_power, real_signal=True)
106

Martin Reinecke's avatar
Martin Reinecke committed
107
    t0 = ift.Field(p_space,
Jakob Knollmueller's avatar
Jakob Knollmueller committed
108
            val=ift.dobj.from_global_data(-7.))
109

110
    for i in range(500):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
111 112

        S0 = ift.create_power_operator(h_space, power_spectrum=ift.exp(t0))
113

Philipp Arras's avatar
Philipp Arras committed
114
        # Initialize non-linear Wiener Filter energy
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
115 116 117
        ICI = ift.GradientNormController(verbose=False, name="ICI",
                                         iteration_limit=500,
                                         tol_abs_gradnorm=0.1)
118
        map_inverter = ift.ConjugateGradient(controller=ICI)
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
119 120 121
        map_energy = ift.library.WienerFilterEnergy(position=m0,
                                                    d=d, R=R, N=N, S=S0,
                                                    inverter=map_inverter)
Philipp Arras's avatar
Philipp Arras committed
122
        # Solve the Wiener Filter analytically
123
        D0 = map_energy.curvature
124
        m0 = D0.inverse_times(j)
Philipp Arras's avatar
Philipp Arras committed
125
        # Initialize power energy with updated parameters
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
126 127 128 129
        ICI2 = ift.GradientNormController(name="powI",
                                          verbose=False,
                                          iteration_limit=200,
                                          tol_abs_gradnorm=1e-5)
130
        power_inverter = ift.ConjugateGradient(controller=ICI2)
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
131 132 133
        power_energy = ift.library.CriticalPowerEnergy(
            position=t0, m=m0, D=D0, smoothness_prior=10., samples=3,
            inverter=power_inverter)
134

Martin Reinecke's avatar
Martin Reinecke committed
135
        (power_energy, convergence) = minimizer1(power_energy)
136

Philipp Arras's avatar
Philipp Arras committed
137
        # Set new power spectrum
Martin Reinecke's avatar
Martin Reinecke committed
138
        t0 = power_energy.position.real
139

Philipp Arras's avatar
Philipp Arras committed
140
        # Plot current estimate
Martin Reinecke's avatar
Martin Reinecke committed
141
        ift.dobj.mprint(i)
Philipp Arras's avatar
Philipp Arras committed
142
        if i % 5 == 0:
Martin Reinecke's avatar
updates  
Martin Reinecke committed
143
            plot_parameters(m0, t0, ift.log(sp), data_power)