diff --git a/1_wiener_filter.py b/1_wiener_filter.py index fd06aa072cea39d50b2be11466fd7439576400fe..bb194af450d4c31365011f2b97bcda8e71f2cf9b 100644 --- a/1_wiener_filter.py +++ b/1_wiener_filter.py @@ -18,7 +18,7 @@ import numpy as np import nifty5 as ift -from helpers import plot_WF +from helpers import generate_wf_data, plot_WF np.random.seed(42) @@ -26,23 +26,13 @@ np.random.seed(42) position_space = ift.RGSpace(256) -# Generate data and signal -harmonic_space = position_space.get_default_codomain() -HT = ift.HartleyOperator(harmonic_space, target=position_space) -N = ift.ScalingOperator(0.1, position_space) -S_k = ift.create_power_operator(harmonic_space, lambda k: 1/(10. + k**2.5)) -s = HT(S_k.draw_sample()) -d = s + N.draw_sample() -np.save('data.npy', d.to_global_data()) -np.save('signal.npy', s.to_global_data()) -# End generate data and signal +prior_spectrum = lambda k: 1/(10. + k**2.5) +data, ground_truth = generate_wf_data(position_space, prior_spectrum) R = ift.GeometryRemover(position_space) data_space = R.target -data = np.load('data.npy') data = ift.from_global_data(data_space, data) -ground_truth = np.load('signal.npy') ground_truth = ift.from_global_data(position_space, ground_truth) plot_WF('data', ground_truth, data) @@ -51,11 +41,6 @@ N = ift.ScalingOperator(0.1, data_space) harmonic_space = position_space.get_default_codomain() HT = ift.HartleyOperator(harmonic_space, target=position_space) - -def prior_spectrum(k): - return 1/(10. + k**2.5) - - S_h = ift.create_power_operator(harmonic_space, prior_spectrum) S = HT @ S_h @ HT.adjoint diff --git a/helpers/generate_data.py b/helpers/generate_data.py index a5c1cf1a0c3bc4e8a06995b659b4b794f51e94b3..1ad8caa6f753d8a5720f62300ab420ac355e7564 100644 --- a/helpers/generate_data.py +++ b/helpers/generate_data.py @@ -38,3 +38,17 @@ def generate_bernoulli_data(signal_response): rate = signal_response(ground_truth).to_global_data() d = np.random.binomial(1, rate) return ift.from_global_data(signal_response.target, d), ground_truth + + +def generate_wf_data(domain, spectrum): + harmonic_space = domain.get_default_codomain() + HT = ift.HartleyOperator(harmonic_space, target=domain) + N = ift.ScalingOperator(0.1, domain) + S_k = ift.create_power_operator(harmonic_space, spectrum) + s = HT(S_k.draw_sample()).to_global_data() + d = (s + N.draw_sample()).to_global_data() + return d, s + + +def generate_mysterious_data(domain): + return generate_wf_data(domain, lambda k: 5/((7**2 - k**2)**2 + 3**2*k**2))