wiener_filter_hamiltonian.py 4.35 KB
Newer Older
1
2

from nifty import *
3

4
5
import plotly.offline as pl
import plotly.graph_objs as go
6
7
8
9
10

from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.rank

11
np.random.seed(42)
12

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class WienerFilterEnergy(Energy):
    def __init__(self, position, D, j):
        # in principle not necessary, but useful in order to make the signature
        # explicit
        super(WienerFilterEnergy, self).__init__(position)
        self.D = D
        self.j = j

    def at(self, position):
        return self.__class__(position, D=self.D, j=self.j)

    @property
    def value(self):
        D_inv_x = self.D_inverse_x()
        H = 0.5 * D_inv_x.dot(self.position) - self.j.dot(self.position)
        return H.real

    @property
    def gradient(self):
        D_inv_x = self.D_inverse_x()
        g = D_inv_x - self.j
        return_g = g.copy_empty(dtype=np.float)
        return_g.val = g.val.real
        return return_g

38
39
40
41
42
43
44
45
46
47
48
    @property
    def curvature(self):
        class Dummy(object):
            def __init__(self, x):
                self.x = x
            def inverse_times(self, *args, **kwargs):
                return self.x.times(*args, **kwargs)
        my_dummy = Dummy(self.D)
        return my_dummy


49
    @memo
50
51
52
53
    def D_inverse_x(self):
        return D.inverse_times(self.position)


54
55
if __name__ == "__main__":

56
    distribution_strategy = 'not'
57

58
    # Set up spaces and fft transformation
59
60
61
62
63
    s_space = RGSpace([512, 512], dtype=np.float)
    fft = FFTOperator(s_space)
    h_space = fft.target[0]
    p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)

64
    # create the field instances and power operator
65
66
67
68
69
70
71
72
73
    pow_spec = (lambda k: (42 / (k + 1) ** 3))
    S = create_power_operator(h_space, power_spectrum=pow_spec,
                              distribution_strategy=distribution_strategy)

    sp = Field(p_space, val=lambda z: pow_spec(z)**(1./2),
               distribution_strategy=distribution_strategy)
    sh = sp.power_synthesize(real_signal=True)
    ss = fft.inverse_times(sh)

74
    # model the measurement process
75
76
77
78
79
80
81
82
83
84
85
    R = SmoothingOperator(s_space, sigma=0.01)
#    R = DiagonalOperator(s_space, diagonal=1.)
#    R._diagonal.val[200:400, 200:400] = 0

    signal_to_noise = 1
    N = DiagonalOperator(s_space, diagonal=ss.var()/signal_to_noise, bare=True)
    n = Field.from_random(domain=s_space,
                          random_type='normal',
                          std=ss.std()/np.sqrt(signal_to_noise),
                          mean=0)

86
    # create mock data
87
    d = R(ss) + n
88
89

    # set up reconstruction objects
90
91
92
    j = R.adjoint_times(N.inverse_times(d))
    D = PropagatorOperator(S=S, N=N, R=R)

93
    def distance_measure(energy, iteration):
94
95
        x = energy.position
        print (iteration, ((x-ss).norm()/ss.norm()).real)
96

97
98
99
100
101
102
103
#    minimizer = SteepestDescent(convergence_tolerance=0,
#                                iteration_limit=50,
#                                callback=distance_measure)

    minimizer = RelaxedNewton(convergence_tolerance=0,
                              iteration_limit=2,
                              callback=distance_measure)
104

105
106
107
108
#    minimizer = VL_BFGS(convergence_tolerance=0,
#                        iteration_limit=50,
#                        callback=distance_measure,
#                        max_history_length=3)
109
110
111

    m0 = Field(s_space, val=1)

112
113
114
115
    energy = WienerFilterEnergy(position=m0, D=D, j=j)

    (energy, convergence) = minimizer(energy)

116
    m = energy.position
117

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    d_data = d.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=d_data)], filename='data.html')


    ss_data = ss.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')

    sh_data = sh.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')

    j_data = j.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=j_data)], filename='j.html')

    jabs_data = np.abs(j.val.get_full_data())
    jphase_data = np.angle(j.val.get_full_data())
    if rank == 0:
        pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html')
        pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html')

    m_data = m.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=m_data)], filename='map.html')
144
145
146
147

#    grad_data = grad.val.get_full_data().real
#    if rank == 0:
#        pl.plot([go.Heatmap(z=grad_data)], filename='grad.html')