From d3f0bf6c0bfeaf2259c7accf3020b799e8e6c7cc Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Mon, 5 Feb 2018 11:06:52 +0100
Subject: [PATCH] cleanup
---
demos/Wiener_Filter.ipynb | 25 ++++++++++---------------
1 file changed, 10 insertions(+), 15 deletions(-)
diff --git a/demos/Wiener_Filter.ipynb b/demos/Wiener_Filter.ipynb
index cf92a9eae..a8da9671a 100644
--- a/demos/Wiener_Filter.ipynb
+++ b/demos/Wiener_Filter.ipynb
@@ -166,13 +166,13 @@
},
"outputs": [],
"source": [
- "def PropagatorOperator(R, N, Sh):\n",
+ "def Curvature(R, N, Sh):\n",
" IC = ift.GradientNormController(iteration_limit=50000,\n",
" tol_abs_gradnorm=0.1)\n",
" inverter = ift.ConjugateGradient(controller=IC)\n",
- " D = (R.adjoint*N.inverse*R + Sh.inverse).inverse\n",
- " # MR FIXME: we can/should provide a preconditioner here as well!\n",
- " return ift.InversionEnabler(D, inverter)\n"
+ " # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy\n",
+ " # helper methods.\n",
+ " return ift.library.WienerFilterCurvature(R,N,Sh,inverter)\n"
]
},
{
@@ -245,7 +245,8 @@
" std=noise_amplitude, mean=0)\n",
"d = noiseless_data + n\n",
"j = R.adjoint_times(N.inverse_times(d))\n",
- "D = PropagatorOperator(R=R, N=N, Sh=Sh)"
+ "curv = Curvature(R=R, N=N, Sh=Sh)\n",
+ "D = curv.inverse"
]
},
{
@@ -468,7 +469,8 @@
},
"outputs": [],
"source": [
- "D = PropagatorOperator(R=R, N=N, Sh=Sh)\n",
+ "curv = Curvature(R=R, N=N, Sh=Sh)\n",
+ "D = curv.inverse\n",
"j = R.adjoint_times(N.inverse_times(d))\n",
"m = D(j)"
]
@@ -493,12 +495,6 @@
"outputs": [],
"source": [
"sc = ift.probing.utils.StatCalculator()\n",
- "\n",
- "IC = ift.GradientNormController(iteration_limit=50000,\n",
- " tol_abs_gradnorm=0.1)\n",
- "inverter = ift.ConjugateGradient(controller=IC)\n",
- "curv = ift.library.wiener_filter_curvature.WienerFilterCurvature(R,N,Sh,inverter)\n",
- "\n",
"for i in range(200):\n",
" print i\n",
" sc.add(HT(curv.generate_posterior_sample()))\n",
@@ -627,8 +623,6 @@
"# Operators\n",
"Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)\n",
"N = ift.ScalingOperator(sigma2,s_space)\n",
- "R = ift.FFTSmoothingOperator(s_space, sigma=.01)\n",
- "#D = PropagatorOperator(R=R, N=N, Sh=Sh)\n",
"\n",
"# Fields and data\n",
"sh = ift.power_synthesize(ift.PS_field(p_space,pow_spec),real_signal=True)\n",
@@ -645,7 +639,8 @@
"\n",
"R = ift.DiagonalOperator(mask)*HT\n",
"n.val[l:h, l:h] = 0\n",
- "D = PropagatorOperator(R=R, N=N, Sh=Sh)\n",
+ "curv = Curvature(R=R, N=N, Sh=Sh)\n",
+ "D = curv.inverse\n",
"\n",
"d = R(sh) + n\n",
"j = R.adjoint_times(N.inverse_times(d))\n",
--
GitLab