Commit 90db729d authored by Theo Steininger's avatar Theo Steininger

Refactored wiener_filter_easy.py

parent 92422402
import numpy as np
from nifty import RGSpace, PowerSpace, Field, FFTOperator, ComposedOperator,\
SmoothingOperator, DiagonalOperator, create_power_operator
from nifty.library import WienerFilterCurvature
from nifty import *
#import plotly.offline as pl #import plotly.offline as pl
#import plotly.graph_objs as go #import plotly.graph_objs as go
...@@ -10,36 +14,37 @@ rank = comm.rank ...@@ -10,36 +14,37 @@ rank = comm.rank
if __name__ == "__main__": if __name__ == "__main__":
distribution_strategy = 'not' distribution_strategy = 'fftw'
#Setting up physical constants # Setting up physical constants
#total length of Interval or Volume the field lives on, e.g. in meters # total length of Interval or Volume the field lives on, e.g. in meters
L = 2. L = 2.
#typical distance over which the field is correlated (in same unit as L) # typical distance over which the field is correlated (in same unit as L)
correlation_length = 0.1 correlation_length = 0.1
#variance of field in position space sqrt(<|s_x|^2>) (in unit of s) # variance of field in position space sqrt(<|s_x|^2>) (in unit of s)
field_variance = 2. field_variance = 2.
#smoothing length of response (in same unit as L) # smoothing length of response (in same unit as L)
response_sigma = 0.1 response_sigma = 0.1
#defining resolution (pixels per dimension) # defining resolution (pixels per dimension)
N_pixels = 512 N_pixels = 512
#Setting up derived constants # Setting up derived constants
k_0 = 1./correlation_length k_0 = 1./correlation_length
#note that field_variance**2 = a*k_0/4. for this analytic form of power # note that field_variance**2 = a*k_0/4. for this analytic form of power
#spectrum # spectrum
a = field_variance**2/k_0*4. a = field_variance**2/k_0*4.
pow_spec = (lambda k: a / (1 + k/k_0) ** 4) pow_spec = (lambda k: a / (1 + k/k_0) ** 4)
pixel_width = L/N_pixels pixel_length = L/N_pixels
# Setting up the geometry # Setting up the geometry
s_space = RGSpace([N_pixels, N_pixels], distances = pixel_width) s_space = RGSpace([N_pixels, N_pixels], distances=pixel_length)
fft = FFTOperator(s_space) fft = FFTOperator(s_space, domain_dtype=np.float, target_dtype=np.complex)
h_space = fft.target[0] h_space = fft.target[0]
inverse_fft = FFTOperator(h_space, target=s_space,
domain_dtype=np.complex, target_dtype=np.float)
p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy) p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)
# Creating the mock data # Creating the mock data
S = create_power_operator(h_space, power_spectrum=pow_spec, S = create_power_operator(h_space, power_spectrum=pow_spec,
...@@ -51,6 +56,7 @@ if __name__ == "__main__": ...@@ -51,6 +56,7 @@ if __name__ == "__main__":
ss = fft.inverse_times(sh) ss = fft.inverse_times(sh)
R = SmoothingOperator(s_space, sigma=response_sigma) R = SmoothingOperator(s_space, sigma=response_sigma)
R_harmonic = ComposedOperator([inverse_fft, R], default_spaces=[0, 0])
signal_to_noise = 1 signal_to_noise = 1
N = DiagonalOperator(s_space, diagonal=ss.var()/signal_to_noise, bare=True) N = DiagonalOperator(s_space, diagonal=ss.var()/signal_to_noise, bare=True)
...@@ -63,7 +69,9 @@ if __name__ == "__main__": ...@@ -63,7 +69,9 @@ if __name__ == "__main__":
# Wiener filter # Wiener filter
j = R.adjoint_times(N.inverse_times(d)) j = R_harmonic.adjoint_times(N.inverse_times(d))
D = PropagatorOperator(S=S, N=N, R=R) wiener_curvature = WienerFilterCurvature(S=S, N=N, R=R_harmonic)
m = wiener_curvature.inverse_times(j)
m_s = inverse_fft(m)
m = D(j)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment