critical_filtering.py 5.48 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
import numpy as np
Martin Reinecke's avatar
Martin Reinecke committed
2
import nifty as ift
3
from nifty.library.critical_filter import CriticalPowerEnergy
Philipp Arras's avatar
Philipp Arras committed
4
from nifty.library.wiener_filter import WienerFilterEnergy
5

Philipp Arras's avatar
Philipp Arras committed
6
7
import plotly.graph_objs as go
import plotly.offline as pl
8
from mpi4py import MPI
Philipp Arras's avatar
Philipp Arras committed
9

10
11
12
comm = MPI.COMM_WORLD
rank = comm.rank

13
np.random.seed(44)
14

Jakob Knollmueller's avatar
Jakob Knollmueller committed
15

16
def plot_parameters(m, t, p, p_sig,p_d):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
17

Martin Reinecke's avatar
Martin Reinecke committed
18
    x = np.log(t.domain[0].kindex)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
19
    m = fft.adjoint_times(m)
20
21
22
    m = m.val.get_full_data().real
    t = t.val.get_full_data().real
    p = p.val.get_full_data().real
23
    pd_sig = p_sig.val.get_full_data()
24
    p_d = p_d.val.get_full_data().real
Martin Reinecke's avatar
Martin Reinecke committed
25
    pl.plot([go.Heatmap(z=m)], filename='map.html', auto_open=False)
Philipp Arras's avatar
Philipp Arras committed
26
    pl.plot([go.Scatter(x=x, y=t), go.Scatter(x=x, y=p),
27
             go.Scatter(x=x, y=p_d),go.Scatter(x=x, y=pd_sig)], filename="t.html", auto_open=False)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
28

Jakob Knollmueller's avatar
Jakob Knollmueller committed
29

Martin Reinecke's avatar
Martin Reinecke committed
30
class AdjointFFTResponse(ift.LinearOperator):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
31
32
    def __init__(self, FFT, R, default_spaces=None):
        super(AdjointFFTResponse, self).__init__(default_spaces)
33
        self._domain = FFT.target
Jakob Knollmueller's avatar
Jakob Knollmueller committed
34
35
        self._target = R.target
        self.R = R
36
37
38
        self.FFT = FFT

    def _times(self, x, spaces=None):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
39
        return self.R(self.FFT.adjoint_times(x))
40
41

    def _adjoint_times(self, x, spaces=None):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
42
        return self.FFT(self.R.adjoint_times(x))
Philipp Arras's avatar
Philipp Arras committed
43

44
45
46
47
48
49
50
51
52
53
54
55
    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

    @property
    def unitary(self):
        return False

Philipp Arras's avatar
Philipp Arras committed
56

57
58
59
60
61
if __name__ == "__main__":

    distribution_strategy = 'not'

    # Set up position space
62
    dist = 1/128. *10
63
    s_space = ift.RGSpace([128, 128], distances=[dist,dist])
Martin Reinecke's avatar
Martin Reinecke committed
64
    # s_space = ift.HPSpace(32)
65
66

    # Define harmonic transformation and associated harmonic space
Martin Reinecke's avatar
Martin Reinecke committed
67
    fft = ift.FFTOperator(s_space)
68
69
    h_space = fft.target[0]

Philipp Arras's avatar
Philipp Arras committed
70
    # Set up power space
Martin Reinecke's avatar
Martin Reinecke committed
71
72
73
74
    p_space = ift.PowerSpace(h_space,
                             binbounds=ift.PowerSpace.useful_binbounds(
                                       h_space, logarithmic=True),
                             distribution_strategy=distribution_strategy)
75

Philipp Arras's avatar
Philipp Arras committed
76
    # Choose the prior correlation structure and defining correlation operator
77
78
    p_spec = (lambda k: (.5 / (k + 1) ** 3))
    # p_spec = (lambda k: 1)
Martin Reinecke's avatar
Martin Reinecke committed
79
80
    S = ift.create_power_operator(h_space, power_spectrum=p_spec,
                                  distribution_strategy=distribution_strategy)
81

Philipp Arras's avatar
Philipp Arras committed
82
    # Draw a sample sh from the prior distribution in harmonic space
Martin Reinecke's avatar
Martin Reinecke committed
83
84
    sp = ift.Field(p_space,  val=p_spec,
                   distribution_strategy=distribution_strategy)
85
86
    sh = sp.power_synthesize(real_signal=True)

Philipp Arras's avatar
Philipp Arras committed
87
    # Choose the measurement instrument
Martin Reinecke's avatar
Martin Reinecke committed
88
89
    # Instrument = ift.SmoothingOperator(s_space, sigma=0.01)
    Instrument = ift.DiagonalOperator(s_space, diagonal=1.)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
