diff --git a/Dockerfile b/Dockerfile index 7f6c11883187228404da766e516fe35c264c391c..fa289e73b84c2287320371c0a5184a433cd3e367 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,7 +16,7 @@ RUN pip3 install --break-system-packages \ # Testing dependencies pytest pytest-cov pytest-xdist \ # Documentation build dependencies - jupyter nbconvert jupytext sphinx pydata-sphinx-theme myst-parser + jupyter nbconvert jupytext sphinx pydata-sphinx-theme myst-parser sphinxcontrib-bibtex # Create user (openmpi does not like to be run as root) RUN useradd -ms /bin/bash runner diff --git a/README.md b/README.md index 53d7ed17eb8de065fd74b9830730260dcef0e27e..3bb2eb9bed9f25a7e4eacbf8c2921b7efd843faf 100644 --- a/README.md +++ b/README.md @@ -177,7 +177,7 @@ To build the documentation locally, run: ``` sudo apt-get install dvipng jupyter-nbconvert texlive-latex-base texlive-latex-extra -pip install --user sphinx jupytext pydata-sphinx-theme myst-parser +pip install --user sphinx jupytext pydata-sphinx-theme myst-parser sphinxcontrib-bibtex cd <nifty_directory> bash docs/generate.sh ``` diff --git a/docs/generate.sh b/docs/generate.sh index dca15072d1f36d9c83dcdf7c67252ef1e99fe234..0f388263488cf9c416c0e100419a32a62fc0b4fb 100755 --- a/docs/generate.sh +++ b/docs/generate.sh @@ -2,9 +2,9 @@ set -e -FOLDER=docs/source/user/ +FOLDER=docs/source/user -for FILE in ${FOLDER}/0_intro ${FOLDER}/old_nifty_getting_started_0 ${FOLDER}/old_nifty_getting_started_4_CorrelatedFields ${FOLDER}/old_nifty_custom_nonlinearities; do +for FILE in ${FOLDER}/0_intro ${FOLDER}/old_nifty_getting_started_0 ${FOLDER}/old_nifty_getting_started_4_CorrelatedFields ${FOLDER}/old_nifty_custom_nonlinearities ${FOLDER}/niftyre_getting_started_4_CorrelatedFields; do if [ ! -f "${FILE}.md" ] || [ "${FILE}.ipynb" -nt "${FILE}.md" ]; then jupytext --to ipynb "${FILE}.py" jupyter-nbconvert --to markdown --execute --ExecutePreprocessor.timeout=None "${FILE}.ipynb" @@ -12,5 +12,5 @@ for FILE in ${FOLDER}/0_intro ${FOLDER}/old_nifty_getting_started_0 ${FOLDER}/ol done EXCLUDE="nifty8/logger.py" -sphinx-apidoc -e -o docs/source/mod nifty8 ${EXCLUDE} +sphinx-apidoc -e -d 1 -o docs/source/mod nifty8 ${EXCLUDE} sphinx-build -b html docs/source/ docs/build/ diff --git a/docs/source/conf.py b/docs/source/conf.py index 67ecab25e064e798423dddb5cdd18a251c9db61c..8755fa888ff40113c617a2b46dd139f0917c85fd 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -4,11 +4,14 @@ needs_sphinx = '3.2.0' extensions = [ 'sphinx.ext.napoleon', # Support for NumPy and Google style docstrings - 'sphinx.ext.imgmath', # Render math as images + 'sphinx.ext.mathjax', # Render math as images 'sphinx.ext.viewcode', # Add links to highlighted source code 'sphinx.ext.intersphinx', # Links to other sphinx docs (mostly numpy) 'myst_parser', # Parse markdown -] + 'sphinxcontrib.bibtex', +] + +bibtex_bibfiles = ['user/paper.bib'] master_doc = 'index' myst_enable_extensions = [ @@ -35,6 +38,7 @@ napoleon_use_admonition_for_references = True napoleon_include_special_with_doc = True imgmath_embed = True +numfig = True project = u'NIFTy8' copyright = u'2013-2022, Max-Planck-Society' diff --git a/docs/source/user/citations.rst b/docs/source/user/citations.rst index e76ac550658f23f70baa9d88db628c99f252b368..616785c9b5c04ea0096b2e0c97837ab01a50ff7e 100644 --- a/docs/source/user/citations.rst +++ b/docs/source/user/citations.rst @@ -1,5 +1,5 @@ -Citing NIFTy and the Magic Inside -================================= +Citing NIFTy +============ To cite the software library NIFTy, please use the following references: diff --git a/docs/source/user/index.rst b/docs/source/user/index.rst index f2fd41129a2bf0370ad8d81b50bf5baa80524bc7..9f820a1561ecf2acbb4f9a49b569e8d4fa72ce88 100644 --- a/docs/source/user/index.rst +++ b/docs/source/user/index.rst @@ -2,23 +2,28 @@ NIFTy user guide ================ -This guide is an overview and explains the main idea behind NIFTy. +This guide is an overview and explains the main conceptual idea behind NIFTy (Numerical Information Field Theory). More details on the API can be found at the `API reference <../mod/nifty8.html>`_. In the following, an online version of the latest publication describing NIFTy is provided. -The publication gives a quick overview of the core concepts powering NIFTy and showcases a simple example: +The publication below gives a quick overview of the core concepts powering NIFTy and showcases a simple example. .. toctree:: :maxdepth: 1 - + paper -The demonstration scripts which are part of NIFTy usually provide a good foundation for new projects in NIFTy: - .. toctree:: :maxdepth: 1 0_intro +The following page provides a demonstration of the Correlated field model in NIFTy.re: + +.. toctree:: + :maxdepth: 1 + + niftyre_getting_started_4_CorrelatedFields + In-depth discussion of important concepts in NIFTy and details on how to cite NIFTy is provided on the following pages: .. toctree:: @@ -28,15 +33,11 @@ In-depth discussion of important concepts in NIFTy and details on how to cite NI approximate_inference citations -In the following are guides to the old NIFTy. +In the following are guides to the NumPy-based NIFTy. They serve as reference but are likely not relevant anymore for new projects. .. toctree:: :maxdepth: 1 old_nifty - old_nifty_volume - old_nifty_design_principles - old_nifty_custom_nonlinearities - old_nifty_getting_started_0 - old_nifty_getting_started_4_CorrelatedFields + diff --git a/docs/source/user/niftyre_getting_started_4_CorrelatedFields.py b/docs/source/user/niftyre_getting_started_4_CorrelatedFields.py new file mode 100644 index 0000000000000000000000000000000000000000..9a459a99ffa7b3b202fc88a7af52cb5ce008a40a --- /dev/null +++ b/docs/source/user/niftyre_getting_started_4_CorrelatedFields.py @@ -0,0 +1,461 @@ +#!/usr/bin/env python +# coding: utf-8 +# # Showcasing the Correlated Field Model +# +# The field model works roughly like this: +# +# ``f = HT( A * zero_mode * xi ) + offset`` +# +# The correlated field is constructed using: +# +# \begin{equation*} +# cf = \verb|offset_mean| + \left(\bigotimes_{i} \frac{1}{V_i} HT_i \right) \left( \verb|zero_mode| \cdot (\bigotimes_{i} A_i (k)) +# \right) +# \end{equation*} +# +# where the outer product $\bigotimes_{i}$ is taken over all subdomains, $V_i$ is the volume of each sub-space, $HT_i$ is the harmonic transform over each sub-space +# +# `A` is a spectral power field which is constructed from power spectra that are defined on subdomains of the target domain. It is scaled by a zero mode operator and then pointwise multiplied by a Gaussian excitation field, yielding a representation of the field in harmonic space. It is then transformed into the target real space and a offset is added. +# +# The power spectra that `A` is constructed of, are in turn constructed as the sum of a power law component and an integrated Wiener process whose amplitude and roughness can be set. + +# ## Preliminaries + +# + +# Imports and Initializing dimensions and a seed + +# %matplotlib inline +import jax +import nifty8.re as jft +import matplotlib.pyplot as plt +from jax import numpy as jnp +from typing import Tuple +import numpy as np +import nifty8 as ift + +jax.config.update("jax_enable_x64", True) + +plt.rcParams["figure.dpi"] = 300 + +npix = 256 +seed = 42 +distances = 1 +key = jax.random.PRNGKey(seed) +k_lengths = jnp.arange(0, npix) * (1 / npix) +totvol = jnp.prod(jnp.array(npix) * jnp.array(distances)) +realisations = 5 + + +# ## The Moment Matched Log-Normal Distribution +# +# Many properties of the correlated field are modelled as being lognormally distributed. +# +# The distribution models are parametrized via their means and standard-deviations (first and second position in tuple). +# +# To get a feeling of how the ratio of the `mean` and `stddev` parameters influences the distribution shape, here are a few example histograms: (observe the x-axis!) + +# + +fig = plt.figure(figsize=(13, 3.5)) +mean = 1.0 +sigmas = [1.0, 0.5, 0.1] + + +for i in range(3): + op = jft.LogNormalPrior(mean=mean, std=sigmas[i], name="foo") + op_samples = np.zeros(10000) + for j in range(10000): + key, signalk = jax.random.split(key, num=2) + s = jft.random_like(signalk, op.domain) + op_samples[j] = op(s) + + ax = fig.add_subplot(1, 3, i + 1) + ax.hist(op_samples, bins=50) + ax.set_title(f"mean = {mean}, sigma = {sigmas[i]}") + ax.set_xlabel("x") + del op_samples + + +plt.show() + +# + +# Making the Correlated Field in Nifty.re + + +def fieldmaker(npix: Tuple, distances: Tuple, matern, **args): + cf = jft.CorrelatedFieldMaker("") + cf.set_amplitude_total_offset( + offset_mean=args["offset_mean"], offset_std=args["offset_std"] + ) + args.pop("offset_mean") + args.pop("offset_std") + # There are two choices to the kwarg non_parametric_kind, power and amplitude. NIFTy.re's default is amplitude. + cf.add_fluctuations_matern( + npix, + distances, + non_parametric_kind="power", + renormalize_amplitude=False, + **args, + ) if matern else cf.add_fluctuations( + npix, distances, non_parametric_kind="power", **args + ) + cf_model = cf.finalize() + + return cf_model, cf.power_spectrum + + +def vary_parameter(parameter, values, matern, **args): + global key + for i, j in enumerate(values): + syn_data = np.zeros(shape=(npix, realisations)) + syn_pow = np.zeros(shape=(int(npix / 2 + 1), realisations)) + args[parameter] = j + fig = plt.figure(tight_layout=True, figsize=(10, 3)) + fig.suptitle(f"{parameter} = {j}") + ax1 = fig.add_subplot(1, 2, 1) + ax1.set_title("Field Realizations") + ax1.set_ylim( + -4.0, + 4, + ) + ax2 = fig.add_subplot(1, 2, 2) + ax2.set_xscale("log") + ax2.set_yscale("log") + ax2.set_title("Power Spectra") + # Plotting different realisations for each field. + for k in range(realisations): + cf_model, pow_cf = fieldmaker(npix, distances, matern, **args) + key, signalk = jax.random.split(key, num=2) + syn_signal = jft.random_like(signalk, cf_model.domain) + syn_data[:, k] = cf_model(syn_signal) + syn_pow[:, k] = pow_cf(syn_signal) + ax1.plot(k_lengths, syn_data[:, k], linewidth=1) + if not matern: + ax2.plot( + np.arange(len(k_lengths) / 2 + 1), + np.sqrt(syn_pow[:, k] / totvol ** 2), + linewidth=1, + ) + ax2.set_ylim(1e-6, 2.0) + else: + ax2.plot(np.arange(len(k_lengths) / 2 + 1), syn_pow[:, k], linewidth=1) + ax2.set_ylim(1e-1, 1e2) + + +# - + +# ### The Amplitude Spectrum in NIFTy.re +# +# The correlation operator $A$ which used to transform a to-be-inferred signal $s$ to standardized $\xi$ coordinates, is given by +# $ s = A \xi $ +# +# and A is defined as (see below for derivation) +# +# $$ A \mathrel{\mathop{:}}= F^{-1}\hat{S}^{1/2} $$ +# +# Where $F^{-1}$ is the inverse Fourier transform, and $\hat{S}$ is the diagonalized correlation structure in harmonic space. +# +# The norm is defined as: +# +# \begin{equation*} +# \text{norm} = \sqrt{\frac{\int \textnormal{d}k p(k)}{V}} +# \end{equation*} +# +# The amplitude spectrum $\text{amp(k)}$ is then +# \begin{align*} +# \text{amp(k)} = \frac{\text{fluc} \cdot \sqrt{V}}{\text{norm}} \cdot \sqrt{p(k)} \\ +# \text{amp(k)} = \frac{\text{fluc} \cdot V \cdot \sqrt{p(k)}}{\sqrt{\int \textnormal{d}k p(k)}} +# \end{align*} +# +# The power spectrum is just the amplitude spectrum squared: +# +# \begin{align*} +# p(k) &= \text{amp}^2(k) \\ +# & = \frac{\text{fluc}^2 \cdot V^2 \cdot p(k)}{\int \textnormal{d}k p(k)} \\ +# \int \textnormal{d}k p(k) & = \text{fluc}^2 \cdot V^2 +# \end{align*} +# +# Hence, the fluctuations in `NIFTy.re` are given by +# +# \begin{equation*} +# \text{fluc} = \sqrt{\frac{\int \textnormal{d}k p(k)}{V^2}} +# \end{equation*} +# +# This is different as compared to the Numpy-based NIFTy, where: +# +# \begin{equation*} +# \text{fluc} = \int \textnormal{d}k p(k) +# \end{equation*} +# +# + +# ## The Neutral Field + +# + +cf_args = { + "fluctuations": (1e-3, 1e-16), + "loglogavgslope": (0.0, 1e-16), + "flexibility": (1e-3, 1e-16), + "asperity": (1e-3, 1e-16), + "prefix": "", + "offset_mean": 0.0, + "offset_std": (1e-3, 1e-16), +} + +cf_model, pow_cf = fieldmaker(npix, distances, matern=False, **cf_args) + +key, signalk = jax.random.split(key, num=2) +syn_signal = jft.random_like(signalk, cf_model.domain) +syn_data = cf_model(syn_signal) +syn_pow = pow_cf(syn_signal) +fig = plt.figure(tight_layout=True, figsize=(10, 3)) +ax1 = fig.add_subplot(1, 2, 1) +ax1.set_title("Field Realizations") +ax1.set_ylim( + -4.0, + 4, +) +ax2 = fig.add_subplot(1, 2, 2) +ax2.set_xscale("log") +ax2.set_yscale("log") +ax2.set_title("Power Spectra") +ax2.set_ylim(1e-6, 2.0) +ax1.plot(k_lengths, syn_data, linewidth=1) +ax2.plot(np.arange(len(k_lengths) / 2 + 1), np.sqrt(syn_pow / totvol ** 2), linewidth=1) + +# - + +# ## The `fluctuations` parameters of `add_fluctuations()` +# +# `fluctuations` determine the **amplitude of variations** along the field dimension for which add_fluctuations is called. +# +# `fluctuations[0]` set the average amplitude of the fields fluctuations along the given dimension, +# `fluctuations[1]` sets the width and shape of the amplitude distribution. +# + +# ## `fluctuations` mean + +vary_parameter( + "fluctuations", [(0.05, 1e-16), (0.5, 1e-16), (1.0, 1e-16)], matern=False, **cf_args +) +cf_args["fluctuations"] = (1.0, 1e-16) + +# ## `fluctuations` std + +vary_parameter( + "fluctuations", [(1.0, 0.01), (1.0, 0.1), (1.0, 1.0)], matern=False, **cf_args +) +cf_args["fluctuations"] = (1.0, 1e-16) + + +# ## The `loglogavgslope` parameters of `add_fluctuations()` +# +# The value of `loglogavgslope` determines the __slope of the loglog-linear (power law) component of the power spectrum.__ +# +# The slope is modelled to be normally distributed. + +# ## `loglogavgslope` mean + +vary_parameter( + "loglogavgslope", + [(-6.0, 1e-16), (-2.0, 1e-16), (2.0, 1e-16)], + matern=False, + **cf_args, +) + +# ## `loglogavgslope` std + +vary_parameter( + "loglogavgslope", [(-2.0, 0.02), (-2.0, 0.2), (-2.0, 2.0)], matern=False, **cf_args +) +cf_args["loglogavgslope"] = (-2.0, 1e-16) + +# ## The `flexibility` parameters of `add_fluctuations()` +# +# Values for `flexibility` determine the __amplitude of the integrated Wiener process component of the power spectrum__ (how strong the power spectrum varies besides the power-law). +# +# `flexibility[0]` sets the _average_ amplitude of the i.g.p. component, +# `flexibility[1]` sets how much the amplitude can vary. +# These two parameters feed into a moment-matched log-normal distribution model, see above for a demo of its behavior. + +# ## `flexibility` mean + +vary_parameter( + "flexibility", [(0.4, 1e-16), (4.0, 1e-16), (12.0, 1e-16)], matern=False, **cf_args +) + +# ## `flexibility` std + +vary_parameter( + "flexibility", [(4.0, 0.02), (4.0, 0.2), (4.0, 2.0)], matern=False, **cf_args +) +cf_args["flexibility"] = (4.0, 1e-16) + +# ## The `asperity` parameters of `add_fluctuations()` +# +# `asperity` determines how __rough the integrated Wiener process component of the power spectrum is.__ +# +# `asperity[0]` sets the average roughness, `asperity[1]` sets how much the roughness can vary. +# These two parameters feed into a moment-matched log-normal distribution model, see above for a demo of its behavior. +# +# + +# ## `asperity` mean + +vary_parameter( + "asperity", [(0.001, 1e-16), (1.0, 1e-16), (5.0, 1e-16)], matern=False, **cf_args +) + +# ## `asperity` std + +vary_parameter( + "asperity", [(1.0, 0.01), (1.0, 0.1), (1.0, 1.0)], matern=False, **cf_args +) +cf_args["asperity"] = (1.0, 1e-16) + +# ## The `offset_mean` parameter of `CorrelatedFieldMaker()` +# +# The `offset_mean` parameter defines a global additive offset on the field realizations. +# +# If the field is used for a lognormal model `f = field.exp()`, this acts as a global signal magnitude offset. +# +# + +# Reset model to neutral +cf_args["fluctuations"] = (1e-3, 1e-16) +cf_args["flexibility"] = (1e-3, 1e-16) +cf_args["asperity"] = (1e-3, 1e-16) +cf_args["loglogavgslope"] = (1e-3, 1e-16) + +vary_parameter("offset_mean", [3.0, 0.0, -2.0], matern=False, **cf_args) + +vary_parameter( + "offset_std", [(1e-16, 1e-16), (0.5, 1e-16), (2.0, 1e-16)], matern=False, **cf_args +) + +vary_parameter( + "offset_std", [(1.0, 0.01), (1.0, 0.1), (1.0, 1.0)], matern=False, **cf_args +) + +# ## Matern Fluctuation Kernels +# +# The correlated fields model also supports parametrizing the power spectra of field dimensions using Matern kernels. In the following, the effects of their parameters are demonstrated. +# +# Contrary to the field fluctuations parametrization showed above, the Matern kernel parameters show strong interactions. For example, the field amplitude does not only depend on the amplitude scaling parameter `scale`, but on the combination of all three parameters `scale`, `cutoff` and `loglogavgslope`. + +# + +# Neutral model parameters yielding a quasi-constant field + +cf_args_matern = { + "scale": (1e-2, 1e-16), + "cutoff": (1.0, 1e-16), + "loglogslope": (-2.0, 1e-16), + "prefix": "", + "offset_mean": 0.0, + "offset_std": (1e-3, 1e-16), +} +# - + +vary_parameter( + "scale", [(0.01, 1e-16), (0.1, 1e-16), (1.0, 1e-16)], matern=True, **cf_args_matern +) + +vary_parameter( + "scale", [(0.5, 0.01), (0.5, 0.1), (0.5, 0.5)], matern=True, **cf_args_matern +) +cf_args_matern["scale"] = (0.5, 1e-16) + +vary_parameter( + "cutoff", [(10.0, 1.0), (10.0, 3.16), (10.0, 100.0)], matern=True, **cf_args_matern +) + +# ## Theory of the Correlated Field Model +# +# - Want to infer a signal $s$, and have a (multivariate Gaussian) prior $\mathcal{G}(s,S)$ +# - Inference algorithms are sensitive to coordinates, more efficient to infer $\mathcal{G}(\xi, 1)$ as opposed to $\mathcal{G}(s,S)$ +# - Can find a coordinate transform (known as the amplitude transform) which changes the distribution to standardized coordinates $\xi \hookleftarrow \mathcal{G}(\xi, 1)$. +# - This transform is given by $ s = F^{-1}\hat{S}^{1/2}\xi \mathrel{\mathop{:}}= A\xi $ +# +# We can now define a Correlation Operator $A$ as +# \begin{equation} +# A \mathrel{\mathop{:}}= F^{-1}\hat{S}^{1/2} +# \end{equation} +# which relates $\xi$ from the latent space to $s$ in physical space. +# +# The correlation structure $S$ and $A$ are related as, +# \begin{equation} +# S = A A^{\dagger} +# \end{equation} +# +# +# (Here we assume that the correlation structure $S$ is statistically homogenous and stationary and isotropic (which means that $S$ does not have a preferred location or direction a priori), which then according to the Wiener-Khintchin theorem can be represented as a diagonal matrix in Harmonic (Fourier) Space.) +# +# Then, +# +# $S$ can be written as $S = F^\dagger (\hat{p_S}) F $ where $\hat{p}$ is a diagonal matrix with the power spectra values for each mode on the diagonal. +# +# From the relation given above, the amplitude spectrum $p_A$ is related to the power spectrum as: +# $$ p_A(k) = \sqrt{p_S(k)} $$ +# $$ C_s(k) = \lim\limits_{V \to \infty} \frac{1}{V} \left< \left| \int_{V} dx s^{x} e^{ikx} \right| \right>_{s}$$ +# +# +# If we do not have enough information about the correlation structure $S$ (and consequently the amplitude $A$), we build a model for it and learn it using data. Again, assuming statistical homogeneity a priori, then the Wiener Khintchin theorem yields: +# +# $$ A^{kk'} = (2\pi)^{d} \delta(k - k') p_A(k) $$ +# +# We can now non-parametrically model $p_A(k)$, with the only constraint being the positivity of the power spectrum. To enforce positivity, let us model $\gamma(k)$ where +# $$ p_A(k) \propto e^{\gamma(k)} $$ +# +# $\gamma(k)$ is modeled with an Integrated Wiener Process in $l = log|k|$ coordinates. In logarithmic coordinates, the zero mode $|k| = 0$ is infinitely far away from the other modes, hence it is treated separately. +# +# The integrated Wiener process is a stochastic process that is described by the equation: +# +# $$ \frac{d^2 \gamma(k)}{dl^2} = \eta (l) $$ +# where $\eta(l)$ is standard Gaussian distributed. +# +# One way to solve the differential equation is by splitting the single second order differential equation into two first order differential equations by substitution. +# +# $$ v = \frac{d \gamma(k)}{dl} $$ +# $$ \frac{dv}{dl} = \eta(l) $$ +# +# Integrating both equations wrt $l$ results in: +# +# \begin{equation} +# \int_{l_0}^{l} v(l') dl' = \gamma(k) - \gamma(0) +# \end{equation} +# +# \begin{equation} +# v(l') - v(l_0) = \int_{l_0}^{l'} \eta(l'') dl'' +# \end{equation} +# +# +# +# +# +# +# + +# ## Mathematical intuition for the Fluctuations parameter +# +# The two-point correlation function between two locations $x$ and $y$ is given by $S^{xy}$ which is a continuous function of the distance between the two points, $x-y$ assuming a priori statistical homogeneity. $$ S^{xy}= C_s(x-y) \mathrel{\mathop{:}}= C_s(r)$$ +# +# When this is the case, the correlation structure is diagonalized in harmonic space and is described fully by the Power Spectrum $P_s(k)$ i.e, +# +# \begin{align*} +# S^{xy} & = (F^\dagger)^r_k P_s (k) \\ +# & = \int \frac{\mathop{dk}}{(2\pi)^u} \exp{[-ikr]} P_s(k) \\ +# \end{align*} +# +# Then, the auto correlation +# +# \begin{align*} +# S^{xx} & = \left< s^x s^{x*}\right> \\ +# & = C_s(0) = \int \frac{dk}{(2\pi)^u} e^{0} P_s(k) \\ +# \int P_s(k) \mathop{dk} & = \left< |s^x|^2 \right> +# +# \end{align*} +# +# This $P_s(k)$ is modified to a power spectrum with added fluctuations, +# +# \begin{align*} +# P_s ' = \frac{P_s}{\sqrt{\int \mathop{dk} P_s(k)}} a^2 +# \end{align*} diff --git a/docs/source/user/old_nifty.rst b/docs/source/user/old_nifty.rst index 0e3d69dcacd30164816e81b20103877d4f606cdb..ea1e82f2f7a96f35861c309b6fea9316d68c8954 100644 --- a/docs/source/user/old_nifty.rst +++ b/docs/source/user/old_nifty.rst @@ -1,5 +1,17 @@ +NumPy-Based NIFTy +================== + +.. toctree:: + :maxdepth: 1 + + old_nifty_volume + old_nifty_design_principles + old_nifty_custom_nonlinearities + old_nifty_getting_started_0 + old_nifty_getting_started_4_CorrelatedFields + What is NIFTy? -============== +-------------- **NIFTy** [1]_ [2]_ [3]_, "\ **N**\umerical **I**\nformation **F**\ield **T**\heor\ **y**\ ", is a versatile library designed to enable the development of signal inference algorithms that are independent of the underlying grids (spatial, spectral, temporal, …) and their resolutions. Its object-oriented framework is written in Python, although it accesses libraries written in C++ and C for efficiency. diff --git a/paper/paper.bib b/docs/source/user/paper.bib similarity index 99% rename from paper/paper.bib rename to docs/source/user/paper.bib index c54764190e33dd0eaad9b3cc52784e36133771df..1d0148fc8eaf2f3fdeb6ba22875790b21fd6f46a 100644 --- a/paper/paper.bib +++ b/docs/source/user/paper.bib @@ -15,7 +15,7 @@ adsurl = {https://ui.adsabs.harvard.edu/abs/2019A&A...627A.134A}, adsnote = {Provided by the SAO/NASA Astrophysics Data System} } -@software{Arras2019NIFTy, +@misc{Arras2019NIFTy, author = {{Arras}, Philipp and {Baltac}, Mihai and {Ensslin}, Torsten A. and {Frank}, Philipp and {Hutschenreuter}, Sebastian and {Knollmueller}, Jakob and {Leike}, Reimar and {Newrzella}, Max-Niklas and {Platz}, Lukas and {Reinecke}, Martin and {Stadler}, Julia}, title = {{NIFTy5: Numerical Information Field Theory v5}}, howpublished = {Astrophysics Source Code Library, record ascl:1903.008}, @@ -96,6 +96,17 @@ adsurl = {https://ui.adsabs.harvard.edu/abs/2022ApJ...935..167A}, adsnote = {Provided by the SAO/NASA Astrophysics Data System} } +@article{Blondel2021, + author = {Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy and Hoyer, Stephan and Llinares-Lopez, Felipe and Pedregosa, Fabian and Vert, Jean-Philippe}, + journal = {Advances in Neural Information Processing Systems}, + editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh}, + pages = {5230--5242}, + publisher = {Curran Associates, Inc.}, + title = {Efficient and Modular Implicit Differentiation}, + url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/228b9279ecf9bbafe582406850c57115-Paper-Conference.pdf}, + volume = {35}, + year = {2022} +} @article{Bingham2019, author = {Eli Bingham and Jonathan P. Chen and Martin Jankowiak and Fritz Obermeyer and Neeraj Pradhan and Theofanis Karaletsos and Rohit Singh and Paul A. Szerlip and Paul Horsfall and Noah D. Goodman}, title = {Pyro: Deep Universal Probabilistic Programming}, @@ -105,24 +116,13 @@ year = {2019}, url = {http://jmlr.org/papers/v20/18-403.html} } -@software{blackjax2020, - author = {Cabezas, Alberto, Lao, Junpeng, and Louf, R\'emi}, +@misc{blackjax2020, + author = {Cabezas and Alberto and Lao and Junpeng and Louf and R\'emi}, title = {{B}lackjax: A sampling library for {JAX}}, url = {http://github.com/blackjax-devs/blackjax}, version = {v1.1.0}, year = {2023} } -@article{Blondel2021, - author = {Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy and Hoyer, Stephan and Llinares-Lopez, Felipe and Pedregosa, Fabian and Vert, Jean-Philippe}, - booktitle = {Advances in Neural Information Processing Systems}, - editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh}, - pages = {5230--5242}, - publisher = {Curran Associates, Inc.}, - title = {Efficient and Modular Implicit Differentiation}, - url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/228b9279ecf9bbafe582406850c57115-Paper-Conference.pdf}, - volume = {35}, - year = {2022} -} @article{Carpenter2017, title = {Stan: A Probabilistic Programming Language}, volume = {76}, @@ -134,13 +134,13 @@ year = {2017}, pages = {1-32} } -@software{Deepmind2020Optax, +@misc{Deepmind2020Optax, title = {The {D}eep{M}ind {JAX} {E}cosystem}, author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio}, url = {http://github.com/google-deepmind}, year = {2020} } -@software{ducc0, +@misc{ducc0, author = {Martin Reinecke}, title = {{DUCC}: Distinctly Useful Code Collection}, url = {https://gitlab.mpcdf.mpg.de/mtr/ducc}, @@ -237,7 +237,7 @@ url = {https://arxiv.org/abs/1703.09710} } -@software{ForemanMackey2024, +@misc{ForemanMackey2024, author = {Foreman-Mackey, Daniel and Weixiang Yu and Yadav, Sachin and Becker, McCoy Reynolds and Caplar, Neven and Huppenkothen, Daniela and Killestein, Thomas and Tronsgaard, René and Rashid, Theo and Schmerler, Steve}, title = {{dfm/tinygp: The tiniest of Gaussian Process libraries}}, month = jan, @@ -429,7 +429,7 @@ adsurl = {https://ui.adsabs.harvard.edu/abs/2023arXiv230412350H}, adsnote = {Provided by the SAO/NASA Astrophysics Data System} } -@software{Jax2018, +@misc{Jax2018, author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, url = {http://github.com/google/jax}, @@ -452,7 +452,7 @@ year = {2019}, copyright = {arXiv.org perpetual, non-exclusive license} } -@software{Koposov2023, +@misc{Koposov2023, author = {Sergey Koposov and Josh Speagle and Kyle Barbary and Gregory Ashton and Ed Bennett and Johannes Buchner and Carl Scheffler and Ben Cook and Colm Talbot and James Guillochon and Patricio Cubillos and Andrés Asensio Ramos and Ben Johnson and Dustin Lang and Ilya and Matthieu Dartiailh and Alex Nitz and Andrew McCluskey and Anne Archibald}, title = {joshspeagle/dynesty: v2.1.3}, month = oct, @@ -768,3 +768,4 @@ journal = {Journal of Open Source Software} } + diff --git a/docs/source/user/paper.md b/docs/source/user/paper.md deleted file mode 120000 index 9785e6c38b9caa2baadb46a14319f883165e417f..0000000000000000000000000000000000000000 --- a/docs/source/user/paper.md +++ /dev/null @@ -1 +0,0 @@ -../../../paper/paper.md \ No newline at end of file diff --git a/docs/source/user/paper.rst b/docs/source/user/paper.rst new file mode 100644 index 0000000000000000000000000000000000000000..f3489f92e9a9dce1184af87ca57abf504ebccb7a --- /dev/null +++ b/docs/source/user/paper.rst @@ -0,0 +1,420 @@ +.. raw:: html + + <!-- + ## JAX + NIFTy Paper + * USP: selling point: speed + * Bonus: higher order diff for more efficient optimization and all of Tensorflow and Tensorflow for all + * GP + * Regular Grid Refinement + * KISS-GP + * Grid Refinement + * Posterior Approx. + * HMC but with variable dtype handling + * JIT-able VI and also (indirectly) available for Tensorflow + * predecessor enabled 100B reconstruction + * middle ground between tools like blackjax and pymc + --> + +Publication +=========== + +Summary +------- + +Imaging is the process of transforming noisy, incomplete data into a +space that humans can interpret. ``NIFTy`` is a Bayesian framework for +imaging and has already successfully been applied to many fields in +astrophysics. Previous design decisions held the performance and the +development of methods in ``NIFTy`` back. We present a rewrite of +``NIFTy``, coined ``NIFTy.re``, which reworks the modeling principle, +extends the inference strategies, and outsources much of the heavy +lifting to JAX. The rewrite dramatically accelerates models written in +``NIFTy``, lays the foundation for new types of inference machineries, +improves maintainability, and enables interoperability between ``NIFTy`` +and the JAX machine learning ecosystem. + +Statement of Need +----------------- + +Imaging commonly involves millions to billions of pixels. Each pixel +usually corresponds to one or more correlated degrees of freedom in the +model space. Modeling this many degrees of freedom is computationally +demanding. However, imaging is not only computationally demanding but +also statistically challenging. The noise in the data requires a +statistical treatment and needs to be accurately propagated from the +data to the uncertainties in the final image. To do this, we require an +inference machinery that not only handles extremely high-dimensional +spaces, but one that does so in a statistically rigorous way. + +NIFTy is a Bayesian imaging library :cite:`Selig2013`, :cite:`Steiniger2017`, :cite:`Arras2019NIFTy` designed to infer million- to billion-dimensional posterior distributions from noisy data. At its core are Gaussian Process (GP) models and Variational Inference (VI) algorithms. + + +``NIFTy.re`` is a rewrite of ``NIFTy`` in JAX :cite:`Jax2018` with all +relevant previous GP models, new, more flexible GP models, and a more +flexible machinery for approximating posterior distributions. Being +written in JAX, ``NIFTy.re`` effortlessly runs on accelerator hardware +such as the GPU and TPU, vectorizes models whenever possible, and +just-in-time compiles code for additional performance. ``NIFTy.re`` +switches from a home-grown automatic differentiation engine that was +used in ``NIFTy`` to JAX’s automatic differentiation engine. This lays +the foundation for new types of inference machineries that make use of +the higher order derivatives provided by JAX. Through these changes, we +envision to harness significant gains in maintainability of ``NIFTy.re`` +compared to ``NIFTy`` and a faster development cycle for new features. + +.. raw:: html + + <!-- Mention (if applicable) a representative set of past or ongoing research projects using the software and recent scholarly publications enabled by it. --> + +We expect ``NIFTy.re`` to be highly useful for many imaging applications +and envision many applications within and outside of astrophysics +:cite:`Arras2019, Arras2022, Leike2019, Leike2020, Mertsch2023, +Roth2023DirectionDependentCalibration, Hutschenreuter2023, +Tsouros2023, Roth2023FastCadenceHighContrastImaging, +Hutschenreuter2022, ScheelPlatz2023, Frank2017, Welling2021, +Westerkamp2023, Eberle2023ButterflyImaging, +Eberle2023ButterflyImagingAlgorithm`. ``NIFTy.re`` has already been +successfully used in two galactic tomography publications :cite:`Leike2022, +Edenhofer2023`. + + +.. raw:: html + + <!-- A list of key references, including to other software addressing related needs. Note that the references should include full names of venues, e.g., journals and conferences, not abbreviations only understood in the context of a specific discipline. --> + +NIFTy.re competes with other GP libraries as well as with probabilistic programming languages and frameworks. +Compared to GPyTorch :cite:`Hensman2015`, GPflow :cite:`Matthews2017`, george :cite:`Sivaram2015`, or TinyGP :cite:`ForemanMackey2024`, +NIFTy.re focuses on GP models for structured spaces and does not assume the posterior to be analytically accessible. +Instead, NIFTy.re tries to approximate the true posterior using VI. Compared to classical +probabilistic programming languages such as Stan :cite:`Carpenter2017` and frameworks such as +Pyro :cite:`Bingham2019`, NumPyro :cite:`Phan2019`, pyMC3 :cite:`Salvatier2016`, +emcee :cite:`ForemanMackey2013`, dynesty :cite:`Speagle2020, Koposov2023`, or BlackJAX :cite:`blackjax2020`, +NIFTy.re focuses on inference in extremely high-dimensional spaces. NIFTy.re +exploits the structure of probabilistic models in its VI techniques :cite:`Frank2021`. +With NIFTy.re, the GP models and the VI machinery are now fully accessible in the JAX ecosystem and +NIFTy.re components interact seamlessly with other JAX packages such as BlackJAX and JAXopt/Optax :cite:`Blondel2021, Deepmind2020Optax`. + +Core Components +--------------- + +``NIFTy.re`` brings tried and tested structured GP models and VI +algorithms to JAX. GP models are highly useful for imaging problems, and +VI algorithms are essential to probe high-dimensional posteriors, which +are often encountered in imaging problems. ``NIFTy.re`` infers the +parameters of interest from noisy data via a stochastic mapping that +goes in the opposite direction, from the parameters of interest to the +data. + +``NIFTy`` and ``NIFTy.re`` build up hierarchical models for the +posterior inference. The log-posterior function reads +:math:`\ln{p(\theta|d)} := \ell(d, f(\theta)) + \ln{p}(\theta) + \mathrm{const}` +with log-likelihood :math:`\ell`, forward model :math:`f` mapping the +parameters of interest :math:`\theta` to the data space, and log-prior +:math:`\ln{p(\theta)}`. The goal of the inference is to draw samples +from the posterior :math:`p(\theta|d)`. + +What is considered part of the likelihood versus part of the prior is +ill-defined. Without loss of generality, ``NIFTy`` and ``NIFTy.re`` +re-formulate models such that the prior is always standard Gaussian. +They implicitly define a mapping from a new latent space with a priori +standard Gaussian parameters :math:`\xi` to the parameters of interest +:math:`\theta`. The mapping :math:`\theta(\xi)` is incorporated into the +forward model :math:`f(\theta(\xi))` in such a way that all relevant +details of the prior model are encoded in the forward model. This choice +of re-parameterization :cite:`Rezende2015` is called standardization. It is +often carried out implicitly in the background without user input. + +Gaussian Processes +------------------ + +One standard tool from the ``NIFTy.re`` toolbox is the so-called +correlated field GP model from ``NIFTy``. This model relies on the +harmonic domain being easily accessible. For example, for pixels spaced +on a regular Cartesian grid, the natural choice to represent a +stationary kernel is the Fourier domain. In the generative picture, a +realization :math:`s` drawn from a GP then reads +:math:`s = \mathrm{FT} \circ \sqrt{P} \circ \xi` with +:math:`\mathrm{FT}` the (fast) Fourier transform, :math:`\sqrt{P}` the +square-root of the power-spectrum in harmonic space, and :math:`\xi` +standard Gaussian random variables. In the implementation in +``NIFTy.re`` and ``NIFTy``, the user can choose between two adaptive +kernel models, a non-parametric kernel :math:`\sqrt{P}` and a Matérn +kernel :math:`\sqrt{P}` :cite:`Arras2022, Guardiani2022` for details on their +implementation]. A code example that initializes a non-parametric GP +prior for a :math:`128 \times 128` space with unit volume is shown in +the following. + +.. code:: python + + from nifty8 import re as jft + + dims = (128, 128) + cfm = jft.CorrelatedFieldMaker("cf") + cfm.set_amplitude_total_offset(offset_mean=2, offset_std=(1e-1, 3e-2)) + # Parameters for the kernel and the regular 2D Cartesian grid for which + # it is defined + cfm.add_fluctuations( + dims, + distances=tuple(1.0 / d for d in dims), + fluctuations=(1.0, 5e-1), + loglogavgslope=(-3.0, 2e-1), + flexibility=(1e0, 2e-1), + asperity=(5e-1, 5e-2), + prefix="ax1", + non_parametric_kind="power", + ) + # Get the forward model for the GP prior + correlated_field = cfm.finalize() + +Not all problems are well described by regularly spaced pixels. For more +complicated pixel spacings, ``NIFTy.re`` features Iterative Charted +Refinement :cite:`Edenhofer2022`, a GP model for arbitrarily deformed spaces. +This model exploits nearest neighbor relations on various coarsenings of +the discretized modeled space and runs very efficiently on GPUs. For +one-dimensional problems with arbitrarily spaced pixels, ``NIFTy.re`` +also implements multiple flavors of Gauss-Markov processes. + +Building Up Complex Models +-------------------------- + +Models are rarely just a GP prior. Commonly, a model contains at least a +few non-linearities that transform the GP prior or combine it with other +random variables. For building more complex models, ``NIFTy.re`` +provides a ``Model`` class that offers a somewhat familiar +object-oriented design yet is fully JAX compatible and functional under +the hood. The following code shows how to build a slightly more complex +model using the objects from the previous example. + +.. code:: python + + from jax import numpy as jnp + + + class Forward(jft.Model): + def __init__(self, correlated_field): + self._cf = correlated_field + # Tracks a callable with which the model can be initialized. This + # is not strictly required, but comes in handy when building deep + # models. Note, the init method (short for "initialization" method) + # is not to be confused with the prior, which is always standard + # Gaussian. + super().__init__(init=correlated_field.init) + + def __call__(self, x): + # NOTE, any kind of masking of the output, non-linear and linear + # transformation could be carried out here. Models can also be + # combined and nested in any way and form. + return jnp.exp(self._cf(x)) + + + forward = Forward(correlated_field) + + data = jnp.load("data.npy") + lh = jft.Poissonian(data).amend(forward) + +All GP models in ``NIFTy.re`` as well as all likelihoods behave like +instances of ``jft.Model``, meaning that JAX understands what it means +if a computation involves ``self``, other ``jft.Model`` instances, or +their attributes. In other words, ``correlated_field``, ``forward``, and +``lh`` from the code snippets shown here are all so-called pytrees in +JAX, and, for example, the following is valid code +``jax.jit(lambda l, x: l(x))(lh, x0)`` with ``x0`` some arbitrarily +chosen valid input to ``lh``. Inspired by equinox :cite:`Kidger2021`, +individual attributes of the class can be marked as non-static or static +via ``dataclass.field(metadata=dict(static=...))`` for the purpose of +compiling. Depending on the value, JAX will either treat the attribute +as an unknown placeholder or as a known concrete attribute and +potentially inline it during compilation. This mechanism is extensively +used in likelihoods to avoid inlining large constants such as the data +and to avoid expensive re-compilations whenever possible. + +Variational Inference +--------------------- + +``NIFTy.re`` is built for models with millions to billions of degrees of +freedom. To probe the posterior efficiently and accurately, ``NIFTy.re`` +relies on VI. Specifically, ``NIFTy.re`` implements Metric Gaussian +Variational Inference (MGVI) and its successor geometric Variational +Inference (geoVI) :cite:`Knollmueller2019, Frank2021, Frank2022`. At the +core of both MGVI and geoVI lies an alternating procedure in which one +switches between optimizing the Kullback–Leibler divergence for a +specific shape of the variational posterior and updating the shape of +the variational posterior. MGVI and geoVI define the variational +posterior via samples, specifically, via samples drawn around an +expansion point. The samples in MGVI and geoVI exploit model-intrinsic +knowledge of the posterior’s approximate shape, encoded in the Fisher +information metric and the prior curvature :cite:`Frank2021`. + +``NIFTy.re`` allows for much finer control over the way samples are +drawn and updated compared to ``NIFTy``. ``NIFTy.re`` exposes +stand-alone functions for drawing MGVI and geoVI samples from any +arbitrary model with a likelihood from ``NIFTy.re`` and a forward model +that is differentiable by JAX. In addition to stand-alone sampling +functions, ``NIFTy.re`` provides tools to configure and execute the +alternating Kullback–Leibler divergence optimization and sample adaption +at a lower abstraction level. These tools are provided in a +JAXopt/Optax-style optimizer class :cite:`Blondel2021, Deepmind2020Optax`. +A typical minimization with ``NIFTy.re`` is shown in the following. It +retrieves six independent, antithetically mirrored samples from the +approximate posterior via 25 iterations of alternating between +optimization and sample adaption. The final result is stored in the +``samples`` variable. A convenient one-shot wrapper for the code below +is ``jft.optimize_kl``. By virtue of all modeling tools in ``NIFTy.re`` +being written in JAX, it is also possible to combine ``NIFTy.re`` tools +with BlackJAX :cite:`blackjax2020` or any other posterior sampler in the JAX +ecosystem. + +.. code:: python + + from jax import random + + key = random.PRNGKey(42) + key, sk = random.split(key, 2) + # NIFTy is agnostic w.r.t. the type of inputs it gets as long as they + # support core arithmetic properties. Tell NIFTy to treat our parameter + # dictionary as a vector. + samples = jft.Samples(pos=jft.Vector(lh.init(sk)), samples=None) + + delta = 1e-4 + absdelta = delta * jft.size(samples.pos) + + opt_vi = jft.OptimizeVI(lh, n_total_iterations=25) + opt_vi_st = opt_vi.init_state( + key, + # Implicit definition for the accuracy of the KL-divergence + # approximation; typically on the order of 2-12 + n_samples=lambda i: 1 if i < 2 else (2 if i < 4 else 6), + # Parametrize the conjugate gradient method at the heart of the + # sample-drawing + draw_linear_kwargs=dict( + cg_name="SL", cg_kwargs=dict(absdelta=absdelta / 10.0, maxiter=100) + ), + # Parametrize the minimizer in the nonlinear update of the samples + nonlinearly_update_kwargs=dict( + minimize_kwargs=dict( + name="SN", xtol=delta, cg_kwargs=dict(name=None), maxiter=5 + ) + ), + # Parametrize the minimization of the KL-divergence cost potential + kl_kwargs=dict(minimize_kwargs=dict(name="M", xtol=delta, maxiter=35)), + sample_mode="nonlinear_resample", + ) + for i in range(opt_vi.n_total_iterations): + print(f"Iteration {i+1:04d}") + # Continuously update the samples of the approximate posterior + # distribution + samples, opt_vi_st = opt_vi.update(samples, opt_vi_st) + print(opt_vi.get_status_message(samples, opt_vi_st)) + +.. _fig_minimal_reconstruction_data_mean_std: +.. figure:: minimal_reconstruction_data_mean_std.png + :alt: Data (left), posterior mean (middle), and posterior uncertainty (right) for a simple toy example. + :align: center + + Data (left), posterior mean (middle), and posterior uncertainty (right) for a simple toy example. + +:numref:`fig_minimal_reconstruction_data_mean_std` shows an exemplary posterior reconstruction employing the above model. +The posterior mean agrees with the data but removes noisy structures. The posterior standard deviation is approximately equal to typical differences between the posterior mean and the data. + + +Performance of ``NIFTy.re`` compared to ``NIFTy`` +------------------------------------------------- + +We test the performance of ``NIFTy.re`` against ``NIFTy`` for the simple +yet representative model from above. To assess the performance, we +compare the time required to apply :math:`M_p := F_p + \mathbb{1}` to +random input with :math:`F_p` denoting the Fisher metric of the overall +likelihood at position :math:`p` and :math:`\mathbb{1}` the identity +matrix. Within ``NIFTy.re``, the Fisher metric of the overall likelihood +is decomposed into :math:`J_{f,p}^\dagger N^{-1} J_{f,p}` with +:math:`J_{f,p}` the implicit Jacobian of the forward model :math:`f` at +:math:`p` and :math:`N^{-1}` the Fisher-metric of the Poisson +likelihood. We choose to benchmark :math:`M_p` as a typical VI +minimization in ``NIFTy.re`` and ``NIFTy`` is dominated by calls to this +function. + +.. _fig_benchmark_nthreads=1+8_devices=cpu+gpu: +.. figure:: benchmark_nthreads=1+8_devices=cpu+gpu.png + :alt: Median evaluation time of applying the Fisher metric plus the identity metric to random input for ``NIFTy.re`` and ``NIFTy`` on the CPU (one and eight core(s) of an Intel Xeon Platinum 8358 CPU clocked at 2.60G Hz) and the GPU (A100 SXM4 80 GB HBM2). The quantile range from the 16%- to the 84%-quantile is obscured by the marker symbols. + + Median evaluation time of applying the Fisher metric plus the identity metric to random input for ``NIFTy.re`` and ``NIFTy`` on the CPU (one and eight core(s) of an Intel Xeon Platinum 8358 CPU clocked at 2.60G Hz) and the GPU (A100 SXM4 80 GB HBM2). The quantile range from the 16%- to the 84%-quantile is obscured by the marker symbols. + +:numref:`fig_benchmark_nthreads=1+8_devices=cpu+gpu` shows +the median evaluation time in ``NIFTy`` of applying :math:`M_p` to new, +random tangent positions and the evaluation time in ``NIFTy.re`` of +building :math:`M_p` and applying it to new, random tangent positions +for exponentially larger models. The 16%-quantiles and the 84%-quantiles +of the timings are obscured by the marker symbols. We chose to exclude +the build time of :math:`M_p` in ``NIFTy`` from the comparison, putting +``NIFTy`` at an advantage, as its automatic differentiation is built +around calls to :math:`M_p` with :math:`p` rarely varying. We ran the +benchmark on one CPU core, eight CPU cores, and on a GPU on a +compute-node with an Intel Xeon Platinum 8358 CPU clocked at 2.60G Hz +and an NVIDIA A100 SXM4 80 GB HBM2 GPU. The benchmark used +``jax==0.4.23`` and ``jaxlib==0.4.23+cuda12.cudnn89``. We vary the size +of the model by increasing the size of the two-dimensional square image +grid. + + +For small image sizes, ``NIFTy.re`` on the CPU is about one order of +magnitude faster than ``NIFTy``. Both reach about the same performance +at an image size of roughly 15,000 pixels and continue to perform +roughly the same for larger image sizes. The performance increases by a +factor of three to four with eight cores for ``NIFTy.re`` and ``NIFTy``, +although ``NIFTy.re`` is slightly better at using the additional cores. +On the GPU, ``NIFTy.re`` is consistently about one to two orders of +magnitude faster than ``NIFTy`` for images larger than 100,000 pixels. + +We believe the performance benefits of ``NIFTy.re`` on the CPU for small +models stem from the reduced Python overhead by just-in-time compiling +computations. At image sizes larger than roughly 15,000 pixels, both +evaluation times are dominated by the fast Fourier transform and are +hence roughly the same as both use the same underlying implementation +:cite:`ducc0`. Models in ``NIFTy.re`` and ``NIFTy`` are often well aligned +with GPU programming models and thus consistently perform well on the +GPU. Modeling components such as the new GP models implemented in +``NIFTy.re`` are even better aligned with GPU programming paradigms and +yield even higher performance gains :cite:`Edenhofer2022`. + +Conclusion +---------- + +``NIFTy.re`` implements the core GP and VI machinery of the Bayesian +imaging package ``NIFTy`` in JAX. The rewrite moves much of the +heavy-lifting from home-grown solutions to JAX, and we envision +significant gains in maintainability of ``NIFTy.re`` and a faster +development cycle moving forward. The rewrite accelerates typical models +written in ``NIFTy`` by one to two orders of magnitude, lays the +foundation for new types of inference machineries by enabling higher +order derivatives via JAX, and enables the interoperability of +``NIFTy``\ ’s VI and GP methods with the JAX machine learning ecosystem. + + +Acknowledgements +---------------- + +Gordian Edenhofer acknowledges support from the German Academic +Scholarship Foundation in the form of a PhD scholarship +(“Promotionsstipendium der Studienstiftung des Deutschen Volkesâ€). +Philipp Frank acknowledges funding through the German Federal Ministry +of Education and Research for the project “ErUM-IFT: +Informationsfeldtheorie für Experimente an Großforschungsanlagen†+(Förderkennzeichen: 05D23EO1). Jakob Roth acknowledges financial support +by the German Federal Ministry of Education and Research (BMBF) under +grant 05A20W01 (Verbundprojekt D-MeerKAT). Matteo Guardiani, Vincent +Eberle, and Margret Westerkamp acknowledge financial support from the +“Deutsches Zentrum für Luft- und Raumfahrt e.V.†(DLR) through the +project Universal Bayesian Imaging Kit (UBIK, Förderkennzeichen +50OO2103). Lukas Scheel-Platz acknowledges funding from the European +Research Council (ERC) under the European Union’s Horizon Europe +research and innovation programme under grant agreement No 101041936 +(EchoLux). + +Bibiography +----------- + + +.. bibliography:: paper.bib + :style: plain + + diff --git a/src/re/evidence_lower_bound.py b/src/re/evidence_lower_bound.py index e9cd6de8c8bc81c7faecf5055fe73f66c97ef309..20c7bf11d33662dacf89bc6d4123041b7c9b6fcc 100644 --- a/src/re/evidence_lower_bound.py +++ b/src/re/evidence_lower_bound.py @@ -240,9 +240,8 @@ def estimate_evidence_lower_bound( For further details we refer to: - Analytic geoVI parametrization: P. Frank et al., Geometric Variational - Inference <https://arxiv.org/pdf/2105.10470.pdf> (Sec. 5.1) + Inference <https://arxiv.org/pdf/2105.10470.pdf> (Sec. 5.1) - Conceptualization: A. Kostić et al. (manuscript in preparation). - """ if not isinstance(samples, Samples): raise TypeError("samples attribute should be of type `Samples`.")