Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • ift/nifty
  • g-philipparras/NIFTy
  • tpeters/NIFTy
  • g-neelshah/nifty
4 results
Show changes
Commits on Source (55)
Showing
with 1139 additions and 95 deletions
......@@ -16,7 +16,7 @@ RUN pip3 install --break-system-packages \
# Testing dependencies
pytest pytest-cov pytest-xdist \
# Documentation build dependencies
jupyter nbconvert jupytext sphinx pydata-sphinx-theme myst-parser
jupyter nbconvert jupytext sphinx pydata-sphinx-theme myst-parser sphinxcontrib-bibtex
# Create user (openmpi does not like to be run as root)
RUN useradd -ms /bin/bash runner
......
......@@ -177,7 +177,7 @@ To build the documentation locally, run:
```
sudo apt-get install dvipng jupyter-nbconvert texlive-latex-base texlive-latex-extra
pip install --user sphinx jupytext pydata-sphinx-theme myst-parser
pip install --user sphinx jupytext pydata-sphinx-theme myst-parser sphinxcontrib-bibtex
cd <nifty_directory>
bash docs/generate.sh
```
......
......@@ -2,9 +2,9 @@
set -e
FOLDER=docs/source/user/
FOLDER=docs/source/user
for FILE in ${FOLDER}/0_intro ${FOLDER}/old_nifty_getting_started_0 ${FOLDER}/old_nifty_getting_started_4_CorrelatedFields ${FOLDER}/old_nifty_custom_nonlinearities; do
for FILE in ${FOLDER}/0_intro ${FOLDER}/old_nifty_getting_started_0 ${FOLDER}/old_nifty_getting_started_4_CorrelatedFields ${FOLDER}/old_nifty_custom_nonlinearities ${FOLDER}/niftyre_getting_started_4_CorrelatedFields; do
if [ ! -f "${FILE}.md" ] || [ "${FILE}.ipynb" -nt "${FILE}.md" ]; then
jupytext --to ipynb "${FILE}.py"
jupyter-nbconvert --to markdown --execute --ExecutePreprocessor.timeout=None "${FILE}.ipynb"
......@@ -12,5 +12,5 @@ for FILE in ${FOLDER}/0_intro ${FOLDER}/old_nifty_getting_started_0 ${FOLDER}/ol
done
EXCLUDE="nifty8/logger.py"
sphinx-apidoc -e -o docs/source/mod nifty8 ${EXCLUDE}
sphinx-apidoc -e -d 1 -o docs/source/mod nifty8 ${EXCLUDE}
sphinx-build -b html docs/source/ docs/build/
......@@ -4,11 +4,14 @@ needs_sphinx = '3.2.0'
extensions = [
'sphinx.ext.napoleon', # Support for NumPy and Google style docstrings
'sphinx.ext.imgmath', # Render math as images
'sphinx.ext.mathjax', # Render math as images
'sphinx.ext.viewcode', # Add links to highlighted source code
'sphinx.ext.intersphinx', # Links to other sphinx docs (mostly numpy)
'myst_parser', # Parse markdown
]
'sphinxcontrib.bibtex',
]
bibtex_bibfiles = ['user/paper.bib']
master_doc = 'index'
myst_enable_extensions = [
......@@ -35,6 +38,7 @@ napoleon_use_admonition_for_references = True
napoleon_include_special_with_doc = True
imgmath_embed = True
numfig = True
project = u'NIFTy8'
copyright = u'2013-2022, Max-Planck-Society'
......
Citing NIFTy and the Magic Inside
=================================
Citing NIFTy
============
To cite the software library NIFTy, please use the following references:
......
......@@ -2,23 +2,28 @@
NIFTy user guide
================
This guide is an overview and explains the main idea behind NIFTy.
This guide is an overview and explains the main conceptual idea behind NIFTy (Numerical Information Field Theory).
More details on the API can be found at the `API reference <../mod/nifty8.html>`_.
In the following, an online version of the latest publication describing NIFTy is provided.
The publication gives a quick overview of the core concepts powering NIFTy and showcases a simple example:
The publication below gives a quick overview of the core concepts powering NIFTy and showcases a simple example.
.. toctree::
:maxdepth: 1
paper
The demonstration scripts which are part of NIFTy usually provide a good foundation for new projects in NIFTy:
.. toctree::
:maxdepth: 1
0_intro
The following page provides a demonstration of the Correlated field model in NIFTy.re:
.. toctree::
:maxdepth: 1
niftyre_getting_started_4_CorrelatedFields
In-depth discussion of important concepts in NIFTy and details on how to cite NIFTy is provided on the following pages:
.. toctree::
......@@ -28,15 +33,11 @@ In-depth discussion of important concepts in NIFTy and details on how to cite NI
approximate_inference
citations
In the following are guides to the old NIFTy.
In the following are guides to the NumPy-based NIFTy.
They serve as reference but are likely not relevant anymore for new projects.
.. toctree::
:maxdepth: 1
old_nifty
old_nifty_volume
old_nifty_design_principles
old_nifty_custom_nonlinearities
old_nifty_getting_started_0
old_nifty_getting_started_4_CorrelatedFields
#!/usr/bin/env python
# coding: utf-8
# # Showcasing the Correlated Field Model
#
# The field model works roughly like this:
#
# ``f = HT( A * zero_mode * xi ) + offset``
#
# The correlated field is constructed using:
#
# \begin{equation*}
# cf = \verb|offset_mean| + \left(\bigotimes_{i} \frac{1}{V_i} HT_i \right) \left( \verb|zero_mode| \cdot (\bigotimes_{i} A_i (k))
# \right)
# \end{equation*}
#
# where the outer product $\bigotimes_{i}$ is taken over all subdomains, $V_i$ is the volume of each sub-space, $HT_i$ is the harmonic transform over each sub-space
#
# `A` is a spectral power field which is constructed from power spectra that are defined on subdomains of the target domain. It is scaled by a zero mode operator and then pointwise multiplied by a Gaussian excitation field, yielding a representation of the field in harmonic space. It is then transformed into the target real space and a offset is added.
#
# The power spectra that `A` is constructed of, are in turn constructed as the sum of a power law component and an integrated Wiener process whose amplitude and roughness can be set.
# ## Preliminaries
# +
# Imports and Initializing dimensions and a seed
# %matplotlib inline
import jax
import nifty8.re as jft
import matplotlib.pyplot as plt
from jax import numpy as jnp
from typing import Tuple
import numpy as np
import nifty8 as ift
jax.config.update("jax_enable_x64", True)
plt.rcParams["figure.dpi"] = 300
npix = 256
seed = 42
distances = 1
key = jax.random.PRNGKey(seed)
k_lengths = jnp.arange(0, npix) * (1 / npix)
totvol = jnp.prod(jnp.array(npix) * jnp.array(distances))
realisations = 5
# ## The Moment Matched Log-Normal Distribution
#
# Many properties of the correlated field are modelled as being lognormally distributed.
#
# The distribution models are parametrized via their means and standard-deviations (first and second position in tuple).
#
# To get a feeling of how the ratio of the `mean` and `stddev` parameters influences the distribution shape, here are a few example histograms: (observe the x-axis!)
# +
fig = plt.figure(figsize=(13, 3.5))
mean = 1.0
sigmas = [1.0, 0.5, 0.1]
for i in range(3):
op = jft.LogNormalPrior(mean=mean, std=sigmas[i], name="foo")
op_samples = np.zeros(10000)
for j in range(10000):
key, signalk = jax.random.split(key, num=2)
s = jft.random_like(signalk, op.domain)
op_samples[j] = op(s)
ax = fig.add_subplot(1, 3, i + 1)
ax.hist(op_samples, bins=50)
ax.set_title(f"mean = {mean}, sigma = {sigmas[i]}")
ax.set_xlabel("x")
del op_samples
plt.show()
# +
# Making the Correlated Field in Nifty.re
def fieldmaker(npix: Tuple, distances: Tuple, matern, **args):
cf = jft.CorrelatedFieldMaker("")
cf.set_amplitude_total_offset(
offset_mean=args["offset_mean"], offset_std=args["offset_std"]
)
args.pop("offset_mean")
args.pop("offset_std")
# There are two choices to the kwarg non_parametric_kind, power and amplitude. NIFTy.re's default is amplitude.
cf.add_fluctuations_matern(
npix,
distances,
non_parametric_kind="power",
renormalize_amplitude=False,
**args,
) if matern else cf.add_fluctuations(
npix, distances, non_parametric_kind="power", **args
)
cf_model = cf.finalize()
return cf_model, cf.power_spectrum
def vary_parameter(parameter, values, matern, **args):
global key
for i, j in enumerate(values):
syn_data = np.zeros(shape=(npix, realisations))
syn_pow = np.zeros(shape=(int(npix / 2 + 1), realisations))
args[parameter] = j
fig = plt.figure(tight_layout=True, figsize=(10, 3))
fig.suptitle(f"{parameter} = {j}")
ax1 = fig.add_subplot(1, 2, 1)
ax1.set_title("Field Realizations")
ax1.set_ylim(
-4.0,
4,
)
ax2 = fig.add_subplot(1, 2, 2)
ax2.set_xscale("log")
ax2.set_yscale("log")
ax2.set_title("Power Spectra")
# Plotting different realisations for each field.
for k in range(realisations):
cf_model, pow_cf = fieldmaker(npix, distances, matern, **args)
key, signalk = jax.random.split(key, num=2)
syn_signal = jft.random_like(signalk, cf_model.domain)
syn_data[:, k] = cf_model(syn_signal)
syn_pow[:, k] = pow_cf(syn_signal)
ax1.plot(k_lengths, syn_data[:, k], linewidth=1)
if not matern:
ax2.plot(
np.arange(len(k_lengths) / 2 + 1),
np.sqrt(syn_pow[:, k] / totvol ** 2),
linewidth=1,
)
ax2.set_ylim(1e-6, 2.0)
else:
ax2.plot(np.arange(len(k_lengths) / 2 + 1), syn_pow[:, k], linewidth=1)
ax2.set_ylim(1e-1, 1e2)
# -
# ### The Amplitude Spectrum in NIFTy.re
#
# The correlation operator $A$ which used to transform a to-be-inferred signal $s$ to standardized $\xi$ coordinates, is given by
# $ s = A \xi $
#
# and A is defined as (see below for derivation)
#
# $$ A \mathrel{\mathop{:}}= F^{-1}\hat{S}^{1/2} $$
#
# Where $F^{-1}$ is the inverse Fourier transform, and $\hat{S}$ is the diagonalized correlation structure in harmonic space.
#
# The norm is defined as:
#
# \begin{equation*}
# \text{norm} = \sqrt{\frac{\int \textnormal{d}k p(k)}{V}}
# \end{equation*}
#
# The amplitude spectrum $\text{amp(k)}$ is then
# \begin{align*}
# \text{amp(k)} = \frac{\text{fluc} \cdot \sqrt{V}}{\text{norm}} \cdot \sqrt{p(k)} \\
# \text{amp(k)} = \frac{\text{fluc} \cdot V \cdot \sqrt{p(k)}}{\sqrt{\int \textnormal{d}k p(k)}}
# \end{align*}
#
# The power spectrum is just the amplitude spectrum squared:
#
# \begin{align*}
# p(k) &= \text{amp}^2(k) \\
# & = \frac{\text{fluc}^2 \cdot V^2 \cdot p(k)}{\int \textnormal{d}k p(k)} \\
# \int \textnormal{d}k p(k) & = \text{fluc}^2 \cdot V^2
# \end{align*}
#
# Hence, the fluctuations in `NIFTy.re` are given by
#
# \begin{equation*}
# \text{fluc} = \sqrt{\frac{\int \textnormal{d}k p(k)}{V^2}}
# \end{equation*}
#
# This is different as compared to the Numpy-based NIFTy, where:
#
# \begin{equation*}
# \text{fluc} = \int \textnormal{d}k p(k)
# \end{equation*}
#
#
# ## The Neutral Field
# +
cf_args = {
"fluctuations": (1e-3, 1e-16),
"loglogavgslope": (0.0, 1e-16),
"flexibility": (1e-3, 1e-16),
"asperity": (1e-3, 1e-16),
"prefix": "",
"offset_mean": 0.0,
"offset_std": (1e-3, 1e-16),
}
cf_model, pow_cf = fieldmaker(npix, distances, matern=False, **cf_args)
key, signalk = jax.random.split(key, num=2)
syn_signal = jft.random_like(signalk, cf_model.domain)
syn_data = cf_model(syn_signal)
syn_pow = pow_cf(syn_signal)
fig = plt.figure(tight_layout=True, figsize=(10, 3))
ax1 = fig.add_subplot(1, 2, 1)
ax1.set_title("Field Realizations")
ax1.set_ylim(
-4.0,
4,
)
ax2 = fig.add_subplot(1, 2, 2)
ax2.set_xscale("log")
ax2.set_yscale("log")
ax2.set_title("Power Spectra")
ax2.set_ylim(1e-6, 2.0)
ax1.plot(k_lengths, syn_data, linewidth=1)
ax2.plot(np.arange(len(k_lengths) / 2 + 1), np.sqrt(syn_pow / totvol ** 2), linewidth=1)
# -
# ## The `fluctuations` parameters of `add_fluctuations()`
#
# `fluctuations` determine the **amplitude of variations** along the field dimension for which add_fluctuations is called.
#
# `fluctuations[0]` set the average amplitude of the fields fluctuations along the given dimension,
# `fluctuations[1]` sets the width and shape of the amplitude distribution.
#
# ## `fluctuations` mean
vary_parameter(
"fluctuations", [(0.05, 1e-16), (0.5, 1e-16), (1.0, 1e-16)], matern=False, **cf_args
)
cf_args["fluctuations"] = (1.0, 1e-16)
# ## `fluctuations` std
vary_parameter(
"fluctuations", [(1.0, 0.01), (1.0, 0.1), (1.0, 1.0)], matern=False, **cf_args
)
cf_args["fluctuations"] = (1.0, 1e-16)
# ## The `loglogavgslope` parameters of `add_fluctuations()`
#
# The value of `loglogavgslope` determines the __slope of the loglog-linear (power law) component of the power spectrum.__
#
# The slope is modelled to be normally distributed.
# ## `loglogavgslope` mean
vary_parameter(
"loglogavgslope",
[(-6.0, 1e-16), (-2.0, 1e-16), (2.0, 1e-16)],
matern=False,
**cf_args,
)
# ## `loglogavgslope` std
vary_parameter(
"loglogavgslope", [(-2.0, 0.02), (-2.0, 0.2), (-2.0, 2.0)], matern=False, **cf_args
)
cf_args["loglogavgslope"] = (-2.0, 1e-16)
# ## The `flexibility` parameters of `add_fluctuations()`
#
# Values for `flexibility` determine the __amplitude of the integrated Wiener process component of the power spectrum__ (how strong the power spectrum varies besides the power-law).
#
# `flexibility[0]` sets the _average_ amplitude of the i.g.p. component,
# `flexibility[1]` sets how much the amplitude can vary.
# These two parameters feed into a moment-matched log-normal distribution model, see above for a demo of its behavior.
# ## `flexibility` mean
vary_parameter(
"flexibility", [(0.4, 1e-16), (4.0, 1e-16), (12.0, 1e-16)], matern=False, **cf_args
)
# ## `flexibility` std
vary_parameter(
"flexibility", [(4.0, 0.02), (4.0, 0.2), (4.0, 2.0)], matern=False, **cf_args
)
cf_args["flexibility"] = (4.0, 1e-16)
# ## The `asperity` parameters of `add_fluctuations()`
#
# `asperity` determines how __rough the integrated Wiener process component of the power spectrum is.__
#
# `asperity[0]` sets the average roughness, `asperity[1]` sets how much the roughness can vary.
# These two parameters feed into a moment-matched log-normal distribution model, see above for a demo of its behavior.
#
#
# ## `asperity` mean
vary_parameter(
"asperity", [(0.001, 1e-16), (1.0, 1e-16), (5.0, 1e-16)], matern=False, **cf_args
)
# ## `asperity` std
vary_parameter(
"asperity", [(1.0, 0.01), (1.0, 0.1), (1.0, 1.0)], matern=False, **cf_args
)
cf_args["asperity"] = (1.0, 1e-16)
# ## The `offset_mean` parameter of `CorrelatedFieldMaker()`
#
# The `offset_mean` parameter defines a global additive offset on the field realizations.
#
# If the field is used for a lognormal model `f = field.exp()`, this acts as a global signal magnitude offset.
#
#
# Reset model to neutral
cf_args["fluctuations"] = (1e-3, 1e-16)
cf_args["flexibility"] = (1e-3, 1e-16)
cf_args["asperity"] = (1e-3, 1e-16)
cf_args["loglogavgslope"] = (1e-3, 1e-16)
vary_parameter("offset_mean", [3.0, 0.0, -2.0], matern=False, **cf_args)
vary_parameter(
"offset_std", [(1e-16, 1e-16), (0.5, 1e-16), (2.0, 1e-16)], matern=False, **cf_args
)
vary_parameter(
"offset_std", [(1.0, 0.01), (1.0, 0.1), (1.0, 1.0)], matern=False, **cf_args
)
# ## Matern Fluctuation Kernels
#
# The correlated fields model also supports parametrizing the power spectra of field dimensions using Matern kernels. In the following, the effects of their parameters are demonstrated.
#
# Contrary to the field fluctuations parametrization showed above, the Matern kernel parameters show strong interactions. For example, the field amplitude does not only depend on the amplitude scaling parameter `scale`, but on the combination of all three parameters `scale`, `cutoff` and `loglogavgslope`.
# +
# Neutral model parameters yielding a quasi-constant field
cf_args_matern = {
"scale": (1e-2, 1e-16),
"cutoff": (1.0, 1e-16),
"loglogslope": (-2.0, 1e-16),
"prefix": "",
"offset_mean": 0.0,
"offset_std": (1e-3, 1e-16),
}
# -
vary_parameter(
"scale", [(0.01, 1e-16), (0.1, 1e-16), (1.0, 1e-16)], matern=True, **cf_args_matern
)
vary_parameter(
"scale", [(0.5, 0.01), (0.5, 0.1), (0.5, 0.5)], matern=True, **cf_args_matern
)
cf_args_matern["scale"] = (0.5, 1e-16)
vary_parameter(
"cutoff", [(10.0, 1.0), (10.0, 3.16), (10.0, 100.0)], matern=True, **cf_args_matern
)
# ## Theory of the Correlated Field Model
#
# - Want to infer a signal $s$, and have a (multivariate Gaussian) prior $\mathcal{G}(s,S)$
# - Inference algorithms are sensitive to coordinates, more efficient to infer $\mathcal{G}(\xi, 1)$ as opposed to $\mathcal{G}(s,S)$
# - Can find a coordinate transform (known as the amplitude transform) which changes the distribution to standardized coordinates $\xi \hookleftarrow \mathcal{G}(\xi, 1)$.
# - This transform is given by $ s = F^{-1}\hat{S}^{1/2}\xi \mathrel{\mathop{:}}= A\xi $
#
# We can now define a Correlation Operator $A$ as
# \begin{equation}
# A \mathrel{\mathop{:}}= F^{-1}\hat{S}^{1/2}
# \end{equation}
# which relates $\xi$ from the latent space to $s$ in physical space.
#
# The correlation structure $S$ and $A$ are related as,
# \begin{equation}
# S = A A^{\dagger}
# \end{equation}
#
#
# (Here we assume that the correlation structure $S$ is statistically homogenous and stationary and isotropic (which means that $S$ does not have a preferred location or direction a priori), which then according to the Wiener-Khintchin theorem can be represented as a diagonal matrix in Harmonic (Fourier) Space.)
#
# Then,
#
# $S$ can be written as $S = F^\dagger (\hat{p_S}) F $ where $\hat{p}$ is a diagonal matrix with the power spectra values for each mode on the diagonal.
#
# From the relation given above, the amplitude spectrum $p_A$ is related to the power spectrum as:
# $$ p_A(k) = \sqrt{p_S(k)} $$
# $$ C_s(k) = \lim\limits_{V \to \infty} \frac{1}{V} \left< \left| \int_{V} dx s^{x} e^{ikx} \right| \right>_{s}$$
#
#
# If we do not have enough information about the correlation structure $S$ (and consequently the amplitude $A$), we build a model for it and learn it using data. Again, assuming statistical homogeneity a priori, then the Wiener Khintchin theorem yields:
#
# $$ A^{kk'} = (2\pi)^{d} \delta(k - k') p_A(k) $$
#
# We can now non-parametrically model $p_A(k)$, with the only constraint being the positivity of the power spectrum. To enforce positivity, let us model $\gamma(k)$ where
# $$ p_A(k) \propto e^{\gamma(k)} $$
#
# $\gamma(k)$ is modeled with an Integrated Wiener Process in $l = log|k|$ coordinates. In logarithmic coordinates, the zero mode $|k| = 0$ is infinitely far away from the other modes, hence it is treated separately.
#
# The integrated Wiener process is a stochastic process that is described by the equation:
#
# $$ \frac{d^2 \gamma(k)}{dl^2} = \eta (l) $$
# where $\eta(l)$ is standard Gaussian distributed.
#
# One way to solve the differential equation is by splitting the single second order differential equation into two first order differential equations by substitution.
#
# $$ v = \frac{d \gamma(k)}{dl} $$
# $$ \frac{dv}{dl} = \eta(l) $$
#
# Integrating both equations wrt $l$ results in:
#
# \begin{equation}
# \int_{l_0}^{l} v(l') dl' = \gamma(k) - \gamma(0)
# \end{equation}
#
# \begin{equation}
# v(l') - v(l_0) = \int_{l_0}^{l'} \eta(l'') dl''
# \end{equation}
#
#
#
#
#
#
#
# ## Mathematical intuition for the Fluctuations parameter
#
# The two-point correlation function between two locations $x$ and $y$ is given by $S^{xy}$ which is a continuous function of the distance between the two points, $x-y$ assuming a priori statistical homogeneity. $$ S^{xy}= C_s(x-y) \mathrel{\mathop{:}}= C_s(r)$$
#
# When this is the case, the correlation structure is diagonalized in harmonic space and is described fully by the Power Spectrum $P_s(k)$ i.e,
#
# \begin{align*}
# S^{xy} & = (F^\dagger)^r_k P_s (k) \\
# & = \int \frac{\mathop{dk}}{(2\pi)^u} \exp{[-ikr]} P_s(k) \\
# \end{align*}
#
# Then, the auto correlation
#
# \begin{align*}
# S^{xx} & = \left< s^x s^{x*}\right> \\
# & = C_s(0) = \int \frac{dk}{(2\pi)^u} e^{0} P_s(k) \\
# \int P_s(k) \mathop{dk} & = \left< |s^x|^2 \right>
#
# \end{align*}
#
# This $P_s(k)$ is modified to a power spectrum with added fluctuations,
#
# \begin{align*}
# P_s ' = \frac{P_s}{\sqrt{\int \mathop{dk} P_s(k)}} a^2
# \end{align*}
NumPy-Based NIFTy
==================
.. toctree::
:maxdepth: 1
old_nifty_volume
old_nifty_design_principles
old_nifty_custom_nonlinearities
old_nifty_getting_started_0
old_nifty_getting_started_4_CorrelatedFields
What is NIFTy?
==============
--------------
**NIFTy** [1]_ [2]_ [3]_, "\ **N**\umerical **I**\nformation **F**\ield **T**\heor\ **y**\ ", is a versatile library designed to enable the development of signal inference algorithms that are independent of the underlying grids (spatial, spectral, temporal, …) and their resolutions.
Its object-oriented framework is written in Python, although it accesses libraries written in C++ and C for efficiency.
......
......@@ -15,7 +15,7 @@
adsurl = {https://ui.adsabs.harvard.edu/abs/2019A&A...627A.134A},
adsnote = {Provided by the SAO/NASA Astrophysics Data System}
}
@software{Arras2019NIFTy,
@misc{Arras2019NIFTy,
author = {{Arras}, Philipp and {Baltac}, Mihai and {Ensslin}, Torsten A. and {Frank}, Philipp and {Hutschenreuter}, Sebastian and {Knollmueller}, Jakob and {Leike}, Reimar and {Newrzella}, Max-Niklas and {Platz}, Lukas and {Reinecke}, Martin and {Stadler}, Julia},
title = {{NIFTy5: Numerical Information Field Theory v5}},
howpublished = {Astrophysics Source Code Library, record ascl:1903.008},
......@@ -96,6 +96,17 @@
adsurl = {https://ui.adsabs.harvard.edu/abs/2022ApJ...935..167A},
adsnote = {Provided by the SAO/NASA Astrophysics Data System}
}
@article{Blondel2021,
author = {Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy and Hoyer, Stephan and Llinares-Lopez, Felipe and Pedregosa, Fabian and Vert, Jean-Philippe},
journal = {Advances in Neural Information Processing Systems},
editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh},
pages = {5230--5242},
publisher = {Curran Associates, Inc.},
title = {Efficient and Modular Implicit Differentiation},
url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/228b9279ecf9bbafe582406850c57115-Paper-Conference.pdf},
volume = {35},
year = {2022}
}
@article{Bingham2019,
author = {Eli Bingham and Jonathan P. Chen and Martin Jankowiak and Fritz Obermeyer and Neeraj Pradhan and Theofanis Karaletsos and Rohit Singh and Paul A. Szerlip and Paul Horsfall and Noah D. Goodman},
title = {Pyro: Deep Universal Probabilistic Programming},
......@@ -105,24 +116,13 @@
year = {2019},
url = {http://jmlr.org/papers/v20/18-403.html}
}
@software{blackjax2020,
author = {Cabezas, Alberto, Lao, Junpeng, and Louf, R\'emi},
@misc{blackjax2020,
author = {Cabezas and Alberto and Lao and Junpeng and Louf and R\'emi},
title = {{B}lackjax: A sampling library for {JAX}},
url = {http://github.com/blackjax-devs/blackjax},
version = {v1.1.0},
year = {2023}
}
@article{Blondel2021,
author = {Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy and Hoyer, Stephan and Llinares-Lopez, Felipe and Pedregosa, Fabian and Vert, Jean-Philippe},
booktitle = {Advances in Neural Information Processing Systems},
editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh},
pages = {5230--5242},
publisher = {Curran Associates, Inc.},
title = {Efficient and Modular Implicit Differentiation},
url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/228b9279ecf9bbafe582406850c57115-Paper-Conference.pdf},
volume = {35},
year = {2022}
}
@article{Carpenter2017,
title = {Stan: A Probabilistic Programming Language},
volume = {76},
......@@ -134,13 +134,13 @@
year = {2017},
pages = {1-32}
}
@software{Deepmind2020Optax,
@misc{Deepmind2020Optax,
title = {The {D}eep{M}ind {JAX} {E}cosystem},
author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
url = {http://github.com/google-deepmind},
year = {2020}
}
@software{ducc0,
@misc{ducc0,
author = {Martin Reinecke},
title = {{DUCC}: Distinctly Useful Code Collection},
url = {https://gitlab.mpcdf.mpg.de/mtr/ducc},
......@@ -237,7 +237,7 @@
url = {https://arxiv.org/abs/1703.09710}
}
@software{ForemanMackey2024,
@misc{ForemanMackey2024,
author = {Foreman-Mackey, Daniel and Weixiang Yu and Yadav, Sachin and Becker, McCoy Reynolds and Caplar, Neven and Huppenkothen, Daniela and Killestein, Thomas and Tronsgaard, René and Rashid, Theo and Schmerler, Steve},
title = {{dfm/tinygp: The tiniest of Gaussian Process libraries}},
month = jan,
......@@ -429,7 +429,7 @@
adsurl = {https://ui.adsabs.harvard.edu/abs/2023arXiv230412350H},
adsnote = {Provided by the SAO/NASA Astrophysics Data System}
}
@software{Jax2018,
@misc{Jax2018,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/google/jax},
......@@ -452,7 +452,7 @@
year = {2019},
copyright = {arXiv.org perpetual, non-exclusive license}
}
@software{Koposov2023,
@misc{Koposov2023,
author = {Sergey Koposov and Josh Speagle and Kyle Barbary and Gregory Ashton and Ed Bennett and Johannes Buchner and Carl Scheffler and Ben Cook and Colm Talbot and James Guillochon and Patricio Cubillos and Andrés Asensio Ramos and Ben Johnson and Dustin Lang and Ilya and Matthieu Dartiailh and Alex Nitz and Andrew McCluskey and Anne Archibald},
title = {joshspeagle/dynesty: v2.1.3},
month = oct,
......@@ -768,3 +768,4 @@
journal = {Journal of Open Source Software}
}
../../../paper/paper.md
\ No newline at end of file
This diff is collapsed.
[project]
name = "nifty8"
version = "8.5.5"
version = "8.5.6"
description = "Probabilistic programming framework for signal inference algorithms that operate regardless of the underlying grids and their resolutions"
readme = "README.md"
authors = [
......
......@@ -17,6 +17,7 @@
import numpy as np
from .multi_domain import MultiDomain
from .operators.operator import Operator
from .sugar import makeOp
from .utilities import check_object_identity
......@@ -65,10 +66,16 @@ class Linearization(Operator):
return self.make_var(self._val, self._want_metric)
def prepend_jac(self, jac):
if jac.isIdentity():
return self
if self._metric is None:
if self._jac.isIdentity():
return self.new(self._val, jac)
return self.new(self._val, self._jac @ jac)
from .operators.sandwich_operator import SandwichOperator
metric = SandwichOperator.make(jac, self._metric)
if self._jac.isIdentity():
return self.new(self._val, jac, metric)
return self.new(self._val, self._jac @ jac, metric)
@property
......@@ -118,6 +125,8 @@ class Linearization(Operator):
return self._metric
def __getitem__(self, name):
if not isinstance(self.domain, MultiDomain):
return NotImplemented
return self.new(self._val[name], self._jac.ducktape_left(name))
def __neg__(self):
......
......@@ -12,20 +12,26 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
# Copyright(C) 2025 Philipp Arras
#
# Author: Philipp Arras, Philipp Frank
import os
import pathlib
import pickle
import re
import time
from warnings import warn
import numpy as np
from .. import utilities
from ..field import Field
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..operators.operator import Operator
from ..utilities import get_MPI_params_from_comm, shareRange
from ..utilities import (ensure_all_tasks_succeed, get_MPI_params_from_comm,
shareRange)
class SampleListBase:
......@@ -428,7 +434,7 @@ class SampleListBase:
if re.match(f"{base_file}.[0-9]+.pickle", ff)]
if len(files) == 0:
raise RuntimeError(f"No files matching `{file_name_base}.*.pickle`")
n_samples = max(list(map(lambda x: int(x.split(".")[-2]), files))) + 1
n_samples = _consecutive_length(list(map(lambda x: int(x.split(".")[-2]), files)))
ntask, rank, _ = get_MPI_params_from_comm(comm)
local_indices = range(*shareRange(n_samples, ntask, rank))
......@@ -438,6 +444,26 @@ class SampleListBase:
raise RuntimeError(f"File {ff} not found")
return files
def deep_eq(self, other):
# TODO: Potentially add __eq__ to MultiField and Field to simplify this
if not isinstance(other, SampleListBase):
return NotImplemented
if self.n_samples != other.n_samples:
return False
for xx, yy in zip(self.iterator(), other.iterator()):
if xx.domain is not yy.domain:
return False
if isinstance(xx, MultiField):
xx = list(xx.val.values())
yy = list(yy.val.values())
else:
xx, yy = [xx.val], [yy.val]
assert len(xx) == len(yy)
for xxx, yyy in zip(xx, yy):
if np.any(xxx != yyy):
return False
return True
class ResidualSampleList(SampleListBase):
def __init__(self, mean, residuals, neg, comm=None):
......@@ -508,13 +534,23 @@ class ResidualSampleList(SampleListBase):
return ResidualSampleList(mean, self._r, self._n, self.comm)
def save(self, file_name_base, overwrite=False):
for ii, isample in enumerate(self.local_indices):
obj = [self._r[ii], self._n[ii]]
fname = _sample_file_name(file_name_base, isample)
_save_to_disk(fname, obj, overwrite)
if self.MPI_master:
_save_to_disk(f"{file_name_base}.mean.pickle", self._m, overwrite)
_barrier(self.comm)
# TODO: Make this function atomic potentially unify with
# SampleList.save.
_ensure_proper_sample_list_ending(_sample_file_name(file_name_base, self.n_samples),
overwrite, self.comm)
# Save samples
with ensure_all_tasks_succeed(self.comm):
for ii, isample in enumerate(self.local_indices):
obj = [self._r[ii], self._n[ii]]
fname = _sample_file_name(file_name_base, isample)
_save_to_disk(fname, obj, overwrite)
# Save mean
with ensure_all_tasks_succeed(self.comm):
if self.MPI_master:
_save_to_disk(f"{file_name_base}.mean.pickle", self._m, overwrite)
@classmethod
def load(cls, file_name_base, comm=None):
......@@ -568,13 +604,18 @@ class SampleList(SampleListBase):
return self._s[i]
def save(self, file_name_base, overwrite=False):
nsample = self.n_samples
for isample in range(nsample):
if isample in self.local_indices:
obj = self._s[isample-self.local_indices[0]]
# TODO: Make this function atomic potentially unify with
# ResidualSampleList.save.
_ensure_proper_sample_list_ending(_sample_file_name(file_name_base, self.n_samples),
overwrite, self.comm)
# Save samples
with ensure_all_tasks_succeed(self.comm):
for ii, isample in enumerate(self.local_indices):
obj = self._s[ii]
fname = _sample_file_name(file_name_base, isample)
_save_to_disk(fname, obj, overwrite=True)
_barrier(self.comm)
_save_to_disk(fname, obj, overwrite)
@classmethod
def load(cls, file_name_base, comm=None):
......@@ -647,6 +688,8 @@ def _load_from_disk(file_name):
def _save_to_disk(file_name, obj, overwrite=False):
if not overwrite and os.path.isfile(file_name):
raise RuntimeError(f"{file_name} already exists")
if overwrite and os.path.isfile(file_name):
os.remove(file_name)
with open(file_name, "wb") as f:
......@@ -682,3 +725,31 @@ def _compute_local_indices(n_local, comm):
n_locals = comm.allgather(n_local)
start = sum(n_locals[:comm.Get_rank()])
return range(start, start + n_local)
def _consecutive_length(lst):
if 0 not in lst:
raise ValueError("List does not contain 0, the starting element")
res = 0
while True:
if res + 1 not in lst:
return res + 1
res += 1
def _ensure_proper_sample_list_ending(fname, overwrite, comm):
with ensure_all_tasks_succeed(comm):
MPI_master = utilities.get_MPI_params_from_comm(comm)[2]
if MPI_master:
if overwrite:
# Remove potential "next sample"
# TODO: this might leave dangling samples that are either
# ignored or overwritten later. So it is fine for now. Will
# be deleted once a proper atomic approach is implemented.
pathlib.Path(fname).unlink(missing_ok=True)
else:
# Make sure that "next" sample does not exist
if os.path.isfile(fname):
raise RuntimeError(f"{fname} already exists. You may "
"want to remove it or specify overwrite=True")
......@@ -42,9 +42,16 @@ class ChainOperator(LinearOperator):
@staticmethod
def simplify(ops):
if len(ops) == 1:
return ops
# verify domains
for i in range(len(ops) - 1):
utilities.check_object_identity(ops[i + 1].target, ops[i].domain)
if len(ops)==2:
if ops[0].isIdentity():
return [ops[1]]
if ops[1].isIdentity():
return [ops[0]]
# unpack ChainOperators
opsnew = []
for op in ops:
......
......@@ -11,13 +11,10 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
# Copyright(C) 2013-2025 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from functools import partial
from operator import mul
import numpy as np
from .. import utilities
......@@ -25,6 +22,19 @@ from ..domain_tuple import DomainTuple
from ..field import Field
from .endomorphic_operator import EndomorphicOperator
# TODO: Eventually enforce somewhat modern ducc version (>=0.37.0) as nifty
# dependency and remove the try statement below
try:
from ducc0.misc.experimental import mul_conj, div_conj
except ImportError:
def mul_conj(a, b, out=None):
out = a*b.conj()
return out
def div_conj(a, b, out=None):
out = a/b.conj()
return out
class DiagonalOperator(EndomorphicOperator):
"""Represents a :class:`LinearOperator` which is diagonal.
......@@ -60,11 +70,17 @@ class DiagonalOperator(EndomorphicOperator):
This shortcoming will hopefully be fixed in the future.
"""
def __init__(self, diagonal, domain=None, spaces=None, sampling_dtype=None):
def __init__(self, diagonal, domain=None, spaces=None, sampling_dtype=None,
_trafo=0):
# MR: _trafo is more or less deliberately undocumented, since it is not supposed
# to be necessary for "end users". It describes the type of transform for which
# the diagonal can be used without modification
# (0:TIMES, 1:ADJOINT, 2:INVERSE, 3:ADJOINT_INVERSE)
if not isinstance(diagonal, Field):
raise TypeError("Field object required")
utilities.check_dtype_or_none(sampling_dtype)
self._dtype = sampling_dtype
self._trafo = _trafo
if domain is None:
self._domain = diagonal.domain
else:
......@@ -78,7 +94,8 @@ class DiagonalOperator(EndomorphicOperator):
raise ValueError("spaces and domain must have the same length")
for i, j in enumerate(self._spaces):
if diagonal.domain[i] != self._domain[j]:
raise ValueError("Mismatch:\n{diagonal.domain[i]}\n{self._domain[j]}")
raise ValueError(f"Mismatch between:\n{diagonal.domain[i]}\n"
f"and:\n{self._domain[j]}")
if self._spaces == tuple(range(len(self._domain))):
self._spaces = None # shortcut
......@@ -100,75 +117,97 @@ class DiagonalOperator(EndomorphicOperator):
self._complex = utilities.iscomplextype(self._ldiag.dtype)
self._capability = self._all_ops
if not self._complex:
self._diagmin = self._ldiag.min()
self._diagmin_cache = None
@property
def _diagmin(self):
if self._complex:
raise RuntimeError("complex DiagonalOperator does not have _diagmin")
if self._diagmin_cache is None:
self._diagmin_cache = self._ldiag.min()
return self._diagmin_cache
def _from_ldiag(self, spc, ldiag, sampling_dtype):
def _from_ldiag(self, spc, ldiag, sampling_dtype, trafo):
res = DiagonalOperator.__new__(DiagonalOperator)
res._dtype = sampling_dtype
res._trafo = trafo
res._domain = self._domain
if self._spaces is None or spc is None:
res._spaces = None
else:
res._spaces = tuple(set(self._spaces) | set(spc))
res._ldiag = np.array(ldiag)
utilities.myassert(isinstance(ldiag, np.ndarray))
res._ldiag = ldiag
res._fill_rest()
return res
def _get_actual_diag(self):
if self._trafo == 0:
return self._ldiag
if self._trafo == 1:
return np.conj(self._ldiag) if self._complex else self._ldiag
if self._trafo == 2:
return 1./self._ldiag
if self._trafo == 3:
return np.conj(1./self._ldiag) if self._complex else 1./self._ldiag
def _scale(self, fct):
if not np.isscalar(fct):
raise TypeError("scalar value required")
return self._from_ldiag((), self._ldiag*fct, self._dtype)
return self._from_ldiag((), self._get_actual_diag()*fct, self._dtype, 0)
def _add(self, sum_):
if not np.isscalar(sum_):
raise TypeError("scalar value required")
return self._from_ldiag((), self._ldiag+sum_, self._dtype)
return self._from_ldiag((), self._get_actual_diag()+sum_, self._dtype, 0)
def _combine_prod(self, op):
if not isinstance(op, DiagonalOperator):
raise TypeError("DiagonalOperator required")
dtype = self._dtype if self._dtype == op._dtype else None
return self._from_ldiag(op._spaces, self._ldiag*op._ldiag, dtype)
return self._from_ldiag(op._spaces, self._get_actual_diag()*op._get_actual_diag(),
dtype, 0)
def _combine_sum(self, op, selfneg, opneg):
if not isinstance(op, DiagonalOperator):
raise TypeError("DiagonalOperator required")
tdiag = (self._ldiag * (-1 if selfneg else 1) +
op._ldiag * (-1 if opneg else 1))
tdiag = (self._get_actual_diag() * (-1 if selfneg else 1) +
op._get_actual_diag() * (-1 if opneg else 1))
dtype = self._dtype if self._dtype == op._dtype else None
return self._from_ldiag(op._spaces, tdiag, dtype)
return self._from_ldiag(op._spaces, tdiag, dtype, 0)
def apply(self, x, mode):
self._check_input(x, mode)
# shortcut for most common cases
if mode == 1 or (not self._complex and mode == 2):
# To save both time and memory, we remap the `mode` (via `self._trafo`)
# and do not compute and store a new `self._ldiag`s for adjoint, inverse
# or adjoint-inverse DiagonalOperators.
trafo = self._ilog[mode] ^ self._trafo
if trafo == 0: # straight application
return Field(x.domain, x.val*self._ldiag)
xdiag = self._ldiag
if self._complex and (mode & 10): # adjoint or inverse adjoint
xdiag = xdiag.conj()
if trafo == 1: # adjoint
return Field(x.domain, mul_conj(x.val, self._ldiag)
if self._complex else x.val*self._ldiag)
if trafo == 2: # inverse
return Field(x.domain, x.val/self._ldiag)
if mode & 3:
return Field(x.domain, x.val*xdiag)
return Field(x.domain, x.val/xdiag)
# adjoint inverse
return Field(x.domain, div_conj(x.val, self._ldiag)
if self._complex else x.val/self._ldiag)
def _flip_modes(self, trafo):
if trafo == self.ADJOINT_BIT and not self._complex: # shortcut
return self
xdiag = self._ldiag
if self._complex and (trafo & self.ADJOINT_BIT):
xdiag = xdiag.conj()
if trafo & self.INVERSE_BIT:
# dividing by zero is OK here, we can deal with infinities
with np.errstate(divide='ignore'):
xdiag = 1./xdiag
return self._from_ldiag((), xdiag, self._dtype)
return self._from_ldiag((), self._ldiag, self._dtype, self._trafo ^ trafo)
def process_sample(self, samp, from_inverse):
from_inverse2 = from_inverse ^ (self._trafo >= 2)
# `from_inverse2` captures if the inverse of `self._ldiag` needs to be
# taken or not (can happen for nontrivial `self._trafo`).
if (self._complex or (self._diagmin < 0.) or
(self._diagmin == 0. and from_inverse)):
(self._diagmin == 0. and from_inverse2)):
raise ValueError("operator not positive definite")
if from_inverse:
if from_inverse2:
res = samp.val/np.sqrt(self._ldiag)
else:
res = samp.val*np.sqrt(self._ldiag)
......@@ -184,9 +223,9 @@ class DiagonalOperator(EndomorphicOperator):
return self.process_sample(res, from_inverse)
def get_sqrt(self):
if np.iscomplexobj(self._ldiag) or (self._ldiag < 0).any():
if self._complex or self._diagmin < 0.:
raise ValueError("get_sqrt() works only for positive definite operators.")
return self._from_ldiag((), np.sqrt(self._ldiag), self._dtype)
return self._from_ldiag((), np.sqrt(self._ldiag), self._dtype, self._trafo)
def __repr__(self):
from ..multi_domain import MultiDomain
......
......@@ -573,7 +573,7 @@ class GaussianEnergy(LikelihoodEnergyOperator):
residual = x if self._data is None else x - self._data
res = self._op(residual).real
if x.want_metric:
return res.add_metric(self.get_metric_at(x.val))
return res.add_metric(self._icov)
return res
def get_transformation(self):
......
......@@ -97,12 +97,16 @@ class LinearOperator(Operator):
return self._flip_modes(self.ADJOINT_BIT)
def __matmul__(self, other):
if other.isIdentity():
return self
if isinstance(other, LinearOperator):
from .chain_operator import ChainOperator
return ChainOperator.make([self, other])
return Operator.__matmul__(self, other)
def __rmatmul__(self, other):
if other.isIdentity():
return self
if isinstance(other, LinearOperator):
from .chain_operator import ChainOperator
return ChainOperator.make([other, self])
......@@ -169,6 +173,8 @@ class LinearOperator(Operator):
def __call__(self, x):
"""Same as :meth:`times`"""
if self.isIdentity():
return x
if x.jac is not None:
return x.new(self(x._val), self).prepend_jac(x.jac)
if x.val is not None:
......
......@@ -114,6 +114,9 @@ class Operator(metaclass=NiftyMeta):
"""
return None
def isIdentity(self): # Will be overloaded in ScalingOperator
return False
def scale(self, factor):
if not isinstance(factor, numbers.Number):
raise TypeError(".scale() takes a number as input")
......@@ -165,6 +168,10 @@ class Operator(metaclass=NiftyMeta):
if isinstance(x, LikelihoodEnergyOperator):
return NotImplemented
if x.target is self.domain:
if x.isIdentity():
return self
if self.isIdentity():
return x
return _OpChain.make((self, x))
return self.partial_insert(x)
......@@ -176,6 +183,10 @@ class Operator(metaclass=NiftyMeta):
if isinstance(x, LikelihoodEnergyOperator):
return NotImplemented
if x.domain is self.target:
if x.isIdentity():
return self
if self.isIdentity():
return x
return _OpChain.make((x, self))
return x.partial_insert(self)
......
......@@ -66,6 +66,9 @@ class ScalingOperator(EndomorphicOperator):
check_dtype_or_none(sampling_dtype, self._domain)
self._dtype = sampling_dtype
def isIdentity(self):
return self._factor == 1
def apply(self, x, mode):
from ..sugar import full
......