90
    # Instrument._diagonal.val[200:400, 200:400] = 0
Philipp Arras's avatar
Philipp Arras committed
91
    # Instrument._diagonal.val[64:512-64, 64:512-64] = 0
92

Philipp Arras's avatar
Philipp Arras committed
93
    # Add a harmonic transformation to the instrument
Jakob Knollmueller's avatar
Jakob Knollmueller committed
94
95
    R = AdjointFFTResponse(fft, Instrument)

96
    noise = .1
Martin Reinecke's avatar
Martin Reinecke committed
97
98
99
100
    ndiag = ift.Field(s_space, noise).weight(1)
    N = ift.DiagonalOperator(s_space, ndiag)
    n = ift.Field.from_random(domain=s_space,
                              random_type='normal',
101
                              std=np.sqrt(noise),
Martin Reinecke's avatar
Martin Reinecke committed
102
                              mean=0)
103

Philipp Arras's avatar
Philipp Arras committed
104
    # Create mock data
105
    d = R(sh) + n
Jakob Knollmueller's avatar
Jakob Knollmueller committed
106

107
108
    # The information source
    j = R.adjoint_times(N.inverse_times(d))
Martin Reinecke's avatar
Martin Reinecke committed
109
110
    realized_power = ift.log(sh.power_analyze(binbounds=p_space.binbounds))
    data_power = ift.log(fft(d).power_analyze(binbounds=p_space.binbounds))
111
    d_data = d.val.get_full_data().real
Martin Reinecke's avatar
Martin Reinecke committed
112
    if rank == 0:
Martin Reinecke's avatar
Martin Reinecke committed
113
        pl.plot([go.Heatmap(z=d_data)], filename='data.html', auto_open=False)
114

Philipp Arras's avatar
Philipp Arras committed
115
116
    #  Minimization strategy
    def convergence_measure(a_energy, iteration):  # returns current energy
117
        x = a_energy.value
Philipp Arras's avatar
Philipp Arras committed
118
        print(x, iteration)
119

Martin Reinecke's avatar
Martin Reinecke committed
120
121
122
123
124
125
126
127
128
    IC1 = ift.GradientNormController(iteration_limit=5,
                                     tol_abs_gradnorm=0.1)
    minimizer1 = ift.RelaxedNewton(IC1)
    IC2 = ift.GradientNormController(iteration_limit=30,
                                     tol_abs_gradnorm=0.1)
    minimizer2 = ift.VL_BFGS(IC2, max_history_length=20)
    IC3 = ift.GradientNormController(iteration_limit=100,
                                     tol_abs_gradnorm=0.1)
    minimizer3 = ift.SteepestDescent(IC3)
Philipp Arras's avatar
Philipp Arras committed
129
    # Set starting position
Martin Reinecke's avatar
Martin Reinecke committed
130
    flat_power = ift.Field(p_space, val=1e-8)
131
132
    m0 = flat_power.power_synthesize(real_signal=True)

133
134
    # t0 = ift.Field(p_space, val=np.log(1./(1+p_space.kindex)**2))
    t0 = ift.Field(p_space, val=-5)
135

136
    for i in range(500):
Martin Reinecke's avatar
Martin Reinecke committed
137
138
        S0 = ift.create_power_operator(h_space, power_spectrum=ift.exp(t0),
                                       distribution_strategy=distribution_strategy)
139

Philipp Arras's avatar
Philipp Arras committed
140
        # Initialize non-linear Wiener Filter energy
Martin Reinecke's avatar
Martin Reinecke committed
141
        map_energy = WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S0)
Philipp Arras's avatar
Philipp Arras committed
142
        # Solve the Wiener Filter analytically
143
144
        D0 = map_energy.curvature
        m0 = D0.inverse_times(j)
Philipp Arras's avatar
Philipp Arras committed
145
        # Initialize power energy with updated parameters
146
147
        power_energy = CriticalPowerEnergy(position=t0, m=m0, D=D0,
                                           smoothness_prior=1e-15, samples=5)
148

149
        (power_energy, convergence) = minimizer1(power_energy)
150

Philipp Arras's avatar
Philipp Arras committed
151
152
        # Set new power spectrum
        t0.val = power_energy.position.val.real
153

Philipp Arras's avatar
Philipp Arras committed
154
155
        # Plot current estimate
        print(i)
156
157
158
        if i % 1 == 0:
            plot_parameters(sh, t0, ift.log(sp), ift.log(sh.power_analyze(binbounds=p_space.binbounds)),data_power)
            print ift.log(sh.power_analyze(binbounds=p_space.binbounds)).val - t0.val