wiener_filter_easy.py 2.55 KB
Newer Older
 Martin Reinecke committed Nov 29, 2017 1 ``````import numpy as np `````` Martin Reinecke committed Jan 20, 2018 2 ``````import nifty4 as ift `````` Martin Reinecke committed Nov 29, 2017 3 4 5 `````` if __name__ == "__main__": `````` Martin Reinecke committed Dec 28, 2017 6 `````` np.random.seed(43) `````` Martin Reinecke committed Nov 29, 2017 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 `````` # Set up physical constants # Total length of interval or volume the field lives on, e.g. in meters L = 2. # Typical distance over which the field is correlated (in same unit as L) correlation_length = 0.1 # Variance of field in position space sqrt(<|s_x|^2>) (in unit of s) field_variance = 2. # Smoothing length of response (in same unit as L) response_sigma = 0.01 # Define resolution (pixels per dimension) N_pixels = 256 # Set up derived constants k_0 = 1./correlation_length # Note that field_variance**2 = a*k_0/4. for this analytic form of power # spectrum a = field_variance**2/k_0*4. pow_spec = (lambda k: a / (1 + k/k_0) ** 4) pixel_width = L/N_pixels # Set up the geometry s_space = ift.RGSpace([N_pixels, N_pixels], distances=pixel_width) `````` Martin Reinecke committed Jan 31, 2018 30 `````` h_space = s_space.get_default_codomain() `````` Martin Reinecke committed Feb 06, 2018 31 `````` HT = ift.HarmonicTransformOperator(h_space, s_space) `````` Martin Reinecke committed Nov 29, 2017 32 33 34 35 36 37 `````` p_space = ift.PowerSpace(h_space) # Create mock data Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec) `````` Martin Reinecke committed Dec 01, 2017 38 `````` sp = ift.PS_field(p_space, pow_spec) `````` Martin Reinecke committed Nov 29, 2017 39 40 `````` sh = ift.power_synthesize(sp, real_signal=True) `````` Martin Reinecke committed Feb 06, 2018 41 42 `````` R = HT*ift.create_harmonic_smoothing_operator((h_space,), 0, response_sigma) `````` Martin Reinecke committed Nov 29, 2017 43 `````` `````` Martin Reinecke committed Jan 31, 2018 44 45 `````` noiseless_data = R(sh) signal_to_noise = 1. `````` Martin Reinecke committed Feb 02, 2018 46 `````` noise_amplitude = noiseless_data.val.std()/signal_to_noise `````` Martin Reinecke committed Feb 06, 2018 47 `````` N = ift.ScalingOperator(noise_amplitude**2, s_space) `````` Martin Reinecke committed Nov 29, 2017 48 49 `````` n = ift.Field.from_random(domain=s_space, random_type='normal', `````` Martin Reinecke committed Jan 31, 2018 50 `````` std=noise_amplitude, `````` Martin Reinecke committed Nov 29, 2017 51 52 `````` mean=0) `````` Martin Reinecke committed Jan 31, 2018 53 `````` d = noiseless_data + n `````` Martin Reinecke committed Nov 29, 2017 54 55 56 57 `````` # Wiener filter j = R.adjoint_times(N.inverse_times(d)) `````` Martin Reinecke committed Jan 30, 2018 58 `````` IC = ift.GradientNormController(name="inverter", iteration_limit=500, `````` Martin Reinecke committed Dec 28, 2017 59 `````` tol_abs_gradnorm=0.1) `````` Martin Reinecke committed Nov 29, 2017 60 `````` inverter = ift.ConjugateGradient(controller=IC) `````` Martin Reinecke committed Jan 31, 2018 61 `````` D = (R.adjoint*N.inverse*R + Sh.inverse).inverse `````` Martin Reinecke committed Dec 28, 2017 62 `````` # MR FIXME: we can/should provide a preconditioner here as well! `````` Martin Reinecke committed Dec 28, 2017 63 `````` D = ift.InversionEnabler(D, inverter) `````` Martin Reinecke committed Nov 29, 2017 64 `````` m = D(j) `````` Martin Reinecke committed Jan 31, 2018 65 `````` `````` Philipp Arras committed Feb 05, 2018 66 67 `````` # Plotting d_field = ift.Field(s_space, val=d.val) `````` Martin Reinecke committed Feb 06, 2018 68 69 70 71 `````` zmax = max(HT(sh).max(), d_field.max(), HT(m).max()) zmin = min(HT(sh).min(), d_field.min(), HT(m).min()) plotdict = {"colormap": "Planck-like", "zmax": zmax, "zmin": zmin} ift.plot(HT(sh), name="mock_signal.png", **plotdict) `````` Philipp Arras committed Feb 05, 2018 72 `````` ift.plot(d_field, name="data.png", **plotdict) `````` Martin Reinecke committed Feb 06, 2018 73 `` ift.plot(HT(m), name="reconstruction.png", **plotdict)``