Commit a4c2996b authored by Martin Reinecke's avatar Martin Reinecke
Browse files

fix 2D part of notebook

parent 28deed25
Pipeline #24370 passed with stage
in 6 minutes and 33 seconds
...@@ -490,12 +490,13 @@ ...@@ -490,12 +490,13 @@
"source": [ "source": [
"sc = ift.probing.utils.StatCalculator()\n", "sc = ift.probing.utils.StatCalculator()\n",
"\n", "\n",
"IC = ift.GradientNormController(name=\"inverter\", iteration_limit=50000,\n", "IC = ift.GradientNormController(iteration_limit=50000,\n",
" tol_abs_gradnorm=0.1)\n", " tol_abs_gradnorm=0.1)\n",
"inverter = ift.ConjugateGradient(controller=IC)\n", "inverter = ift.ConjugateGradient(controller=IC)\n",
"curv = ift.library.wiener_filter_curvature.WienerFilterCurvature(R,N,Sh,inverter)\n", "curv = ift.library.wiener_filter_curvature.WienerFilterCurvature(R,N,Sh,inverter)\n",
"\n", "\n",
"for i in range(200):\n", "for i in range(200):\n",
" print i\n",
" sc.add(HT(curv.generate_posterior_sample()))\n", " sc.add(HT(curv.generate_posterior_sample()))\n",
"\n", "\n",
"m_var = sc.var" "m_var = sc.var"
...@@ -594,11 +595,11 @@ ...@@ -594,11 +595,11 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"N_pixels = 256 # Number of pixels\n", "N_pixels = 256 # Number of pixels\n",
"sigma2 = 1000 # Noise variance\n", "sigma2 = 10. # Noise variance\n",
"\n", "\n",
"\n", "\n",
"def pow_spec(k):\n", "def pow_spec(k):\n",
" P0, k0, gamma = [.2, 20, 4]\n", " P0, k0, gamma = [.2, 5, 4]\n",
" return P0 * (1. + (k/k0)**2)**(- gamma / 2)\n", " return P0 * (1. + (k/k0)**2)**(- gamma / 2)\n",
"\n", "\n",
"\n", "\n",
...@@ -616,7 +617,7 @@ ...@@ -616,7 +617,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"h_space = s_space.get_default_codomain()\n", "h_space = s_space.get_default_codomain()\n",
"fft = ift.FFTOperator(s_space,h_space)\n", "HT = ift.HarmonicTransformOperator(h_space,s_space)\n",
"p_space = ift.PowerSpace(h_space)\n", "p_space = ift.PowerSpace(h_space)\n",
"\n", "\n",
"# Operators\n", "# Operators\n",
...@@ -627,7 +628,6 @@ ...@@ -627,7 +628,6 @@
"\n", "\n",
"# Fields and data\n", "# Fields and data\n",
"sh = ift.power_synthesize(ift.PS_field(p_space,pow_spec),real_signal=True)\n", "sh = ift.power_synthesize(ift.PS_field(p_space,pow_spec),real_signal=True)\n",
"s = fft.adjoint_times(sh)\n",
"n = ift.Field.from_random(domain=s_space, random_type='normal',\n", "n = ift.Field.from_random(domain=s_space, random_type='normal',\n",
" std=np.sqrt(sigma2), mean=0)\n", " std=np.sqrt(sigma2), mean=0)\n",
"\n", "\n",
...@@ -639,11 +639,11 @@ ...@@ -639,11 +639,11 @@
"mask = ift.Field(s_space, val=1)\n", "mask = ift.Field(s_space, val=1)\n",
"mask.val[l:h,l:h] = 0\n", "mask.val[l:h,l:h] = 0\n",
"\n", "\n",
"R = ift.DiagonalOperator(mask)\n", "R = ift.DiagonalOperator(mask)*HT\n",
"n.val[l:h, l:h] = 0\n", "n.val[l:h, l:h] = 0\n",
"D = PropagatorOperator(R=R, N=N, Sh=fft.inverse*Sh*fft)\n", "D = PropagatorOperator(R=R, N=N, Sh=Sh)\n",
"\n", "\n",
"d = R(s) + n\n", "d = R(sh) + n\n",
"j = R.adjoint_times(N.inverse_times(d))\n", "j = R.adjoint_times(N.inverse_times(d))\n",
"\n", "\n",
"# Run Wiener filter\n", "# Run Wiener filter\n",
...@@ -652,28 +652,26 @@ ...@@ -652,28 +652,26 @@
"# Uncertainty\n", "# Uncertainty\n",
"sc = ift.probing.utils.StatCalculator()\n", "sc = ift.probing.utils.StatCalculator()\n",
"\n", "\n",
"IC = ift.GradientNormController(name=\"inverter\", iteration_limit=50000,\n", "IC = ift.GradientNormController(iteration_limit=50000,\n",
" tol_abs_gradnorm=0.1)\n", " tol_abs_gradnorm=0.1)\n",
"inverter = ift.ConjugateGradient(controller=IC)\n", "inverter = ift.ConjugateGradient(controller=IC)\n",
"curv = ift.library.wiener_filter_curvature.WienerFilterCurvature(R,N,fft.inverse*Sh*fft,inverter)\n", "curv = ift.library.wiener_filter_curvature.WienerFilterCurvature(R,N,Sh,inverter)\n",
"\n", "\n",
"for i in range(20):\n", "for i in range(20):\n",
" print i\n",
" sc.add(HT(curv.generate_posterior_sample()))\n", " sc.add(HT(curv.generate_posterior_sample()))\n",
"\n", "\n",
"m_var = sc.var\n", "m_var = sc.var\n",
"diagProber = DiagonalProber(domain=s_space, probe_dtype=np.complex, probe_count=10)\n",
"diagProber(D)\n",
"m_var = Field(s_space, val=diagProber.diagonal.val).weight(-1)\n",
"\n", "\n",
"# Get data\n", "# Get data\n",
"s_power = sh.power_analyze()\n", "s_power = ift.power_analyze(sh)\n",
"m_power = fft(m).power_analyze()\n", "m_power = ift.power_analyze(m)\n",
"s_power_data = s_power.val.get_full_data().real\n", "s_power_data = s_power.val.real\n",
"m_power_data = m_power.val.get_full_data().real\n", "m_power_data = m_power.val.real\n",
"s_data = s.val.get_full_data().real\n", "s_data = HT(sh).val.real\n",
"m_data = m.val.get_full_data().real\n", "m_data = HT(m).val.real\n",
"m_var_data = m_var.val.get_full_data().real\n", "m_var_data = m_var.val.real\n",
"d_data = d.val.get_full_data().real\n", "d_data = d.val.real\n",
"\n", "\n",
"uncertainty = np.sqrt(np.abs(m_var_data))" "uncertainty = np.sqrt(np.abs(m_var_data))"
] ]
...@@ -708,15 +706,6 @@ ...@@ -708,15 +706,6 @@
"fig.colorbar(im, cax=cbar_ax)" "fig.colorbar(im, cax=cbar_ax)"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
...@@ -744,19 +733,6 @@ ...@@ -744,19 +733,6 @@
"fig.colorbar(im, cax=cbar_ax)" "fig.colorbar(im, cax=cbar_ax)"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"fig"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
...@@ -783,8 +759,7 @@ ...@@ -783,8 +759,7 @@
"\n", "\n",
"fig = plt.figure()\n", "fig = plt.figure()\n",
"plt.imshow(precise.astype(float), cmap=\"brg\")\n", "plt.imshow(precise.astype(float), cmap=\"brg\")\n",
"plt.colorbar()\n", "plt.colorbar()"
"fig"
] ]
}, },
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment