wiener_filter_easy.py 2.38 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
import numpy as np
Martin Reinecke's avatar
Martin Reinecke committed
2
import nifty4 as ift
Martin Reinecke's avatar
Martin Reinecke committed
3

4

Martin Reinecke's avatar
Martin Reinecke committed
5
if __name__ == "__main__":
6
    np.random.seed(43)
Martin Reinecke's avatar
Martin Reinecke committed
7 8 9 10 11
    # Set up physical constants
    # Total length of interval or volume the field lives on, e.g. in meters
    L = 2.
    # Typical distance over which the field is correlated (in same unit as L)
    correlation_length = 0.1
12
    # Variance of field in position space sqrt(<|s_x|^2>) (in same unit as s)
Martin Reinecke's avatar
Martin Reinecke committed
13 14 15
    field_variance = 2.
    # Smoothing length of response (in same unit as L)
    response_sigma = 0.01
16 17
    # typical noise amplitude of the measurement
    noise_level = 1.
Martin Reinecke's avatar
Martin Reinecke committed
18 19 20 21 22 23

    # Define resolution (pixels per dimension)
    N_pixels = 256

    # Set up derived constants
    k_0 = 1./correlation_length
24 25 26
    #defining a power spectrum with the right correlation length
    #we later set the field variance to the desired value
    unscaled_pow_spec = (lambda k: 1. / (1 + k/k_0) ** 4)
Martin Reinecke's avatar
Martin Reinecke committed
27 28 29 30
    pixel_width = L/N_pixels

    # Set up the geometry
    s_space = ift.RGSpace([N_pixels, N_pixels], distances=pixel_width)
Martin Reinecke's avatar
Martin Reinecke committed
31
    h_space = s_space.get_default_codomain()
32 33
    s_var = ift.get_signal_variance(unscaled_pow_spec, h_space)
    pow_spec = (lambda k: unscaled_pow_spec(k)/s_var*field_variance**2)
Martin Reinecke's avatar
Martin Reinecke committed
34

Martin Reinecke's avatar
Martin Reinecke committed
35
    HT = ift.HarmonicTransformOperator(h_space, s_space)
Martin Reinecke's avatar
Martin Reinecke committed
36 37 38 39

    # Create mock data

    Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
40
    sh = Sh.draw_sample()
Martin Reinecke's avatar
Martin Reinecke committed
41

Martin Reinecke's avatar
Martin Reinecke committed
42 43
    R = HT*ift.create_harmonic_smoothing_operator((h_space,), 0,
                                                  response_sigma)
Martin Reinecke's avatar
Martin Reinecke committed
44
    noiseless_data = R(sh)
45
    N = ift.ScalingOperator(noise_level**2, s_space)
Martin Reinecke's avatar
adjust  
Martin Reinecke committed
46
    n = N.draw_sample()
Martin Reinecke's avatar
Martin Reinecke committed
47

Martin Reinecke's avatar
Martin Reinecke committed
48
    d = noiseless_data + n
Martin Reinecke's avatar
Martin Reinecke committed
49 50 51 52

    # Wiener filter

    j = R.adjoint_times(N.inverse_times(d))
53
    IC = ift.GradientNormController(name="inverter", iteration_limit=500,
54
                                    tol_abs_gradnorm=0.1)
Martin Reinecke's avatar
Martin Reinecke committed
55
    inverter = ift.ConjugateGradient(controller=IC)
Martin Reinecke's avatar
Martin Reinecke committed
56
    D = (ift.SandwichOperator.make(R, N.inverse) + Sh.inverse).inverse
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
57
    D = ift.InversionEnabler(D, inverter, approximation=Sh)
Martin Reinecke's avatar
Martin Reinecke committed
58
    m = D(j)
Martin Reinecke's avatar
Martin Reinecke committed
59

60
    # Plotting
Martin Reinecke's avatar
Martin Reinecke committed
61
    d_field = d.cast_domain(s_space)
Martin Reinecke's avatar
Martin Reinecke committed
62 63 64 65
    zmax = max(HT(sh).max(), d_field.max(), HT(m).max())
    zmin = min(HT(sh).min(), d_field.min(), HT(m).min())
    plotdict = {"colormap": "Planck-like", "zmax": zmax, "zmin": zmin}
    ift.plot(HT(sh), name="mock_signal.png", **plotdict)
66
    ift.plot(d_field, name="data.png", **plotdict)
Martin Reinecke's avatar
Martin Reinecke committed
67
    ift.plot(HT(m), name="reconstruction.png", **plotdict)