Commit e8160f7e authored by Philipp Arras's avatar Philipp Arras
Browse files

Migrate from ipynb to scripts+jupytext

parent 16066ede
Pipeline #123189 passed with stages
in 18 minutes and 30 seconds
......@@ -78,11 +78,13 @@ before_script:
run_ipynb0:
stage: demo_runs
script:
- jupytext --to ipynb demos/getting_started_0.py
- jupyter nbconvert --execute --ExecutePreprocessor.timeout=None --to html demos/getting_started_0.ipynb
run_ipynb1:
stage: demo_runs
script:
- jupytext --to ipynb demos/getting_started_4_CorrelatedFields.py
- jupyter nbconvert --execute --ExecutePreprocessor.timeout=None --to html demos/getting_started_4_CorrelatedFields.ipynb
run_getting_started_1:
......
......@@ -12,7 +12,7 @@ RUN apt-get update && apt-get install -y \
# Optional NIFTy dependencies
python3-mpi4py python3-matplotlib python3-h5py python3-astropy \
# more optional NIFTy dependencies
&& DUCC0_OPTIMIZATION=portable pip3 install ducc0 finufft jupyter jax jaxlib sphinx pydata-sphinx-theme \
&& DUCC0_OPTIMIZATION=portable pip3 install ducc0 finufft jupyter jax jaxlib sphinx pydata-sphinx-theme jupytext \
&& rm -rf /var/lib/apt/lists/*
# Set matplotlib backend
......
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Code example: Wiener filter"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"## Introduction\n",
"IFT starting point:\n",
"\n",
"$$d = Rs+n$$\n",
"\n",
"Typically, $s$ is a continuous field, $d$ a discrete data vector. Particularly, $R$ is not invertible.\n",
"\n",
"IFT aims at **inverting** the above uninvertible problem in the **best possible way** using Bayesian statistics.\n",
"\n",
"NIFTy (Numerical Information Field Theory) is a Python framework in which IFT problems can be tackled easily.\n",
"\n",
"Main Interfaces:\n",
"\n",
"- **Spaces**: Cartesian, 2-Spheres (Healpix, Gauss-Legendre) and their respective harmonic spaces.\n",
"- **Fields**: Defined on spaces.\n",
"- **Operators**: Acting on fields."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"## Wiener filter on one-dimensional fields\n",
"\n",
"### Assumptions\n",
"\n",
"- $d=Rs+n$, $R$ linear operator.\n",
"- $\\mathcal P (s) = \\mathcal G (s,S)$, $\\mathcal P (n) = \\mathcal G (n,N)$ where $S, N$ are positive definite matrices.\n",
"\n",
"### Posterior\n",
"The Posterior is given by:\n",
"\n",
"$$\\mathcal P (s|d) \\propto P(s,d) = \\mathcal G(d-Rs,N) \\,\\mathcal G(s,S) \\propto \\mathcal G (s-m,D) $$\n",
"\n",
"where\n",
"$$m = Dj$$\n",
"with\n",
"$$D = (S^{-1} +R^\\dagger N^{-1} R)^{-1} , \\quad j = R^\\dagger N^{-1} d.$$\n",
"\n",
"Let us implement this in NIFTy!"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### In NIFTy\n",
"\n",
"- We assume statistical homogeneity and isotropy. Therefore the signal covariance $S$ is diagonal in harmonic space, and is described by a one-dimensional power spectrum, assumed here as $$P(k) = P_0\\,\\left(1+\\left(\\frac{k}{k_0}\\right)^2\\right)^{-\\gamma /2},$$\n",
"with $P_0 = 0.2, k_0 = 5, \\gamma = 4$.\n",
"- $N = 0.2 \\cdot \\mathbb{1}$.\n",
"- Number of data points $N_{pix} = 512$.\n",
"- reconstruction in harmonic space.\n",
"- Response operator:\n",
"$$R = FFT_{\\text{harmonic} \\rightarrow \\text{position}}$$\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [],
"source": [
"N_pixels = 512 # Number of pixels\n",
"\n",
"def pow_spec(k):\n",
" P0, k0, gamma = [.2, 5, 4]\n",
" return P0 / ((1. + (k/k0)**2)**(gamma / 2))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### Implementation"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"#### Import Modules"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true,
"slideshow": {
"slide_type": "-"
}
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import numpy as np\n",
"import nifty8 as ift\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['figure.dpi'] = 100\n",
"plt.style.use(\"seaborn-notebook\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Implement Propagator"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [],
"source": [
"def Curvature(R, N, Sh):\n",
" IC = ift.GradientNormController(iteration_limit=50000,\n",
" tol_abs_gradnorm=0.1)\n",
" # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy\n",
" # helper methods.\n",
" return ift.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"source": [
"#### Conjugate Gradient Preconditioning\n",
"\n",
"- $D$ is defined via:\n",
"$$D^{-1} = \\mathcal S_h^{-1} + R^\\dagger N^{-1} R.$$\n",
"In the end, we want to apply $D$ to $j$, i.e. we need the inverse action of $D^{-1}$. This is done numerically (algorithm: *Conjugate Gradient*). \n",
"\n",
"<!--\n",
"- One can define the *condition number* of a non-singular and normal matrix $A$:\n",
"$$\\kappa (A) := \\frac{|\\lambda_{\\text{max}}|}{|\\lambda_{\\text{min}}|},$$\n",
"where $\\lambda_{\\text{max}}$ and $\\lambda_{\\text{min}}$ are the largest and smallest eigenvalue of $A$, respectively.\n",
"\n",
"- The larger $\\kappa$ the slower Conjugate Gradient.\n",
"\n",
"- By default, conjugate gradient solves: $D^{-1} m = j$ for $m$, where $D^{-1}$ can be badly conditioned. If one knows a non-singular matrix $T$ for which $TD^{-1}$ is better conditioned, one can solve the equivalent problem:\n",
"$$\\tilde A m = \\tilde j,$$\n",
"where $\\tilde A = T D^{-1}$ and $\\tilde j = Tj$.\n",
"\n",
"- In our case $S^{-1}$ is responsible for the bad conditioning of $D$ depending on the chosen power spectrum. Thus, we choose\n",
"\n",
"$$T = \\mathcal F^\\dagger S_h^{-1} \\mathcal F.$$\n",
"-->"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Generate Mock data\n",
"\n",
"- Generate a field $s$ and $n$ with given covariances.\n",
"- Calculate $d$."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"s_space = ift.RGSpace(N_pixels)\n",
"h_space = s_space.get_default_codomain()\n",
"HT = ift.HarmonicTransformOperator(h_space, target=s_space)\n",
"\n",
"# Operators\n",
"Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec, sampling_dtype=float)\n",
"R = HT # @ ift.create_harmonic_smoothing_operator((h_space,), 0, 0.02)\n",
"\n",
"# Fields and data\n",
"sh = Sh.draw_sample()\n",
"noiseless_data=R(sh)\n",
"noise_amplitude = np.sqrt(0.2)\n",
"N = ift.ScalingOperator(s_space, noise_amplitude**2, float)\n",
"\n",
"n = ift.Field.from_random(domain=s_space, random_type='normal',\n",
" std=noise_amplitude, mean=0)\n",
"d = noiseless_data + n\n",
"j = R.adjoint_times(N.inverse_times(d))\n",
"curv = Curvature(R=R, N=N, Sh=Sh)\n",
"D = curv.inverse"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Run Wiener Filter"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [],
"source": [
"m = D(j)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [],
"source": [
"# Get signal data and reconstruction data\n",
"s_data = HT(sh).val\n",
"m_data = HT(m).val\n",
"d_data = d.val\n",
"\n",
"plt.plot(s_data, 'r', label=\"Signal\", linewidth=2)\n",
"plt.plot(d_data, 'k.', label=\"Data\")\n",
"plt.plot(m_data, 'k', label=\"Reconstruction\",linewidth=2)\n",
"plt.title(\"Reconstruction\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"plt.plot(s_data - s_data, 'r', label=\"Signal\", linewidth=2)\n",
"plt.plot(d_data - s_data, 'k.', label=\"Data\")\n",
"plt.plot(m_data - s_data, 'k', label=\"Reconstruction\",linewidth=2)\n",
"plt.axhspan(-noise_amplitude,noise_amplitude, facecolor='0.9', alpha=.5)\n",
"plt.title(\"Residuals\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Power Spectrum"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [],
"source": [
"s_power_data = ift.power_analyze(sh).val\n",
"m_power_data = ift.power_analyze(m).val\n",
"plt.loglog()\n",
"plt.xlim(1, int(N_pixels/2))\n",
"ymin = min(m_power_data)\n",
"plt.ylim(ymin, 1)\n",
"xs = np.arange(1,int(N_pixels/2),.1)\n",
"plt.plot(xs, pow_spec(xs), label=\"True Power Spectrum\", color='k',alpha=0.5)\n",
"plt.plot(s_power_data, 'r', label=\"Signal\")\n",
"plt.plot(m_power_data, 'k', label=\"Reconstruction\")\n",
"plt.axhline(noise_amplitude**2 / N_pixels, color=\"k\", linestyle='--', label=\"Noise level\", alpha=.5)\n",
"plt.axhspan(noise_amplitude**2 / N_pixels, ymin, facecolor='0.9', alpha=.5)\n",
"plt.title(\"Power Spectrum\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Wiener Filter on Incomplete Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [],
"source": [
"# Operators\n",
"Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec, sampling_dtype=float)\n",
"N = ift.ScalingOperator(s_space, noise_amplitude**2, sampling_dtype=float)\n",
"# R is defined below\n",
"\n",
"# Fields\n",
"sh = Sh.draw_sample()\n",
"s = HT(sh)\n",
"n = ift.Field.from_random(domain=s_space, random_type='normal',\n",
" std=noise_amplitude, mean=0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"source": [
"### Partially Lose Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [],
"source": [
"l = int(N_pixels * 0.2)\n",
"h = int(N_pixels * 0.2 * 2)\n",
"\n",
"mask = np.full(s_space.shape, 1.)\n",
"mask[l:h] = 0\n",
"mask = ift.Field.from_raw(s_space, mask)\n",
"\n",
"R = ift.DiagonalOperator(mask) @ HT\n",
"n = n.val_rw()\n",
"n[l:h] = 0\n",
"n = ift.Field.from_raw(s_space, n)\n",
"\n",
"d = R(sh) + n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [],
"source": [
"curv = Curvature(R=R, N=N, Sh=Sh)\n",
"D = curv.inverse\n",
"j = R.adjoint_times(N.inverse_times(d))\n",
"m = D(j)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Compute Uncertainty\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 200, np.float64)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"source": [
"### Get data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [],
"source": [
"# Get signal data and reconstruction data\n",
"s_data = s.val\n",
"m_data = HT(m).val\n",
"m_var_data = m_var.val\n",
"uncertainty = np.sqrt(m_var_data)\n",
"d_data = d.val_rw()\n",
"\n",
"# Set lost data to NaN for proper plotting\n",
"d_data[d_data == 0] = np.nan"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [],
"source": [
"plt.axvspan(l, h, facecolor='0.8',alpha=0.5)\n",
"plt.fill_between(range(N_pixels), m_data - uncertainty, m_data + uncertainty, facecolor='0.5', alpha=0.5)\n",
"plt.plot(s_data, 'r', label=\"Signal\", alpha=1, linewidth=2)\n",
"plt.plot(d_data, 'k.', label=\"Data\")\n",
"plt.plot(m_data, 'k', label=\"Reconstruction\", linewidth=2)\n",
"plt.title(\"Reconstruction of incomplete data\")\n",
"plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Wiener filter on two-dimensional field"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"N_pixels = 256 # Number of pixels\n",
"sigma2 = 2. # Noise variance\n",
"\n",
"def pow_spec(k):\n",
" P0, k0, gamma = [.2, 2, 4]\n",
" return P0 * (1. + (k/k0)**2)**(-gamma/2)\n",
"\n",
"s_space = ift.RGSpace([N_pixels, N_pixels])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [],
"source": [
"h_space = s_space.get_default_codomain()\n",
"HT = ift.HarmonicTransformOperator(h_space,s_space)\n",
"\n",
"# Operators\n",
"Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec, sampling_dtype=float)\n",
"N = ift.ScalingOperator(s_space, sigma2, sampling_dtype=float)\n",
"\n",
"# Fields and data\n",
"sh = Sh.draw_sample()\n",
"n = ift.Field.from_random(domain=s_space, random_type='normal',\n",
" std=np.sqrt(sigma2), mean=0)\n",
"\n",
"# Lose some data\n",
"\n",
"l = int(N_pixels * 0.33)\n",
"h = int(N_pixels * 0.33 * 2)\n",
"\n",
"mask = np.full(s_space.shape, 1.)\n",
"mask[l:h,l:h] = 0.\n",
"mask = ift.Field.from_raw(s_space, mask)\n",
"\n",
"R = ift.DiagonalOperator(mask)(HT)\n",
"n = n.val_rw()\n",
"n[l:h, l:h] = 0\n",
"n = ift.Field.from_raw(s_space, n)\n",
"curv = Curvature(R=R, N=N, Sh=Sh)\n",
"D = curv.inverse\n",
"\n",
"d = R(sh) + n\n",
"j = R.adjoint_times(N.inverse_times(d))\n",
"\n",
"# Run Wiener filter\n",
"m = D(j)\n",
"\n",
"# Uncertainty\n",
"m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 20, np.float64)\n",
"\n",
"# Get data\n",
"s_data = HT(sh).val\n",
"m_data = HT(m).val\n",
"m_var_data = m_var.val\n",
"d_data = d.val\n",
"uncertainty = np.sqrt(np.abs(m_var_data))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [],
"source": [
"cmap = ['magma', 'inferno', 'plasma', 'viridis'][1]\n",
"\n",
"mi = np.min(s_data)\n",
"ma = np.max(s_data)\n",
"\n",
"fig, axes = plt.subplots(1, 2)\n",
"\n",
"data = [s_data, d_data]\n",
"caption = [\"Signal\", \"Data\"]\n",
"\n",
"for ax in axes.flat:\n",
" im = ax.imshow(data.pop(0), interpolation='nearest', cmap=cmap, vmin=mi,\n",
" vmax=ma)\n",
" ax.set_title(caption.pop(0))\n",