wiener_filter.py 1.58 KB
Newer Older
theos's avatar
theos committed
1
2

from nifty import *
Martin Reinecke's avatar
Martin Reinecke committed
3
4
import plotly.offline as pl
import plotly.graph_objs as go
theos's avatar
theos committed
5
6
7
8
9
10
11
12

from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.rank


if __name__ == "__main__":

Martin Reinecke's avatar
Martin Reinecke committed
13
    distribution_strategy = 'not'
theos's avatar
theos committed
14

15
    # Setting up the geometry
Martin Reinecke's avatar
Martin Reinecke committed
16
    s_space = RGSpace([512, 512])
theos's avatar
theos committed
17
18
19
20
    fft = FFTOperator(s_space)
    h_space = fft.target[0]
    p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)

21
22

    # Creating the mock data
theos's avatar
theos committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    pow_spec = (lambda k: 42 / (k + 1) ** 3)

    S = create_power_operator(h_space, power_spectrum=pow_spec,
                              distribution_strategy=distribution_strategy)

    sp = Field(p_space, val=pow_spec,
               distribution_strategy=distribution_strategy)
    sh = sp.power_synthesize(real_signal=True)
    ss = fft.inverse_times(sh)

    R = SmoothingOperator(s_space, sigma=0.1)

    signal_to_noise = 1
    N = DiagonalOperator(s_space, diagonal=ss.var()/signal_to_noise, bare=True)
    n = Field.from_random(domain=s_space,
                          random_type='normal',
                          std=ss.std()/np.sqrt(signal_to_noise),
                          mean=0)

    d = R(ss) + n
43
44

    # Wiener filter
theos's avatar
theos committed
45
46
47
48
49
50
51
52
    j = R.adjoint_times(N.inverse_times(d))
    D = PropagatorOperator(S=S, N=N, R=R)

    m = D(j)

    d_data = d.val.get_full_data().real
    m_data = m.val.get_full_data().real
    ss_data = ss.val.get_full_data().real
Martin Reinecke's avatar
Martin Reinecke committed
53
54
55
56
    if rank == 0:
        pl.plot([go.Heatmap(z=d_data)], filename='data.html')
        pl.plot([go.Heatmap(z=m_data)], filename='map.html')
        pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')