diff --git a/demos/krylov_sampling.py b/demos/krylov_sampling.py index 4987437ae681cb8f35773b04326788f3a2e91574..c9c190b4213823190c5fe9c43785e09524a19a68 100644 --- a/demos/krylov_sampling.py +++ b/demos/krylov_sampling.py @@ -40,7 +40,8 @@ IC = ift.GradientNormController(tol_abs_gradnorm=1e-3, iteration_limit=N_iter) m, samps = ift.library.generate_krylov_samples(D_inv, S, j, N_samps, IC) m_x = sky(m) inverter = ift.ConjugateGradient(IC) -curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter) +curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter, + sampling_inverter=inverter) samps_old = [curv.draw_sample(from_inverse=True) for i in range(N_samps)] plt.plot(d.to_global_data(), '+', label="data", alpha=.5) diff --git a/demos/paper_demos/cartesian_wiener_filter.py b/demos/paper_demos/cartesian_wiener_filter.py index ce14a604daac4edd161e9a76e31adbceaac54862..d2ac06f0f5af1374abb72fa7794b2023aec379d0 100644 --- a/demos/paper_demos/cartesian_wiener_filter.py +++ b/demos/paper_demos/cartesian_wiener_filter.py @@ -76,7 +76,7 @@ if __name__ == "__main__": ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=0.1) inverter = ift.ConjugateGradient(controller=ctrl) wiener_curvature = ift.library.WienerFilterCurvature( - S=S, N=N, R=R, inverter=inverter) + S=S, N=N, R=R, inverter=inverter, sampling_inverter=inverter) m_k = wiener_curvature.inverse_times(j) m = ht(m_k) diff --git a/demos/paper_demos/wiener_filter.py b/demos/paper_demos/wiener_filter.py index 8fa668992e5d9570f87b3396596a89069c61c983..dca5a3ced06240fd806802b6ea16055799eb3632 100644 --- a/demos/paper_demos/wiener_filter.py +++ b/demos/paper_demos/wiener_filter.py @@ -50,7 +50,7 @@ if __name__ == "__main__": ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=1e-2) inverter = ift.ConjugateGradient(controller=ctrl) wiener_curvature = ift.library.WienerFilterCurvature( - S=S, N=N, R=R, inverter=inverter) + S=S, N=N, R=R, inverter=inverter, sampling_inverter=inverter) m_k = wiener_curvature.inverse_times(j) m = ht(m_k)