nonlinear_wiener_filter.py 5.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175

from nifty import *

import plotly.offline as pl
import plotly.graph_objs as go

from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.rank

np.random.seed(42)

class NonlinearResponse(LinearOperator):
    def __init__(self, FFT, Instrument, function, derivative, default_spaces=None):
        super(NonlinearResponse, self).__init__(default_spaces)
        self._domain = FFT.target
        self._target = Instrument.target
        self.function = function
        self.derivative = derivative
        self.I = Instrument
        self.FFT = FFT


    def _times(self, x, spaces=None):
        return self.I(self.function(self.FFT.adjoint_times(x)))

    def _adjoint_times(self, x, spaces=None):
        return self.FFT(self.function(self.I.adjoint_times(x)))

    def derived_times(self, x, position):
        position_derivative = self.derivative(self.FFT.adjoint_times(position))
        return self.I(position_derivative * self.FFT.adjoint_times(x))

    def derived_adjoint_times(self, x, position):
        position_derivative = self.derivative(self.FFT.adjoint_times(position))
        return self.FFT(position_derivative * self.I.adjoint_times(x))

    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

    @property
    def unitary(self):
        return False



if __name__ == "__main__":

    distribution_strategy = 'not'

    # Set up position space
    s_space = RGSpace([128,128])
    # s_space = HPSpace(32)

    # Define harmonic transformation and associated harmonic space
    fft = FFTOperator(s_space)
    h_space = fft.target[0]

    # Setting up power space
    p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)

    # Choosing the prior correlation structure and defining correlation operator
    pow_spec = (lambda k: (0.42 / (k + 1) ** 3))
    S = create_power_operator(h_space, power_spectrum=pow_spec,
                              distribution_strategy=distribution_strategy)

    # Drawing a sample sh from the prior distribution in harmonic space
    sp = Field(p_space, val=lambda z: pow_spec(z)**(1./2),
               distribution_strategy=distribution_strategy)
    sh = sp.power_synthesize(real_signal=True)


    # Choosing the measurement instrument
#    Instrument = SmoothingOperator(s_space, sigma=0.01)
    Instrument = DiagonalOperator(s_space, diagonal=1.)
#    Instrument._diagonal.val[200:400, 200:400] = 0

    #   Choosing nonlinearity

    # function = exp
    # derivative = exp
    def function(x):
        return 0.5 * tanh(x) + 0.5
    def derivative(x):
        return 0.5*(1 - tanh(x)**2)
    # def function(x):
    #     return x
    # def derivative(x):
    #     return 1
    #
    # def function(x):
    #     return 0.5*x**2 + x
    # def derivative(x):
    #     return x + 1
    #
    # def function(x):
    #     return 0.9*x**4 +0.2*x**2 + x
    # def derivative(x):
    #     return 0.9*4*x**3 + 0.4*x +1
    #

    #Adding a harmonic transformation to the instrument
    R = NonlinearResponse(fft, Instrument, function, derivative)
    noise = 0.01
    N = DiagonalOperator(s_space, diagonal=noise, bare=True)
    n = Field.from_random(domain=s_space,
                          random_type='normal',
                          std=sqrt(noise),
                          mean=0)

    # Creating the mock data
    d = R(sh) + n

    # Choosing the minimization strategy

    def convergence_measure(energy, iteration): # returns current energy
        x = energy.value
        print (x, iteration)

#    minimizer = SteepestDescent(convergence_tolerance=0,
#                                iteration_limit=50,
#                                callback=convergence_measure)
#
    minimizer = RelaxedNewton(convergence_tolerance=0,
                              convergence_level=1,
                              iteration_limit=10,
                              callback=convergence_measure)

    # minimizer = VL_BFGS(convergence_tolerance=0,
    #                    iteration_limit=50,
    #                    callback=convergence_measure,
    #                    max_history_length=3)
    #
    #

    # Setting starting position
    m0 = Field.from_random(random_type="normal",domain = h_space, std=0.001)

    # Initializing the Wiener Filter energy
    energy = NonlinearWienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S)

    # Solving the problem with chosen minimization strategy
    (energy, convergence) = minimizer(energy)

    # Transforming fields to position space for plotting

    ss = fft.adjoint_times(sh)
    m = fft.adjoint_times(energy.position)


    # Plotting

    d_data = d.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=d_data)], filename='data.html')


    ss_data = ss.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')

    sh_data = sh.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')


    m_data = m.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=m_data)], filename='map.html')