diff --git a/data/mock_signal.npz b/data/mock_signal.npz new file mode 100644 index 0000000000000000000000000000000000000000..62e13b67caf73c63938a9c6e6a5bd1490f5bd947 Binary files /dev/null and b/data/mock_signal.npz differ diff --git a/demo_radio.ipynb b/demo_radio.ipynb index fc4269fdfd97cfe363d65aec342875ca574bf723..71e132948c46d792209b7d760f560eebf6b072a8 100644 --- a/demo_radio.ipynb +++ b/demo_radio.ipynb @@ -242,6 +242,26 @@ "source": [ "from IPython.display import clear_output\n", "\n", + "def _imshow(figure, field, ax, title, vmin = 0, vmax = None, cmap='afmhot'):\n", + " im0 = ax.imshow(field.val.T, origin = 'lower', extent = [0, x_fov, 0, y_fov], cmap=cmap,\n", + " vmin = vmin, vmax = vmax)\n", + " figure.colorbar(im0, ax=ax)\n", + " ax.set_xlabel(r'x $\\left(\\mu as\\right)$')\n", + " ax.set_ylabel(r'y $\\left(\\mu as\\right)$')\n", + " ax.set_title(title)\n", + "\n", + "def _plot_histogram(nodes, hist, ax, title, ):\n", + " nodes = 0.5*(nodes[1:] + nodes[:-1])\n", + " ax.bar(nodes, hist)\n", + " rs = np.arange(nodes[0], nodes[-1], 0.1)\n", + " gauss = np.exp(-0.5*rs**2)/np.sqrt(2*np.pi)\n", + " ax.plot(rs, gauss, 'k--', label = r'standard Gauss')\n", + " ax.set_xlabel(r'$r$')\n", + " ax.set_ylabel(r'$P(r)$')\n", + " ax.set_title(title)\n", + " ax.legend()\n", + " ax.set_xlim([nodes[0], nodes[-1]])\n", + "\n", "def plotting_callback(samples):\n", " clear_output(wait=True) \n", "\n", @@ -260,21 +280,15 @@ " wgt, nodes = np.histogram(rr, nbins, range=[-5, 5])\n", " hist += wgt/wgt.sum()/(nodes[1]-nodes[0])\n", " hist /= samples.n_samples\n", - " nodes = 0.5*(nodes[1:] + nodes[:-1])\n", "\n", "\n", "\n", " fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15,13))\n", " axs = axs.flatten()\n", - " im0 = axs[0].imshow(sky_mean.val.T, origin = 'lower', extent = [0, x_fov, 0, y_fov], cmap='afmhot')\n", - " fig.colorbar(im0, ax=axs[0])\n", - " axs[0].set_xlabel(r'x $\\left(\\mu as\\right)$')\n", - " axs[0].set_ylabel(r'y $\\left(\\mu as\\right)$')\n", - " axs[0].set_title('Sky brightness mean')\n", - " im1 = axs[1].imshow(sky_var.sqrt().val.T, origin = 'lower', extent = [0, x_fov, 0, y_fov])\n", - " fig.colorbar(im1, ax=axs[1])\n", - " axs[1].set_xlabel(r'x $\\left(\\mu as\\right)$')\n", - " axs[1].set_title('Sky brightness std')\n", + "\n", + " _imshow(fig, sky_mean, axs[0], 'Sky brightness mean', vmax = 2e19)\n", + " _imshow(fig, sky_var.sqrt(), axs[1], 'Sky brightness std', cmap = 'viridis')\n", + "\n", " axs[1].yaxis.set_visible(False)\n", " k_lengths = pspec_mean.domain[0].k_lengths[1:]\n", " lbl = 'samples'\n", @@ -290,14 +304,9 @@ " axs[2].set_title(r'Power-spectrum of log-sky brightness')\n", " axs[2].legend()\n", "\n", - " axs[3].bar(nodes, hist)\n", - " rs = np.arange(-5, 5, 0.1)\n", - " gauss = np.exp(-0.5*rs**2)/np.sqrt(2*np.pi)\n", - " axs[3].plot(rs, gauss, 'k--', label = r'standard Gauss')\n", - " axs[3].set_xlabel(r'$r$')\n", - " axs[3].set_ylabel(r'$P(r)$')\n", - " axs[3].set_title(r'Inverse noise weighted data residual ($r$) distribution ($P(r)$)')\n", - " axs[3].legend()\n", + " _plot_histogram(nodes, hist, axs[3], \n", + " r'Inverse noise weighted data residual ($r$) distribution ($P(r)$)')\n", + "\n", " fig.tight_layout()\n", " plt.show();" ] @@ -342,6 +351,7 @@ "# generation and optimization.\n", "n_iterations = 15 # Total number of iterations. \n", "n_samples = (lambda iiter: 2 if iiter < 10 else 5) # Number of samples used for KL approximation\n", + "\n", "minimizer = (lambda iiter: minimizer_early if iiter < 10 else minimizer_late) # When to use which optimizer\n", "minimizer_sampling = (lambda iiter: None if iiter < 10 else ift.NewtonCG(ic_sampling_nl)) # initially MGVI, \n", " # later geoVI\n", @@ -353,6 +363,36 @@ " output_directory=\"mock_inference\", overwrite=True)\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Comparison of posterior to ground truth\n", + "=======================================" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mock_sky = np.load('data/mock_signal.npz')['signal']\n", + "mock_sky = ift.makeField(space, mock_sky)\n", + "\n", + "sky_mean, sky_var = samples.sample_stat(sky_model)\n", + "sky_samples = list(s for s in samples.iterator(sky_model))\n", + "\n", + "fig, axs = plt.subplots(nrows=3, ncols=2, figsize = (15,18))\n", + "_imshow(fig, mock_sky, axs[0,0], 'Sky brightness ground truth', vmax = 2e19)\n", + "_imshow(fig, sky_mean, axs[0,1], 'Sky brightness mean', vmax = 2e19)\n", + "_imshow(fig, sky_var.sqrt()/sky_mean, axs[1,0], 'Sky brightness relative uncertainty')\n", + "_imshow(fig, sky_samples[0], axs[1,1], 'Sky brightness posterior sample (1)', vmax = 2e19)\n", + "_imshow(fig, sky_samples[1], axs[2,0], 'Sky brightness posterior sample (2)', vmax = 2e19)\n", + "_imshow(fig, sky_samples[2], axs[2,1], 'Sky brightness posterior sample (3)', vmax = 2e19)\n", + "fig.tight_layout()\n" + ] + }, { "cell_type": "code", "execution_count": null,