{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Code example: Wiener filter" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## Introduction\n", "IFT starting point:\n", "\n", "$$d = Rs+n$$\n", "\n", "Typically, $s$ is a continuous field, $d$ a discrete data vector. Particularly, $R$ is not invertible.\n", "\n", "IFT aims at **inverting** the above uninvertible problem in the **best possible way** using Bayesian statistics.\n", "\n", "NIFTy (Numerical Information Field Theory) is a Python framework in which IFT problems can be tackled easily.\n", "\n", "Main Interfaces:\n", "\n", "- **Spaces**: Cartesian, 2-Spheres (Healpix, Gauss-Legendre) and their respective harmonic spaces.\n", "- **Fields**: Defined on spaces.\n", "- **Operators**: Acting on fields." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## Wiener filter on one-dimensional fields\n", "\n", "### Assumptions\n", "\n", "- $d=Rs+n$, $R$ linear operator.\n", "- $\\mathcal P (s) = \\mathcal G (s,S)$, $\\mathcal P (n) = \\mathcal G (n,N)$ where $S, N$ are positive definite matrices.\n", "\n", "### Posterior\n", "The Posterior is given by:\n", "\n", "$$\\mathcal P (s|d) \\propto P(s,d) = \\mathcal G(d-Rs,N) \\,\\mathcal G(s,S) \\propto \\mathcal G (s-m,D) $$\n", "\n", "where\n", "$$m = Dj$$\n", "with\n", "$$D = (S^{-1} +R^\\dagger N^{-1} R)^{-1} , \\quad j = R^\\dagger N^{-1} d.$$\n", "\n", "Let us implement this in NIFTy!" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### In NIFTy\n", "\n", "- We assume statistical homogeneity and isotropy. Therefore the signal covariance $S$ is diagonal in harmonic space, and is described by a one-dimensional power spectrum, assumed here as $$P(k) = P_0\\,\\left(1+\\left(\\frac{k}{k_0}\\right)^2\\right)^{-\\gamma /2},$$\n", "with $P_0 = 0.2, k_0 = 5, \\gamma = 4$.\n", "- $N = 0.2 \\cdot \\mathbb{1}$.\n", "- Number of data points $N_{pix} = 512$.\n", "- reconstruction in harmonic space.\n", "- Response operator:\n", "$$R = FFT_{\\text{harmonic} \\rightarrow \\text{position}}$$\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "N_pixels = 512 # Number of pixels\n", "\n", "def pow_spec(k):\n", " P0, k0, gamma = [.2, 5, 4]\n", " return P0 / ((1. + (k/k0)**2)**(gamma / 2))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Implementation" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "#### Import Modules" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true, "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "%matplotlib inline\n", "import numpy as np\n", "import nifty8 as ift\n", "import matplotlib.pyplot as plt\n", "plt.rcParams['figure.dpi'] = 100\n", "plt.style.use(\"seaborn-notebook\")" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### Implement Propagator" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "def Curvature(R, N, Sh):\n", " IC = ift.GradientNormController(iteration_limit=50000,\n", " tol_abs_gradnorm=0.1)\n", " N = ift.SamplingDtypeSetter(N, np.float64)\n", " Sh = ift.SamplingDtypeSetter(Sh, np.float64)\n", " # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy\n", " # helper methods.\n", " return ift.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "skip" } }, "source": [ "#### Conjugate Gradient Preconditioning\n", "\n", "- $D$ is defined via:\n", "$$D^{-1} = \\mathcal S_h^{-1} + R^\\dagger N^{-1} R.$$\n", "In the end, we want to apply $D$ to $j$, i.e. we need the inverse action of $D^{-1}$. This is done numerically (algorithm: *Conjugate Gradient*). \n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### Generate Mock data\n", "\n", "- Generate a field $s$ and $n$ with given covariances.\n", "- Calculate $d$." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "s_space = ift.RGSpace(N_pixels)\n", "h_space = s_space.get_default_codomain()\n", "HT = ift.HarmonicTransformOperator(h_space, target=s_space)\n", "\n", "# Operators\n", "Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)\n", "R = HT # @ ift.create_harmonic_smoothing_operator((h_space,), 0, 0.02)\n", "\n", "# Fields and data\n", "sh = Sh.draw_sample_with_dtype(dtype=np.float64)\n", "noiseless_data=R(sh)\n", "noise_amplitude = np.sqrt(0.2)\n", "N = ift.ScalingOperator(s_space, noise_amplitude**2)\n", "\n", "n = ift.Field.from_random(domain=s_space, random_type='normal',\n", " std=noise_amplitude, mean=0)\n", "d = noiseless_data + n\n", "j = R.adjoint_times(N.inverse_times(d))\n", "curv = Curvature(R=R, N=N, Sh=Sh)\n", "D = curv.inverse" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### Run Wiener Filter" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "m = D(j)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### Results" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "# Get signal data and reconstruction data\n", "s_data = HT(sh).val\n", "m_data = HT(m).val\n", "d_data = d.val\n", "\n", "plt.plot(s_data, 'r', label=\"Signal\", linewidth=2)\n", "plt.plot(d_data, 'k.', label=\"Data\")\n", "plt.plot(m_data, 'k', label=\"Reconstruction\",linewidth=2)\n", "plt.title(\"Reconstruction\")\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "plt.plot(s_data - s_data, 'r', label=\"Signal\", linewidth=2)\n", "plt.plot(d_data - s_data, 'k.', label=\"Data\")\n", "plt.plot(m_data - s_data, 'k', label=\"Reconstruction\",linewidth=2)\n", "plt.axhspan(-noise_amplitude,noise_amplitude, facecolor='0.9', alpha=.5)\n", "plt.title(\"Residuals\")\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### Power Spectrum" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "s_power_data = ift.power_analyze(sh).val\n", "m_power_data = ift.power_analyze(m).val\n", "plt.loglog()\n", "plt.xlim(1, int(N_pixels/2))\n", "ymin = min(m_power_data)\n", "plt.ylim(ymin, 1)\n", "xs = np.arange(1,int(N_pixels/2),.1)\n", "plt.plot(xs, pow_spec(xs), label=\"True Power Spectrum\", color='k',alpha=0.5)\n", "plt.plot(s_power_data, 'r', label=\"Signal\")\n", "plt.plot(m_power_data, 'k', label=\"Reconstruction\")\n", "plt.axhline(noise_amplitude**2 / N_pixels, color=\"k\", linestyle='--', label=\"Noise level\", alpha=.5)\n", "plt.axhspan(noise_amplitude**2 / N_pixels, ymin, facecolor='0.9', alpha=.5)\n", "plt.title(\"Power Spectrum\")\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Wiener Filter on Incomplete Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "# Operators\n", "Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)\n", "N = ift.ScalingOperator(s_space, noise_amplitude**2)\n", "# R is defined below\n", "\n", "# Fields\n", "sh = Sh.draw_sample_with_dtype(dtype=np.float64)\n", "s = HT(sh)\n", "n = ift.Field.from_random(domain=s_space, random_type='normal',\n", " std=noise_amplitude, mean=0)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "skip" } }, "source": [ "### Partially Lose Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "l = int(N_pixels * 0.2)\n", "h = int(N_pixels * 0.2 * 2)\n", "\n", "mask = np.full(s_space.shape, 1.)\n", "mask[l:h] = 0\n", "mask = ift.Field.from_raw(s_space, mask)\n", "\n", "R = ift.DiagonalOperator(mask)(HT)\n", "n = n.val_rw()\n", "n[l:h] = 0\n", "n = ift.Field.from_raw(s_space, n)\n", "\n", "d = R(sh) + n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "curv = Curvature(R=R, N=N, Sh=Sh)\n", "D = curv.inverse\n", "j = R.adjoint_times(N.inverse_times(d))\n", "m = D(j)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Compute Uncertainty\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 200, np.float64)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "skip" } }, "source": [ "### Get data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "# Get signal data and reconstruction data\n", "s_data = s.val\n", "m_data = HT(m).val\n", "m_var_data = m_var.val\n", "uncertainty = np.sqrt(m_var_data)\n", "d_data = d.val_rw()\n", "\n", "# Set lost data to NaN for proper plotting\n", "d_data[d_data == 0] = np.nan" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "plt.axvspan(l, h, facecolor='0.8',alpha=0.5)\n", "plt.fill_between(range(N_pixels), m_data - uncertainty, m_data + uncertainty, facecolor='0.5', alpha=0.5)\n", "plt.plot(s_data, 'r', label=\"Signal\", alpha=1, linewidth=2)\n", "plt.plot(d_data, 'k.', label=\"Data\")\n", "plt.plot(m_data, 'k', label=\"Reconstruction\", linewidth=2)\n", "plt.title(\"Reconstruction of incomplete data\")\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Wiener filter on two-dimensional field" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "N_pixels = 256 # Number of pixels\n", "sigma2 = 2. # Noise variance\n", "\n", "def pow_spec(k):\n", " P0, k0, gamma = [.2, 2, 4]\n", " return P0 * (1. + (k/k0)**2)**(-gamma/2)\n", "\n", "s_space = ift.RGSpace([N_pixels, N_pixels])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "h_space = s_space.get_default_codomain()\n", "HT = ift.HarmonicTransformOperator(h_space,s_space)\n", "\n", "# Operators\n", "Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)\n", "N = ift.ScalingOperator(s_space, sigma2)\n", "\n", "# Fields and data\n", "sh = Sh.draw_sample_with_dtype(dtype=np.float64)\n", "n = ift.Field.from_random(domain=s_space, random_type='normal',\n", " std=np.sqrt(sigma2), mean=0)\n", "\n", "# Lose some data\n", "\n", "l = int(N_pixels * 0.33)\n", "h = int(N_pixels * 0.33 * 2)\n", "\n", "mask = np.full(s_space.shape, 1.)\n", "mask[l:h,l:h] = 0.\n", "mask = ift.Field.from_raw(s_space, mask)\n", "\n", "R = ift.DiagonalOperator(mask)(HT)\n", "n = n.val_rw()\n", "n[l:h, l:h] = 0\n", "n = ift.Field.from_raw(s_space, n)\n", "curv = Curvature(R=R, N=N, Sh=Sh)\n", "D = curv.inverse\n", "\n", "d = R(sh) + n\n", "j = R.adjoint_times(N.inverse_times(d))\n", "\n", "# Run Wiener filter\n", "m = D(j)\n", "\n", "# Uncertainty\n", "m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 20, np.float64)\n", "\n", "# Get data\n", "s_data = HT(sh).val\n", "m_data = HT(m).val\n", "m_var_data = m_var.val\n", "d_data = d.val\n", "uncertainty = np.sqrt(np.abs(m_var_data))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "cmap = ['magma', 'inferno', 'plasma', 'viridis'][1]\n", "\n", "mi = np.min(s_data)\n", "ma = np.max(s_data)\n", "\n", "fig, axes = plt.subplots(1, 2)\n", "\n", "data = [s_data, d_data]\n", "caption = [\"Signal\", \"Data\"]\n", "\n", "for ax in axes.flat:\n", " im = ax.imshow(data.pop(0), interpolation='nearest', cmap=cmap, vmin=mi,\n", " vmax=ma)\n", " ax.set_title(caption.pop(0))\n", "\n", "fig.subplots_adjust(right=0.8)\n", "cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])\n", "fig.colorbar(im, cax=cbar_ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "mi = np.min(s_data)\n", "ma = np.max(s_data)\n", "\n", "fig, axes = plt.subplots(3, 2, figsize=(10, 15))\n", "sample = HT(curv.draw_sample(from_inverse=True)+m).val\n", "post_mean = (m_mean + HT(m)).val\n", "\n", "data = [s_data, m_data, post_mean, sample, s_data - m_data, uncertainty]\n", "caption = [\"Signal\", \"Reconstruction\", \"Posterior mean\", \"Sample\", \"Residuals\", \"Uncertainty Map\"]\n", "\n", "for ax in axes.flat:\n", " im = ax.imshow(data.pop(0), interpolation='nearest', cmap=cmap, vmin=mi, vmax=ma)\n", " ax.set_title(caption.pop(0))\n", "\n", "fig.subplots_adjust(right=0.8)\n", "cbar_ax = fig.add_axes([.85, 0.15, 0.05, 0.7])\n", "fig.colorbar(im, cax=cbar_ax)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Is the uncertainty map reliable?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "precise = (np.abs(s_data-m_data) < uncertainty)\n", "print(\"Error within uncertainty map bounds: \" + str(np.sum(precise) * 100 / N_pixels**2) + \"%\")\n", "\n", "plt.imshow(precise.astype(float), cmap=\"brg\")\n", "plt.colorbar()" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 2 }