critical_filtering.py 4.51 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
import numpy as np
Martin Reinecke's avatar
updates    
Martin Reinecke committed
2
import nifty2go as ift
3

Philipp Arras's avatar
Philipp Arras committed
4
5
import plotly.graph_objs as go
import plotly.offline as pl
6

7
np.random.seed(42)
8

Jakob Knollmueller's avatar
Jakob Knollmueller committed
9

Philipp Arras's avatar
Philipp Arras committed
10
def plot_parameters(m, t, p, p_d):
Martin Reinecke's avatar
updates    
Martin Reinecke committed
11
    x = ift.log(t.domain[0].kindex)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
12
    m = fft.adjoint_times(m)
Martin Reinecke's avatar
updates    
Martin Reinecke committed
13
14
15
16
    m = m.val.real
    t = t.val.real
    p = p.val.real
    p_d = p_d.val.real
Martin Reinecke's avatar
Martin Reinecke committed
17
    pl.plot([go.Heatmap(z=m)], filename='map.html', auto_open=False)
Philipp Arras's avatar
Philipp Arras committed
18
    pl.plot([go.Scatter(x=x, y=t), go.Scatter(x=x, y=p),
Martin Reinecke's avatar
Martin Reinecke committed
19
             go.Scatter(x=x, y=p_d)], filename="t.html", auto_open=False)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
20

Jakob Knollmueller's avatar
Jakob Knollmueller committed
21

Martin Reinecke's avatar
updates    
Martin Reinecke committed
22
class AdjointFFTResponse(ift.LinearOperator):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
23
24
    def __init__(self, FFT, R, default_spaces=None):
        super(AdjointFFTResponse, self).__init__(default_spaces)
25
        self._domain = FFT.target
Jakob Knollmueller's avatar
Jakob Knollmueller committed
26
27
        self._target = R.target
        self.R = R
28
29
30
        self.FFT = FFT

    def _times(self, x, spaces=None):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
31
        return self.R(self.FFT.adjoint_times(x))
32
33

    def _adjoint_times(self, x, spaces=None):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
34
        return self.FFT(self.R.adjoint_times(x))
Philipp Arras's avatar
Philipp Arras committed
35

36
37
38
39
40
41
42
43
44
45
46
47
    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

    @property
    def unitary(self):
        return False

Philipp Arras's avatar
Philipp Arras committed
48

49
50
if __name__ == "__main__":
    # Set up position space
Martin Reinecke's avatar
updates    
Martin Reinecke committed
51
52
    s_space = ift.RGSpace([128, 128])
    # s_space = ift.HPSpace(32)
53
54

    # Define harmonic transformation and associated harmonic space
Martin Reinecke's avatar
updates    
Martin Reinecke committed
55
    fft = ift.FFTOperator(s_space)
56
57
    h_space = fft.target[0]

Philipp Arras's avatar
Philipp Arras committed
58
    # Set up power space
Martin Reinecke's avatar
updates    
Martin Reinecke committed
59
    p_space = ift.PowerSpace(h_space)
60

Philipp Arras's avatar
Philipp Arras committed
61
    # Choose the prior correlation structure and defining correlation operator
62
    p_spec = (lambda k: (.5 / (k + 1) ** 3))
Martin Reinecke's avatar
updates    
Martin Reinecke committed
63
    S = ift.create_power_operator(h_space, power_spectrum=p_spec)
64

Philipp Arras's avatar
Philipp Arras committed
65
    # Draw a sample sh from the prior distribution in harmonic space
Martin Reinecke's avatar
updates    
Martin Reinecke committed
66
    sp = ift.Field(p_space,  val=p_spec(p_space.kindex))
67
68
    sh = sp.power_synthesize(real_signal=True)

Philipp Arras's avatar
Philipp Arras committed
69
    # Choose the measurement instrument
Martin Reinecke's avatar
Martin Reinecke committed
70
    # Instrument = SmoothingOperator(s_space, sigma=0.01)
Martin Reinecke's avatar
updates    
Martin Reinecke committed
71
    Instrument = ift.DiagonalOperator(s_space, diagonal=1.)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
72
    # Instrument._diagonal.val[200:400, 200:400] = 0
Philipp Arras's avatar
Philipp Arras committed
73
    # Instrument._diagonal.val[64:512-64, 64:512-64] = 0
74

Philipp Arras's avatar
Philipp Arras committed
75
    # Add a harmonic transformation to the instrument
