wiener_filter_hamiltonian.py 4.34 KB
Newer Older
1
2

from nifty import *
3

4
5
import plotly.offline as pl
import plotly.graph_objs as go
6
7
8
9
10

from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.rank

11
np.random.seed(42)
12

13
14
15
16
17
18
19
20
21
22
23
24
25
26
class WienerFilterEnergy(Energy):
    def __init__(self, position, D, j):
        # in principle not necessary, but useful in order to make the signature
        # explicit
        super(WienerFilterEnergy, self).__init__(position)
        self.D = D
        self.j = j

    def at(self, position):
        return self.__class__(position, D=self.D, j=self.j)

    @property
    def value(self):
        D_inv_x = self.D_inverse_x()
Martin Reinecke's avatar
Martin Reinecke committed
27
        H = 0.5 * D_inv_x.vdot(self.position) - self.j.dot(self.position)
28
29
30
31
32
33
34
35
36
37
        return H.real

    @property
    def gradient(self):
        D_inv_x = self.D_inverse_x()
        g = D_inv_x - self.j
        return_g = g.copy_empty(dtype=np.float)
        return_g.val = g.val.real
        return return_g

38
39
40
41
42
43
44
45
46
47
48
    @property
    def curvature(self):
        class Dummy(object):
            def __init__(self, x):
                self.x = x
            def inverse_times(self, *args, **kwargs):
                return self.x.times(*args, **kwargs)
        my_dummy = Dummy(self.D)
        return my_dummy


49
    @memo
50
51
52
53
    def D_inverse_x(self):
        return D.inverse_times(self.position)


54
55
if __name__ == "__main__":

Martin Reinecke's avatar
Martin Reinecke committed
56
    distribution_strategy = 'not'
57

58
    # Set up spaces and fft transformation
Martin Reinecke's avatar
Martin Reinecke committed
59
    s_space = RGSpace([512, 512])
60
61
62
63
    fft = FFTOperator(s_space)
    h_space = fft.target[0]
    p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)

64
    # create the field instances and power operator
65
66
67
68
69
70
71
72
73
    pow_spec = (lambda k: (42 / (k + 1) ** 3))
    S = create_power_operator(h_space, power_spectrum=pow_spec,
                              distribution_strategy=distribution_strategy)

    sp = Field(p_space, val=lambda z: pow_spec(z)**(1./2),
               distribution_strategy=distribution_strategy)
    sh = sp.power_synthesize(real_signal=True)
    ss = fft.inverse_times(sh)

74
    # model the measurement process
75
76
77
78
79
80
81
82
83
84
85
    R = SmoothingOperator(s_space, sigma=0.01)
#    R = DiagonalOperator(s_space, diagonal=1.)
#    R._diagonal.val[200:400, 200:400] = 0

    signal_to_noise = 1
    N = DiagonalOperator(s_space, diagonal=ss.var()/signal_to_noise, bare=True)
    n = Field.from_random(domain=s_space,
                          random_type='normal',
                          std=ss.std()/np.sqrt(signal_to_noise),
                          mean=0)

86
    # create mock data
87
    d = R(ss) + n
88
89

    # set up reconstruction objects
90
91
92
    j = R.adjoint_times(N.inverse_times(d))
    D = PropagatorOperator(S=S, N=N, R=R)

93
    def distance_measure(energy, iteration):
94
95
        x = energy.position
        print (iteration, ((x-ss).norm()/ss.norm()).real)
96

97
98
99
100
101
102
103
#    minimizer = SteepestDescent(convergence_tolerance=0,
#                                iteration_limit=50,
#                                callback=distance_measure)

    minimizer = RelaxedNewton(convergence_tolerance=0,
                              iteration_limit=2,
                              callback=distance_measure)
104

105
106
107
108
#    minimizer = VL_BFGS(convergence_tolerance=0,
#                        iteration_limit=50,
#                        callback=distance_measure,
#                        max_history_length=3)
109

110
    m0 = Field(s_space, val=1.)
111

112
113
114
115
    energy = WienerFilterEnergy(position=m0, D=D, j=j)

    (energy, convergence) = minimizer(energy)

116
    m = energy.position
117

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    d_data = d.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=d_data)], filename='data.html')


    ss_data = ss.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')

    sh_data = sh.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')

    j_data = j.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=j_data)], filename='j.html')

    jabs_data = np.abs(j.val.get_full_data())
    jphase_data = np.angle(j.val.get_full_data())
    if rank == 0:
        pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html')
        pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html')

    m_data = m.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=m_data)], filename='map.html')
144
145
146
147

#    grad_data = grad.val.get_full_data().real
#    if rank == 0:
#        pl.plot([go.Heatmap(z=grad_data)], filename='grad.html')