wiener_filter_hamiltonian.py 3.86 KB
Newer Older
1
2

from nifty import *
3

4
5
import plotly.offline as pl
import plotly.graph_objs as go
6
7
8
9
10

from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.rank

11
np.random.seed(42)
12

13
14
class AdjointFFTResponse(LinearOperator):
    def __init__(self, FFT, R, default_spaces=None):
Jakob Knollmueller's avatar
test    
Jakob Knollmueller committed
15
        super(AdjointFFTResponse, self).__init__(default_spaces)
16
        self._domain = FFT.target
Jakob Knollmueller's avatar
test    
Jakob Knollmueller committed
17
        self._target = R.target
18
19
20
        self.R = R
        self.FFT = FFT

Jakob Knollmueller's avatar
test    
Jakob Knollmueller committed
21
    def _times(self, x, spaces=None):
22
23
        return self.R(self.FFT.adjoint_times(x))

Jakob Knollmueller's avatar
test    
Jakob Knollmueller committed
24
    def _adjoint_times(self, x, spaces=None):
25
        return self.FFT(self.R.adjoint_times(x))
Jakob Knollmueller's avatar
test    
Jakob Knollmueller committed
26
27
28
29
30
31
32
33
34
35
36
    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

    @property
    def unitary(self):
        return False
37

38
39


40
41
if __name__ == "__main__":

Martin Reinecke's avatar
Martin Reinecke committed
42
    distribution_strategy = 'not'
43

44
    # Set up spaces and fft transformation
Martin Reinecke's avatar
Martin Reinecke committed
45
    s_space = RGSpace([512, 512])
46
47
48
49
    fft = FFTOperator(s_space)
    h_space = fft.target[0]
    p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)

50
    # create the field instances and power operator
51
52
53
54
55
56
57
58
59
    pow_spec = (lambda k: (42 / (k + 1) ** 3))
    S = create_power_operator(h_space, power_spectrum=pow_spec,
                              distribution_strategy=distribution_strategy)

    sp = Field(p_space, val=lambda z: pow_spec(z)**(1./2),
               distribution_strategy=distribution_strategy)
    sh = sp.power_synthesize(real_signal=True)
    ss = fft.inverse_times(sh)

60
    # model the measurement process
61
    Instrument = SmoothingOperator(s_space, sigma=0.01)
62

63
64
65
#    Instrument = DiagonalOperator(s_space, diagonal=1.)
#    Instrument._diagonal.val[200:400, 200:400] = 0
    R = AdjointFFTResponse(fft, Instrument)
66
67
68
69
70
71
72
    signal_to_noise = 1
    N = DiagonalOperator(s_space, diagonal=ss.var()/signal_to_noise, bare=True)
    n = Field.from_random(domain=s_space,
                          random_type='normal',
                          std=ss.std()/np.sqrt(signal_to_noise),
                          mean=0)

73
    # create mock data
74
    d = R(sh) + n
75

76
    def distance_measure(energy, iteration):
77
78
        x = energy.value
        print (x, iteration)
79

80
81
82
83
84
85
86
#    minimizer = SteepestDescent(convergence_tolerance=0,
#                                iteration_limit=50,
#                                callback=distance_measure)

    minimizer = RelaxedNewton(convergence_tolerance=0,
                              iteration_limit=2,
                              callback=distance_measure)
87

88
89
90
91
#    minimizer = VL_BFGS(convergence_tolerance=0,
#                        iteration_limit=50,
#                        callback=distance_measure,
#                        max_history_length=3)
92
93


Jakob Knollmueller's avatar
test    
Jakob Knollmueller committed
94
    m0 = Field(s_space, val=1.)
95

Jakob Knollmueller's avatar
test    
Jakob Knollmueller committed
96
97
    energy = WienerFilterEnergy(position=m0, R=R, N=N, S=S)
    solution = energy.analytic_solution()
98
99
    (energy, convergence) = minimizer(energy)

100
    m = fft.adjoint_times(energy.position)
101

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    d_data = d.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=d_data)], filename='data.html')


    ss_data = ss.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')

    sh_data = sh.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')

    j_data = j.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=j_data)], filename='j.html')

    jabs_data = np.abs(j.val.get_full_data())
    jphase_data = np.angle(j.val.get_full_data())
    if rank == 0:
        pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html')
        pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html')

    m_data = m.val.get_full_data().real
    if rank == 0:
        pl.plot([go.Heatmap(z=m_data)], filename='map.html')
128
129
130
131

#    grad_data = grad.val.get_full_data().real
#    if rank == 0:
#        pl.plot([go.Heatmap(z=grad_data)], filename='grad.html')