Jakob Knollmueller's avatar
Jakob Knollmueller committed
76
77
    R = AdjointFFTResponse(fft, Instrument)

78
    noise = 1.
Martin Reinecke's avatar
updates    
Martin Reinecke committed
79
80
    N = ift.DiagonalOperator(s_space, diagonal=noise, bare=True)
    n = ift.Field.from_random(domain=s_space,
81
                          random_type='normal',
Martin Reinecke's avatar
updates    
Martin Reinecke committed
82
                          std=ift.sqrt(noise),
83
84
                          mean=0)

Philipp Arras's avatar
Philipp Arras committed
85
    # Create mock data
86
    d = R(sh) + n
Jakob Knollmueller's avatar
Jakob Knollmueller committed
87

88
89
    # The information source
    j = R.adjoint_times(N.inverse_times(d))
Martin Reinecke's avatar
updates    
Martin Reinecke committed
90
91
92
93
    realized_power = ift.log(sh.power_analyze(binbounds=p_space.binbounds))
    data_power = ift.log(fft(d).power_analyze(binbounds=p_space.binbounds))
    d_data = d.val.real
    pl.plot([go.Heatmap(z=d_data)], filename='data.html', auto_open=False)
94

Philipp Arras's avatar
Philipp Arras committed
95
96
    #  Minimization strategy
    def convergence_measure(a_energy, iteration):  # returns current energy
97
        x = a_energy.value
Philipp Arras's avatar
Philipp Arras committed
98
        print(x, iteration)
99

Martin Reinecke's avatar
updates    
Martin Reinecke committed
100
101
102
103
104
105
    IC1 = ift.DefaultIterationController(verbose=True,iteration_limit=5)
    minimizer1 = ift.RelaxedNewton(IC1)
    IC2 = ift.DefaultIterationController(verbose=True,iteration_limit=30)
    minimizer2 = ift.VL_BFGS(IC2, max_history_length=20)
    IC3 = ift.DefaultIterationController(verbose=True,iteration_limit=100)
    minimizer3 = ift.SteepestDescent(IC3)
106

Philipp Arras's avatar
Philipp Arras committed
107
    # Set starting position
Martin Reinecke's avatar
updates    
Martin Reinecke committed
108
    flat_power = ift.Field(p_space, val=1e-8)
109
110
    m0 = flat_power.power_synthesize(real_signal=True)

Martin Reinecke's avatar
updates    
Martin Reinecke committed
111
112
113
    def ps0(k):
        return (1./(1.+k)**2)
    t0 = ift.Field(p_space, val=ift.log(1./(1+p_space.kindex)**2))
114

115
    for i in range(500):
Martin Reinecke's avatar
updates    
Martin Reinecke committed
116
        S0 = ift.create_power_operator(h_space, power_spectrum=ps0)
117

Philipp Arras's avatar
Philipp Arras committed
118
        # Initialize non-linear Wiener Filter energy
Martin Reinecke's avatar
updates    
Martin Reinecke committed
119
120
121
        ICI = ift.DefaultIterationController(verbose=True,iteration_limit=60)
        inverter = ift.ConjugateGradient(controller=ICI,preconditioner=S0.times)
        map_energy = ift.library.WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S0, inverter=inverter)
Philipp Arras's avatar
Philipp Arras committed
122
        # Solve the Wiener Filter analytically
123
        D0 = map_energy.curvature
124
        m0 = D0.inverse_times(j)
Philipp Arras's avatar
Philipp Arras committed
125
        # Initialize power energy with updated parameters
Martin Reinecke's avatar
updates    
Martin Reinecke committed
126
        power_energy = ift.library.CriticalPowerEnergy(position=t0, m=m0, D=D0,
Philipp Arras's avatar
Philipp Arras committed
127
                                           smoothness_prior=10., samples=3)
128

Martin Reinecke's avatar
Martin Reinecke committed
129
        (power_energy, convergence) = minimizer2(power_energy)
130

Philipp Arras's avatar
Philipp Arras committed
131
132
        # Set new power spectrum
        t0.val = power_energy.position.val.real
133

Philipp Arras's avatar
Philipp Arras committed
134
135
136
        # Plot current estimate
        print(i)
        if i % 5 == 0:
Martin Reinecke's avatar
updates    
Martin Reinecke committed
137
            plot_parameters(m0, t0, ift.log(sp), data_power)