diff --git a/demos/Wiener_Filter.ipynb b/demos/Wiener_Filter.ipynb index cf92a9eaee7d68300ba5799b54f4e5c17ba20a62..a8da9671a406089e48689de1a996826de1522a84 100644 --- a/demos/Wiener_Filter.ipynb +++ b/demos/Wiener_Filter.ipynb @@ -166,13 +166,13 @@ }, "outputs": [], "source": [ - "def PropagatorOperator(R, N, Sh):\n", + "def Curvature(R, N, Sh):\n", " IC = ift.GradientNormController(iteration_limit=50000,\n", " tol_abs_gradnorm=0.1)\n", " inverter = ift.ConjugateGradient(controller=IC)\n", - " D = (R.adjoint*N.inverse*R + Sh.inverse).inverse\n", - " # MR FIXME: we can/should provide a preconditioner here as well!\n", - " return ift.InversionEnabler(D, inverter)\n" + " # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy\n", + " # helper methods.\n", + " return ift.library.WienerFilterCurvature(R,N,Sh,inverter)\n" ] }, { @@ -245,7 +245,8 @@ " std=noise_amplitude, mean=0)\n", "d = noiseless_data + n\n", "j = R.adjoint_times(N.inverse_times(d))\n", - "D = PropagatorOperator(R=R, N=N, Sh=Sh)" + "curv = Curvature(R=R, N=N, Sh=Sh)\n", + "D = curv.inverse" ] }, { @@ -468,7 +469,8 @@ }, "outputs": [], "source": [ - "D = PropagatorOperator(R=R, N=N, Sh=Sh)\n", + "curv = Curvature(R=R, N=N, Sh=Sh)\n", + "D = curv.inverse\n", "j = R.adjoint_times(N.inverse_times(d))\n", "m = D(j)" ] @@ -493,12 +495,6 @@ "outputs": [], "source": [ "sc = ift.probing.utils.StatCalculator()\n", - "\n", - "IC = ift.GradientNormController(iteration_limit=50000,\n", - " tol_abs_gradnorm=0.1)\n", - "inverter = ift.ConjugateGradient(controller=IC)\n", - "curv = ift.library.wiener_filter_curvature.WienerFilterCurvature(R,N,Sh,inverter)\n", - "\n", "for i in range(200):\n", " print i\n", " sc.add(HT(curv.generate_posterior_sample()))\n", @@ -627,8 +623,6 @@ "# Operators\n", "Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)\n", "N = ift.ScalingOperator(sigma2,s_space)\n", - "R = ift.FFTSmoothingOperator(s_space, sigma=.01)\n", - "#D = PropagatorOperator(R=R, N=N, Sh=Sh)\n", "\n", "# Fields and data\n", "sh = ift.power_synthesize(ift.PS_field(p_space,pow_spec),real_signal=True)\n", @@ -645,7 +639,8 @@ "\n", "R = ift.DiagonalOperator(mask)*HT\n", "n.val[l:h, l:h] = 0\n", - "D = PropagatorOperator(R=R, N=N, Sh=Sh)\n", + "curv = Curvature(R=R, N=N, Sh=Sh)\n", + "D = curv.inverse\n", "\n", "d = R(sh) + n\n", "j = R.adjoint_times(N.inverse_times(d))\n",