diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index bdec5dc6bbc22f289909ca5695b73c46ddbaace8..60c223c9c14597443b2b430c750752ff1b6a8e9c 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -142,18 +142,26 @@ run_getting_started_mf:
       - 'getting_started_mf_results'
       - '*.png'
 
+run_getting_started_nifty2jax:
+  stage: demo_runs
+  script:
+    - python3 demos/getting_started_6_nifty2jax.py
+  artifacts:
+    paths:
+      - '*.png'
+
 run_getting_density:
   stage: demo_runs
   script:
-    - python3 demos/getting_started_density.py
+    - python3 demos/more/density_estimation.py
   artifacts:
     paths:
       - '*.png'
 
-run_getting_started_model_comparison:
+run_model_comparison:
   stage: demo_runs
   script:
-    - python3 demos/getting_started_model_comparison.py
+    - python3 demos/more/model_comparison.py
   artifacts:
     paths:
       - '*.png'
@@ -161,7 +169,7 @@ run_getting_started_model_comparison:
 run_bernoulli:
   stage: demo_runs
   script:
-    - python3 demos/bernoulli_demo.py
+    - python3 demos/more/bernoulli_map.py
   artifacts:
     paths:
       - '*.png'
@@ -169,7 +177,7 @@ run_bernoulli:
 run_curve_fitting:
   stage: demo_runs
   script:
-    - python3 demos/polynomial_fit.py
+    - python3 demos/more/polynomial_fit.py
   artifacts:
     paths:
       - '*.png'
@@ -177,9 +185,65 @@ run_curve_fitting:
 run_visual_vi:
   stage: demo_runs
   script:
-    - python3 demos/variational_inference_visualized.py
+    - python3 demos/more/variational_inference_visualized.py
 
 run_meanfield:
   stage: demo_runs
   script:
-    - python3 demos/parametric_variational_inference.py
+    - python3 demos/more/parametric_variational_inference.py
+
+run_demo_categorical_L1:
+  stage: demo_runs
+  script:
+    - python3 demos/re/categorical_L1.py
+  artifacts:
+    paths:
+      - '*.png'
+
+run_demo_cf_w_known_spectrum:
+  stage: demo_runs
+  script:
+    - python3 demos/re/correlated_field_w_known_spectrum.py
+  artifacts:
+    paths:
+      - '*.png'
+
+run_demo_cf_w_unknown_spectrum:
+  stage: demo_runs
+  script:
+    - python3 demos/re/correlated_field_w_unknown_spectrum.py
+  artifacts:
+    paths:
+      - '*.png'
+
+run_demo_cf_w_unknown_factorizing_spectra:
+  stage: demo_runs
+  script:
+    - python3 demos/re/correlated_field_w_unknown_factorizing_spectra.py
+  artifacts:
+    paths:
+      - '*.png'
+
+run_demo_nifty_to_jifty:
+  stage: demo_runs
+  script:
+    - python3 demos/re/nifty_to_jifty.py
+  artifacts:
+    paths:
+      - '*.png'
+
+run_demo_banana:
+  stage: demo_runs
+  script:
+    - python3 demos/re/banana.py
+  artifacts:
+    paths:
+      - '*.png'
+
+run_demo_banana_w_reg:
+  stage: demo_runs
+  script:
+    - python3 demos/re/banana_w_reg.py
+  artifacts:
+    paths:
+      - '*.png'
diff --git a/demos/getting_started_6_nifty2jax.py b/demos/getting_started_6_nifty2jax.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a1a6b2e5795ee7c90b41dbc9bff9627daf7c49a
--- /dev/null
+++ b/demos/getting_started_6_nifty2jax.py
@@ -0,0 +1,348 @@
+#!/usr/bin/env python3
+
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+# %% [markdown]
+# ## What Is This All About?
+#
+# * Short introduction in how to port code from NIFTy to JAX + NIFTY (jifty)
+#   * How to get the JAX expression for a NIFTy operator
+#   * How to minimize in jifty
+# * Benchmark NIFTy vs jifty
+
+# %%
+from collections import namedtuple
+from functools import partial
+import sys
+
+from jax import jit, value_and_grad
+from jax import random
+from jax import numpy as jnp
+from jax.config import config as jax_config
+from jax.tree_util import tree_map
+import matplotlib.pyplot as plt
+import numpy as np
+
+import nifty8 as ift
+import nifty8.re as jft
+
+jax_config.update("jax_enable_x64", True)
+# jax_config.update('jax_log_compiles', True)
+
+# %%
+filename = "getting_started_nifty2jax{}.png"
+
+position_space = ift.RGSpace([512, 512])
+cfm_kwargs = {
+    'offset_mean': -2.,
+    'offset_std': (1e-5, 1e-6),
+    'fluctuations': (2., 0.2),  # Amplitude of field fluctuations
+    'loglogavgslope': (-4., 1),  # Exponent of power law power spectrum
+    # Amplitude of integrated Wiener process on top of power law power spectrum
+    'flexibility': (8e-1, 1e-1),
+    'asperity': (3e-1, 1e-3)  # Ragged-ness of integrated Wiener process
+}
+
+correlated_field_nft = ift.SimpleCorrelatedField(position_space, **cfm_kwargs)
+pow_spec_nft = correlated_field_nft.power_spectrum
+
+signal_nft = correlated_field_nft.exp()
+response_nft = ift.GeometryRemover(signal_nft.target)
+signal_response_nft = response_nft(signal_nft)
+
+# %% [markdown]
+# ## From NIFTy to JAX + NIFTy
+#
+# By now, we built a beautiful and very complicated forward model. However,
+# instead of using vanilla NumPy (i.e. using plain NIFTy), we want to compile
+# the forward pass with JAX.
+
+# Note, JAX + NIFTy does not have the concept of domains. Though, it still
+# needs to know how large the parameter space is. This can either be provided
+# via an initializer or via a pytree containing the shapes and dtypes. Thus, in
+# addition to extracting the JAX call, we also need to extract the parameter
+# space on which this call should act.
+
+# %%
+pow_spec = pow_spec_nft.jax_expr
+signal = signal_nft.jax_expr
+# Convenience method to get JAX expression and domain
+signal_response = ift.nifty2jax.convert(signal_response_nft, float)
+
+noise_cov = 0.5**2
+
+# %%
+key = random.PRNGKey(42)
+
+key, sk = random.split(key)
+synth_pos = jft.random_like(sk, signal_response)
+data = synth_signal_response = signal_response(synth_pos)
+data += jnp.sqrt(noise_cov) * random.normal(sk, shape=data.shape)
+
+fig, axs = plt.subplots(1, 2, figsize=(8, 4))
+im = axs.flat[0].imshow(synth_signal_response)
+fig.colorbar(im, ax=axs.flat[0])
+im = axs.flat[1].imshow(data)
+fig.colorbar(im, ax=axs.flat[1])
+fig.tight_layout()
+plt.show()
+
+# %%
+lh = jft.Gaussian(data, noise_cov_inv=lambda x: x / noise_cov) @ signal_response
+ham = jft.StandardHamiltonian(likelihood=lh).jit()
+ham_vg = jit(jft.mean_value_and_grad(ham))
+ham_metric = jit(jft.mean_metric(ham.metric))
+MetricKL = jit(
+    partial(jft.MetricKL, ham),
+    static_argnames=("n_samples", "mirror_samples", "linear_sampling_name")
+)
+
+key, subkey = random.split(key)
+pos = pos_init = 1e-2 * jft.random_like(subkey, signal_response)
+
+# %% [markdown]
+# Let's do a simple MGVI minimization. Note, while this might look very similar
+# to plain NIFTy, the convergence criteria and various implementation details
+# are very different. Thus, timing the minimization and comparing it to NIFTy
+# most probably leads to very screwed results. It is best to only compare a
+# single value-and-gradient call in both implementations for the purpose of
+# creating a benchmark.
+
+# %%
+n_mgvi_iterations = 10
+n_samples = 2
+absdelta = 0.1
+n_newton_iterations = 15
+
+# Minimize the potential
+key, *sk = random.split(key, 1 + n_mgvi_iterations)
+for i, subkey in enumerate(sk):
+    print(f"MGVI Iteration {i}", file=sys.stderr)
+    print("Sampling...", file=sys.stderr)
+    mg_samples = MetricKL(
+        pos,
+        n_samples=n_samples,
+        key=subkey,
+        mirror_samples=True,
+        linear_sampling_name=None,
+        linear_sampling_kwargs={"absdelta": absdelta / 10.}
+    )
+
+    print("Minimizing...", file=sys.stderr)
+    opt_state = jft.minimize(
+        None,
+        pos,
+        method="newton-cg",
+        options={
+            "fun_and_grad": partial(ham_vg, primals_samples=mg_samples),
+            "hessp": partial(ham_metric, primals_samples=mg_samples),
+            "absdelta": absdelta,
+            "maxiter": n_newton_iterations
+        }
+    )
+    pos = opt_state.x
+    msg = f"Post MGVI Iteration {i}: Energy {mg_samples.at(pos).mean(ham):2.4e}"
+    print(msg, file=sys.stderr)
+
+# %%
+# The minimization is done now and we can have a look at the result.
+fig, axs = plt.subplots(1, 3, figsize=(12, 4))
+im = axs.flat[0].imshow(synth_signal_response)
+fig.colorbar(im, ax=axs.flat[0])
+im = axs.flat[1].imshow(data)
+fig.colorbar(im, ax=axs.flat[1])
+sr_pm = mg_samples.at(pos).mean(signal_response)
+im = axs.flat[2].imshow(sr_pm)
+fig.colorbar(im, ax=axs.flat[2])
+fig.tight_layout()
+plt.show()
+
+# %% [markdown]
+# Awesome! We have seen now how a model can be translated to JAX. By doing so
+# we were able to use such convenient transformation like `jit` and
+# `value_and_grad` from JAX. Thus, we can start using higher order derivatives
+# and other useful JAX features like `vmap` and `pmap`. Last but certainly not
+# least, we can now also let our code run on the GPU.
+
+# %% [markdown]
+# ## Performance
+#
+# The driving force behind all of this is of course speed! So let's validate
+# that translating the model to JAX actually is faster.
+
+# %%
+Timed = namedtuple("Timed", ("time", "number"), rename=True)
+
+
+def timeit(stmt, setup=lambda: None, number=None):
+    import timeit
+
+    if number is None:
+        number, _ = timeit.Timer(stmt).autorange()
+
+    setup()
+    t = timeit.timeit(stmt, number=number) / number
+    return Timed(time=t, number=number)
+
+
+r = jft.random_like(random.PRNGKey(54), signal_response)
+
+r_nft = ift.makeField(signal_response_nft.domain, r.val)
+data_nft = ift.makeField(signal_response_nft.target, data)
+lh_nft = ift.GaussianEnergy(
+    data_nft,
+    inverse_covariance=ift.ScalingOperator(data_nft.domain, 1. / noise_cov)
+) @ signal_response_nft
+ham_nft = ift.StandardHamiltonian(lh_nft)
+
+_ = ham(r)  # Warm-Up
+t = timeit(lambda: ham(r).block_until_ready())
+t_nft = timeit(lambda: ham_nft(r_nft))
+
+print(f"W/  JAX :: {t}")
+print(f"W/O JAX :: {t_nft}")
+
+# %%
+# For about 2e+5 #parameters the FFT starts to dominate in the computation and
+# NumPy-based NIFTy is about as fast as JAX-based NIFTy. Thus, we should not
+# have expected to gain much performance for our model at hand.
+
+# So far so good but are we really sure that this is doing the same thing. To
+# validate the result of our model in JAX, let's transfer our synthetic
+# position to plain NIFTy and run the model there again.
+
+sp = ift.makeField(signal_response_nft.domain, synth_pos.val)
+np.testing.assert_allclose(
+    signal_response_nft(sp).val, signal_response(synth_pos)
+)
+
+# %% [markdown]
+# For smaller models or models where the FFT does not dominate JAX-based NIFTy
+# should always have an edge over NumPy based NIFTy. The difference in
+# performance can range from only a couple of double digit percentages for
+# \approx 1e+5 #parameters to many orders of magnitudes. For example with 65536
+# #parameters JAX-based NIFTy should be 2-3 times faster.
+
+# We can show this more explicitly with a proper benchmark. In the following we
+# will instantiate models of various shapes and time the JAX version against
+# the NumPy version. Instead of testing solely a single forward pass, we will
+# compare a full evaluation of the model and its gradient.
+
+
+# %%
+def get_lognormal_model(shapes, cfm_kwargs, data_key, noise_cov=0.5**2):
+    import warnings
+
+    position_space = ift.RGSpace(shapes)
+
+    with warnings.catch_warnings():
+        warnings.filterwarnings(
+            action="ignore", category=UserWarning, message="no JAX"
+        )
+        correlated_field_nft = ift.SimpleCorrelatedField(
+            position_space, **cfm_kwargs
+        )
+        signal_nft = correlated_field_nft.exp()
+        response_nft = ift.GeometryRemover(signal_nft.target)
+        signal_response_nft = response_nft(signal_nft)
+
+    signal_response = ift.nifty2jax.convert(signal_response_nft, float)
+
+    sk_signal, sk_noise = random.split(data_key)
+    synth_pos = jft.random_like(sk_signal, signal_response)
+    data = signal_response(synth_pos)
+    data += jnp.sqrt(noise_cov) * random.normal(sk_noise, shape=data.shape)
+
+    noise_cov_inv = 1. / noise_cov
+    noise_std_inv = jnp.sqrt(noise_cov_inv)
+    lh = jft.Gaussian(
+        data,
+        noise_cov_inv=lambda x: noise_cov_inv * x,
+        noise_std_inv=lambda x: noise_std_inv * x
+    ) @ signal_response
+    ham = jft.StandardHamiltonian(likelihood=lh)
+    ham_vg = value_and_grad(ham)
+
+    with warnings.catch_warnings():
+        warnings.filterwarnings(
+            action="ignore", category=UserWarning, message="no JAX"
+        )
+        data_nft = ift.makeField(signal_response_nft.target, data)
+        noise_cov_inv_nft = ift.ScalingOperator(data_nft.domain, 1. / noise_cov)
+        lh_nft = ift.GaussianEnergy(
+            data_nft, inverse_covariance=noise_cov_inv_nft
+        ) @ signal_response_nft
+        ham_nft = ift.StandardHamiltonian(lh_nft)
+
+    def ham_vg_nft(x):
+        x = x.val if isinstance(x, jft.Field) else x
+        x = ift.makeField(ham_nft.domain, x)
+        x = ift.Linearization.make_var(x)
+        with warnings.catch_warnings():
+            warnings.filterwarnings(
+                action="ignore", category=UserWarning, message="no JAX"
+            )
+            res = ham_nft(x)
+        one_nft = ift.Field(ift.DomainTuple.make(()), 1.)
+        bwd = res.jac.adjoint_times(one_nft)
+        return (res.val.val, bwd.val)
+
+    aux = {
+        "synthetic_position": synth_pos,
+        "hamiltonian_nft": ham_nft,
+        "hamiltonian": ham,
+        "signal_response_nft": signal_response_nft,
+        "signal_response": signal_response,
+    }
+    return ham_vg, ham_vg_nft, aux
+
+
+get_ln_mod = partial(
+    get_lognormal_model, cfm_kwargs=cfm_kwargs, data_key=key, noise_cov=0.5**2
+)
+
+dimensions_to_test = [
+    (256, ), (512, ), (1024, ), (256**2, ), (512**2, ), (128, 128), (256, 256),
+    (512, 512), (1024, 1024), (2048, 2048)
+]
+for dims in dimensions_to_test:
+    h, h_nft, aux = get_ln_mod(dims)
+    r = aux["synthetic_position"]
+    h = jit(h)
+    _ = h(r)  # Warm-Up
+
+    np.testing.assert_allclose(h(r)[0], h_nft(r)[0])
+    ift.myassert(all(tree_map(np.allclose, h(r)[1].val, h_nft(r)[1]).values()))
+    ti = timeit(lambda: h(r)[0].block_until_ready())
+    ti_n = timeit(lambda: h_nft(r))
+
+    print(
+        f"Shape {str(dims):>16s}"
+        f" :: JAX {ti.time:4.2e}"
+        f" :: NIFTy {ti_n.time:4.2e}"
+        f" ;; ({ti.number:6d}, {ti_n.number:<6d} loops respectively)"
+    )
+
+# %% [markdown]
+# | Shape                  | JAX          | NIFTy          | Loops respectively                  |
+# |:-----------------------|:-------------|:---------------| -----------------------------------:|
+# | Shape           (256,) | JAX 2.58e-05 | NIFTy 6.96e-03 | ( 10000, 50     loops respectively) |
+# | Shape           (512,) | JAX 3.90e-05 | NIFTy 7.14e-03 | ( 10000, 50     loops respectively) |
+# | Shape          (1024,) | JAX 6.33e-05 | NIFTy 6.97e-03 | (  5000, 50     loops respectively) |
+# | Shape         (65536,) | JAX 5.41e-03 | NIFTy 1.42e-02 | (    50, 20     loops respectively) |
+# | Shape        (262144,) | JAX 2.72e-02 | NIFTy 4.41e-02 | (    10, 5      loops respectively) |
+# | Shape       (128, 128) | JAX 5.07e-04 | NIFTy 7.00e-03 | (   500, 50     loops respectively) |
+# | Shape       (256, 256) | JAX 3.74e-03 | NIFTy 1.01e-02 | (   100, 20     loops respectively) |
+# | Shape       (512, 512) | JAX 1.53e-02 | NIFTy 2.33e-02 | (    20, 10     loops respectively) |
+# | Shape     (1024, 1024) | JAX 7.80e-02 | NIFTy 7.72e-02 | (     5, 5      loops respectively) |
+# | Shape     (2048, 2048) | JAX 3.21e-01 | NIFTy 3.52e-01 | (     1, 1      loops respectively) |
+
+# For small problems JAX-based NIFTy is significantly faster than the NumPy
+# based one. For really small problems it is more than 200 times faster. This
+# is because the overhead from python can be significantly reduced with JAX
+# since most of the heavy-lifting happens without going back to python.
+
+# Notice, how above a certain threshold, here 2e+5, the NumPy-based NIFTy and
+# JAX-bassed NIFTy start to perform similarly well because the performance of
+# the FFT is the sole bottle neck.
diff --git a/demos/bernoulli_demo.py b/demos/more/bernoulli_map.py
similarity index 100%
rename from demos/bernoulli_demo.py
rename to demos/more/bernoulli_map.py
diff --git a/demos/misc/convolution.py b/demos/more/convolution.py
similarity index 100%
rename from demos/misc/convolution.py
rename to demos/more/convolution.py
diff --git a/demos/getting_started_density.py b/demos/more/density_estimation.py
similarity index 100%
rename from demos/getting_started_density.py
rename to demos/more/density_estimation.py
diff --git a/demos/getting_started_model_comparison.py b/demos/more/model_comparison.py
similarity index 100%
rename from demos/getting_started_model_comparison.py
rename to demos/more/model_comparison.py
diff --git a/demos/parametric_variational_inference.py b/demos/more/parametric_variational_inference.py
similarity index 100%
rename from demos/parametric_variational_inference.py
rename to demos/more/parametric_variational_inference.py
diff --git a/demos/polynomial_fit.py b/demos/more/polynomial_fit.py
similarity index 100%
rename from demos/polynomial_fit.py
rename to demos/more/polynomial_fit.py
diff --git a/demos/variational_inference_visualized.py b/demos/more/variational_inference_visualized.py
similarity index 100%
rename from demos/variational_inference_visualized.py
rename to demos/more/variational_inference_visualized.py
diff --git a/demos/re/banana.py b/demos/re/banana.py
new file mode 100644
index 0000000000000000000000000000000000000000..c496d29ef001e2a3b420e906b66d4e3089c07eee
--- /dev/null
+++ b/demos/re/banana.py
@@ -0,0 +1,173 @@
+#!/usr/bin/env python3
+
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial
+import sys
+
+from jax import numpy as jnp
+from jax import lax, random
+from jax import jit
+from jax.config import config
+import matplotlib.pyplot as plt
+
+import nifty8.re as jft
+
+config.update("jax_enable_x64", True)
+
+seed = 42
+key = random.PRNGKey(seed)
+
+
+# %%
+def cartesian_product(arrays, out=None):
+    import numpy as np
+
+    # Generalized N-dimensional products
+    arrays = [np.asarray(x) for x in arrays]
+    la = len(arrays)
+    dtype = np.find_common_type([a.dtype for a in arrays], [])
+    if out is None:
+        out = np.empty([len(a) for a in arrays] + [la], dtype=dtype)
+    for i, a in enumerate(np.ix_(*arrays)):
+        out[..., i] = a
+    return out.reshape(-1, la)
+
+
+def banana_helper_phi_b(b, x):
+    return jnp.array([x[0], x[1] + b * x[0]**2 - 100 * b])
+
+
+def sample_nonstandard_hamiltonian(
+    likelihood, primals, key, cg=jft.static_cg, cg_name=None, cg_kwargs=None
+):
+    if not isinstance(likelihood, jft.Likelihood):
+        te = f"`likelihood` of invalid type; got '{type(likelihood)}'"
+        raise TypeError(te)
+    from jax.tree_util import Partial
+
+    cg_kwargs = cg_kwargs if cg_kwargs is not None else {}
+    cg_kwargs = {"name": cg_name, **cg_kwargs}
+
+    white_sample = jft.random_like(
+        key, likelihood.left_sqrt_metric_tangents_shape
+    )
+    met_smpl = likelihood.left_sqrt_metric(primals, white_sample)
+    inv_metric_at_p = partial(
+        cg, Partial(likelihood.metric, primals), **cg_kwargs
+    )
+    signal_smpl = inv_metric_at_p(met_smpl)[0]
+    return signal_smpl
+
+
+def NonStandardMetricKL(
+    likelihood,
+    primals,
+    n_samples,
+    key,
+    mirror_samples: bool = True,
+    linear_sampling_cg=jft.static_cg,
+    linear_sampling_name=None,
+    linear_sampling_kwargs=None,
+):
+    from jax.tree_util import Partial
+
+    if not isinstance(likelihood, jft.Likelihood):
+        te = f"`likelihood` of invalid type; got '{type(likelihood)}'"
+        raise TypeError(te)
+
+    draw = Partial(
+        sample_nonstandard_hamiltonian,
+        likelihood=likelihood,
+        primals=primals,
+        cg=linear_sampling_cg,
+        cg_name=linear_sampling_name,
+        cg_kwargs=linear_sampling_kwargs,
+    )
+    subkeys = random.split(key, n_samples)
+    samples_stack = lax.map(lambda k: draw(key=k), subkeys)
+
+    return jft.kl.SampleIter(
+        mean=primals,
+        samples=jft.unstack(samples_stack),
+        linearly_mirror_samples=mirror_samples
+    )
+
+
+# %%
+b = 0.1
+
+signal_response = partial(banana_helper_phi_b, b)
+nll = jft.Gaussian(
+    jnp.zeros(2), lambda x: x / jnp.array([100., 1.])
+) @ signal_response
+
+ham = nll
+ham = ham.jit()
+ham_vg = jit(jft.mean_value_and_grad(ham))
+ham_metric = jit(jft.mean_metric(ham.metric))
+
+# %%
+n_mgvi_iterations = 30
+n_samples = [1] * (n_mgvi_iterations - 10) + [2] * 5 + [3, 3, 10, 10, 100]
+n_newton_iterations = [7] * (n_mgvi_iterations - 10) + [10] * 6 + 4 * [25]
+absdelta = 1e-12
+
+initial_position = jnp.array([1., 1.])
+mkl_pos = 1e-2 * initial_position
+
+# Minimize the potential
+for i in range(n_mgvi_iterations):
+    print(f"MGVI Iteration {i}", file=sys.stderr)
+    print("Sampling...", file=sys.stderr)
+    key, subkey = random.split(key, 2)
+    samples = NonStandardMetricKL(
+        ham,
+        mkl_pos,
+        n_samples[i],
+        key=subkey,
+        mirror_samples=True,
+        linear_sampling_kwargs={"miniter": 0},
+    )
+
+    print("Minimizing...", file=sys.stderr)
+    opt_state = jft.minimize(
+        None,
+        x0=mkl_pos,
+        method="newton-cg",
+        options={
+            "fun_and_grad": partial(ham_vg, primals_samples=samples),
+            "hessp": partial(ham_metric, primals_samples=samples),
+            "energy_reduction_factor": None,
+            "absdelta": absdelta,
+            "maxiter": n_newton_iterations[i],
+            "cg_kwargs": {
+                "miniter": 0
+            },
+            "name": "N",
+        }
+    )
+    mkl_pos = opt_state.x
+    msg = f"Post MGVI Iteration {i}: Energy {samples.at(mkl_pos).mean(ham):2.4e}"
+    print(msg, file=sys.stderr)
+
+# %%
+b_space_smpls = jnp.array(tuple(samples.at(mkl_pos)))
+
+n_pix_sqrt = 1000
+x = jnp.linspace(-10.0, 10.0, n_pix_sqrt, endpoint=True)
+y = jnp.linspace(2.0, 17.0, n_pix_sqrt, endpoint=True)
+X, Y = jnp.meshgrid(x, y)
+XY = jnp.array([X, Y]).T
+xy = XY.reshape((XY.shape[0] * XY.shape[1], 2))
+es = jnp.exp(-lax.map(ham, xy)).reshape(XY.shape[:2]).T
+
+fig, ax = plt.subplots()
+contour = ax.contour(X, Y, es)
+ax.clabel(contour, inline=True, fontsize=10)
+ax.scatter(*b_space_smpls.T)
+ax.plot(*mkl_pos, "rx")
+fig.tight_layout()
+fig.savefig("banana_mgvi_wo_regularization.png", dpi=400)
+plt.close()
diff --git a/demos/re/banana_w_reg.py b/demos/re/banana_w_reg.py
new file mode 100644
index 0000000000000000000000000000000000000000..f93f7e9c203a7396fa267519d04f77e0877cea13
--- /dev/null
+++ b/demos/re/banana_w_reg.py
@@ -0,0 +1,257 @@
+#!/usr/bin/env python3
+
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+# %%
+from functools import partial
+import sys
+
+from jax import numpy as jnp
+from jax import lax, random
+from jax import jit, value_and_grad
+from jax.config import config
+import matplotlib.pyplot as plt
+
+import nifty8.re as jft
+
+config.update("jax_enable_x64", True)
+
+seed = 42
+key = random.PRNGKey(seed)
+
+
+# %%
+def cartesian_product(arrays, out=None):
+    import numpy as np
+
+    # Generalized N-dimensional products
+    arrays = [np.asarray(x) for x in arrays]
+    la = len(arrays)
+    dtype = np.find_common_type([a.dtype for a in arrays], [])
+    if out is None:
+        out = np.empty([len(a) for a in arrays] + [la], dtype=dtype)
+    for i, a in enumerate(np.ix_(*arrays)):
+        out[..., i] = a
+    return out.reshape(-1, la)
+
+
+def banana_helper_phi_b(b, x):
+    return jnp.array([x[0], x[1] + b * x[0]**2 - 100 * b])
+
+
+# %%
+b = 0.1
+
+SCALE = 10.
+
+signal_response = lambda s: banana_helper_phi_b(b, SCALE * s)
+nll = jft.Gaussian(
+    jnp.zeros(2), lambda x: x / jnp.array([100., 1.])
+) @ signal_response
+nll = nll.jit()
+nll_vg = jit(value_and_grad(nll))
+
+ham = jft.StandardHamiltonian(nll)
+ham = ham.jit()
+ham_vg = jit(jft.mean_value_and_grad(ham))
+ham_metric = jit(jft.mean_metric(ham.metric))
+
+MetricKL = jit(
+    partial(jft.MetricKL, ham),
+    static_argnames=("n_samples", "mirror_samples", "linear_sampling_name")
+)
+GeoMetricKL = partial(jft.GeoMetricKL, ham)
+
+# # %%
+# # TODO: Stabilize inversion
+# gkl_position = jnp.array([1.15995025, -0.35110244])
+# special_key = jnp.array([3269562362, 460782344], dtype=jnp.uint32)
+# err = jft.geometrically_sample_standard_hamiltonian(
+#     key=special_key,
+#     hamiltonian=ham,
+#     primals=gkl_position,
+#     mirror_linear_sample=False,
+#     linear_sampling_name="SCG",
+#     linear_sampling_kwargs={"miniter": -1},
+#     non_linear_sampling_name="S",
+#     non_linear_sampling_kwargs={
+#         "cg_kwargs": {
+#             "miniter": -1
+#         },
+#         "maxiter": 20,
+#     }
+# )
+
+# %%  # MGVI
+n_mgvi_iterations = 30
+n_samples = [1] * (n_mgvi_iterations - 2) + [2] + [100]
+n_newton_iterations = [7] * (n_mgvi_iterations - 10) + [10] * 6 + 4 * [25]
+absdelta = 1e-10
+
+initial_position = jnp.array([1., 1.])
+mkl_pos = 1e-2 * initial_position
+
+# Minimize the potential
+for i in range(n_mgvi_iterations):
+    print(f"MGVI Iteration {i}", file=sys.stderr)
+    print("Sampling...", file=sys.stderr)
+    key, subkey = random.split(key, 2)
+    mg_samples = MetricKL(
+        mkl_pos,
+        n_samples=n_samples[i],
+        key=subkey,
+        mirror_samples=True,
+        linear_sampling_kwargs={"miniter": 0}
+    )
+
+    print("Minimizing...", file=sys.stderr)
+    opt_state = jft.minimize(
+        None,
+        x0=mkl_pos,
+        method="newton-cg",
+        options={
+            "fun_and_grad": partial(ham_vg, primals_samples=mg_samples),
+            "hessp": partial(ham_metric, primals_samples=mg_samples),
+            "energy_reduction_factor": None,
+            "absdelta": absdelta,
+            "maxiter": n_newton_iterations[i],
+            "cg_kwargs": {
+                "miniter": 0,
+                "name": None
+            },
+            "name": "N"
+        }
+    )
+    mkl_pos = opt_state.x
+    print(
+        (
+            f"Post MGVI Iteration {i}: Energy {mg_samples.at(mkl_pos).mean(ham):2.4e}"
+            f"; #NaNs {jnp.isnan(mkl_pos).sum()}"
+        ),
+        file=sys.stderr
+    )
+
+# %%  # geoVI
+n_geovi_iterations = 15
+n_samples = [1] * (n_geovi_iterations - 2) + [2] + [100]
+n_newton_iterations = [7] * (n_geovi_iterations - 10) + [10] * 6 + [25] * 4
+absdelta = 1e-10
+
+initial_position = jnp.array([1., 1.])
+gkl_pos = 1e-2 * initial_position
+
+for i in range(n_geovi_iterations):
+    print(f"GeoVI Iteration {i}", file=sys.stderr)
+    print("Sampling...", file=sys.stderr)
+    key, subkey = random.split(key, 2)
+    geo_samples = GeoMetricKL(
+        gkl_pos,
+        n_samples[i],
+        key=subkey,
+        mirror_samples=True,
+        linear_sampling_name=None,
+        linear_sampling_kwargs={"miniter": 0},
+        non_linear_sampling_name=None,
+        non_linear_sampling_kwargs={
+            "cg_kwargs": {
+                "miniter": 0,
+                "absdelta": None
+            },
+            "maxiter": 20,
+        },
+    )
+
+    print("Minimizing...", file=sys.stderr)
+    opt_state = jft.minimize(
+        None,
+        x0=gkl_pos,
+        method="newton-cg",
+        options={
+            "fun_and_grad": partial(ham_vg, primals_samples=geo_samples),
+            "hessp": partial(ham_metric, primals_samples=geo_samples),
+            "energy_reduction_factor": None,
+            "absdelta": absdelta,
+            "maxiter": n_newton_iterations[i],
+            "cg_kwargs": {
+                "miniter": 0,
+                "name": None
+            },
+            "name": "N",
+        }
+    )
+    gkl_pos = opt_state.x
+
+# %%
+absdelta = 1e-10
+opt_state = jft.minimize(
+    None,
+    x0=jnp.array([1., 1.]),
+    method="newton-cg",
+    options={
+        "fun_and_grad": ham_vg,
+        "hessp": ham.metric,
+        "energy_reduction_factor": None,
+        "absdelta": absdelta,
+        "maxiter": 100,
+        "cg_kwargs": {
+            "miniter": 0,
+            "name": None
+        },
+        "name": "MAP"
+    }
+)
+map_pos = opt_state.x
+key, subkey = random.split(key, 2)
+map_geo_samples = GeoMetricKL(
+    map_pos,
+    100,
+    key=subkey,
+    mirror_samples=True,
+    linear_sampling_name=None,
+    linear_sampling_kwargs={"miniter": 0},
+    non_linear_sampling_name=None,
+    non_linear_sampling_kwargs={
+        "cg_kwargs": {
+            "miniter": 0
+        },
+        "maxiter": 20,
+    }
+)
+
+# %%
+
+n_pix_sqrt = 1000
+x = jnp.linspace(-30 / SCALE, 30 / SCALE, n_pix_sqrt)
+y = jnp.linspace(-15 / SCALE, 15 / SCALE, n_pix_sqrt)
+X, Y = jnp.meshgrid(x, y)
+XY = jnp.array([X, Y]).T
+xy = XY.reshape((XY.shape[0] * XY.shape[1], 2))
+es = jnp.exp(-lax.map(ham, xy)).reshape(XY.shape[:2]).T
+
+fig, axs = plt.subplots(1, 3, figsize=(16, 9))
+
+b_space_smpls = jnp.array(tuple(mg_samples.at(mkl_pos)))
+contour = axs[0].contour(X, Y, es)
+axs[0].clabel(contour, inline=True, fontsize=10)
+axs[0].scatter(*b_space_smpls.T)
+axs[0].plot(*mkl_pos, "rx")
+axs[0].set_title("MGVI")
+
+b_space_smpls = jnp.array(tuple(geo_samples.at(gkl_pos)))
+contour = axs[1].contour(X, Y, es)
+axs[1].clabel(contour, inline=True, fontsize=10)
+axs[1].scatter(*b_space_smpls.T, alpha=0.7)
+axs[1].plot(*gkl_pos, "rx")
+axs[1].set_title("GeoVI")
+
+b_space_smpls = jnp.array(tuple(map_geo_samples.at(map_pos)))
+contour = axs[2].contour(X, Y, es)
+axs[2].clabel(contour, inline=True, fontsize=10)
+axs[2].scatter(*b_space_smpls.T, alpha=0.7)
+axs[2].plot(*map_pos, "rx")
+axs[2].set_title("MAP + GeoVI Samples")
+
+fig.tight_layout()
+fig.savefig("banana_vi_w_regularization.png", dpi=400)
+plt.close()
diff --git a/demos/re/categorical_L1.py b/demos/re/categorical_L1.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dbed1c54e9e265e63b4132129a2937264c9d3f7
--- /dev/null
+++ b/demos/re/categorical_L1.py
@@ -0,0 +1,126 @@
+#!/usr/bin/env python3
+
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial
+import sys
+
+from jax import numpy as jnp
+from jax import random
+from jax import jit, value_and_grad
+from jax.config import config
+import matplotlib.pyplot as plt
+
+import nifty8.re as jft
+
+config.update("jax_enable_x64", True)
+
+
+def build_model(predictors, targets, sh, alpha=1):
+    my_laplace_prior = jft.interpolate()(jft.laplace_prior(alpha))
+    matrix = lambda x: my_laplace_prior(x).reshape(sh)
+    model = lambda x: jnp.matmul(predictors, matrix(x))
+    lh = jft.Categorical(targets, axis=1)
+    return {"lh": lh @ model, "logits": model, "matrix": matrix}
+
+
+seed = 42
+key = random.PRNGKey(seed)
+
+N_data = 1024
+N_categories = 10
+N_predictors = 3
+
+n_mgvi_iterations = 5
+n_samples = 5
+mirror_samples = True
+n_newton_iterations = 5
+
+# Create synthetic data
+mock_predictors = random.normal(shape=(N_data, N_predictors), key=key)
+key, subkey = random.split(key)
+model = build_model(
+    mock_predictors, jnp.zeros((N_data, 1), dtype=jnp.int32),
+    (N_predictors, N_categories)
+)
+latent_truth = random.normal(shape=(N_predictors * N_categories, ), key=subkey)
+key, subkey = random.split(key)
+matrix_truth = model["matrix"](latent_truth)
+logits_truth = model["logits"](latent_truth)
+
+mock_targets = random.categorical(logits=logits_truth, key=subkey)
+key, subkey = random.split(key)
+mock_targets = mock_targets.reshape(N_data, 1)
+
+model = build_model(mock_predictors, mock_targets, (N_predictors, N_categories))
+ham = jft.StandardHamiltonian(likelihood=model["lh"]).jit()
+
+pos_init = .1 * random.normal(shape=(N_predictors * N_categories, ), key=subkey)
+key, subkey = random.split(key)
+pos = pos_init.copy()
+
+diff_to_truth = jnp.linalg.norm(model["matrix"](pos) - matrix_truth)
+print(f"Initial diff to truth {diff_to_truth}", file=sys.stderr)
+
+
+def energy(p, samps):
+    return jnp.mean(jnp.array([ham(p + s) for s in samps]), axis=0)
+
+
+@jit
+def metric(p, t, samps):
+    results = [ham.metric(p + s, t) for s in samps]
+    return jnp.mean(jnp.array(results), axis=0)
+
+
+energy_vag = jit(value_and_grad(energy))
+draw = partial(jft.kl.sample_standard_hamiltonian, hamiltonian=ham)
+
+# Preform MGVI loop
+for i in range(n_mgvi_iterations):
+    print(f"MGVI Iteration {i}", file=sys.stderr)
+    key, *subkeys = random.split(key, 1 + n_samples)
+    samples = []
+    samples = [draw(primals=pos, key=k) for k in subkeys]
+
+    Evag = lambda p: energy_vag(p, samples)
+    met = lambda p, t: metric(p, t, samples)
+    opt_state = jft.minimize(
+        None,
+        x0=pos,
+        method="newton-cg",
+        options={
+            "fun_and_grad": Evag,
+            "hessp": met,
+            "maxiter": n_newton_iterations
+        }
+    )
+    pos = opt_state.x
+    diff_to_truth = jnp.linalg.norm(model["matrix"](pos) - matrix_truth)
+    print(
+        (
+            f"Post MGVI Iteration {i}: Energy {Evag(pos)[0]:2.4e}"
+            f"; diff to truth {diff_to_truth}"
+        ),
+        file=sys.stderr
+    )
+
+posterior_samps = [s + pos for s in samples]
+
+matrix_samps = jnp.array([model["matrix"](s) for s in posterior_samps])
+matrix_mean = jnp.mean(matrix_samps, axis=0)
+matrix_std = jnp.std(matrix_samps, axis=0)
+xx = jnp.linspace(-3.5, 3.5, 2)
+plt.plot(xx, xx)
+plt.errorbar(
+    matrix_truth.reshape(-1),
+    matrix_mean.reshape(-1),
+    yerr=matrix_std.reshape(-1),
+    fmt='o',
+    color="black"
+)
+plt.xlabel("truth")
+plt.ylabel("inferred value")
+plt.savefig("matrix_fit.png", dpi=400)
+plt.close()
diff --git a/demos/re/correlated_field_w_known_spectrum.py b/demos/re/correlated_field_w_known_spectrum.py
new file mode 100644
index 0000000000000000000000000000000000000000..952572305fa2e642bf5dcc16cd1651312d313832
--- /dev/null
+++ b/demos/re/correlated_field_w_known_spectrum.py
@@ -0,0 +1,142 @@
+#!/usr/bin/env python3
+
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial
+import sys
+
+from jax import numpy as jnp
+from jax import random
+from jax import jit
+from jax.config import config
+import matplotlib.pyplot as plt
+
+import nifty8.re as jft
+
+config.update("jax_enable_x64", True)
+
+
+@jit
+def cosine_similarity(x, y):
+    return jnp.dot(x, y) / (jnp.linalg.norm(x) * jnp.linalg.norm(y))
+
+
+def hartley(p, axes=None):
+    from jax.numpy import fft
+
+    tmp = fft.fftn(p, axes)
+    return tmp.real + tmp.imag
+
+
+seed = 42
+key = random.PRNGKey(seed)
+
+dims = (1024, )
+
+n_mgvi_iterations = 3
+n_samples = 4
+n_newton_iterations = 5
+absdelta = 1e-4 * jnp.prod(jnp.array(dims))
+
+cf = {"loglogavgslope": 2.}
+loglogslope = cf["loglogavgslope"]
+power_spectrum = lambda k: 1. / (k**loglogslope + 1.)
+
+modes = jnp.arange((dims[0] / 2) + 1., dtype=float)
+harmonic_power = power_spectrum(modes)
+# Every mode appears exactly two times, first ascending then descending
+# Save a little on the computational side by mirroring the ascending part
+harmonic_power = jnp.concatenate((harmonic_power, harmonic_power[-2:0:-1]))
+
+# Specify the model
+correlated_field = lambda x: hartley(harmonic_power * x.val)
+signal_response = lambda x: jnp.exp(1. + correlated_field(x))
+noise_cov = lambda x: 0.1**2 * x
+noise_cov_inv = lambda x: 0.1**-2 * x
+
+# Create synthetic data
+key, subkey = random.split(key)
+pos_truth = jft.Field(random.normal(shape=dims, key=key))
+signal_response_truth = signal_response(pos_truth)
+key, subkey = random.split(key)
+noise_truth = jnp.sqrt(noise_cov(jnp.ones(dims))
+                      ) * random.normal(shape=dims, key=key)
+data = signal_response_truth + noise_truth
+
+nll = jft.Gaussian(data, noise_cov_inv) @ signal_response
+ham = jft.StandardHamiltonian(likelihood=nll).jit()
+
+key, subkey = random.split(key)
+pos_init = random.normal(shape=dims, key=subkey)
+pos = 1e-2 * jft.Field(pos_init)
+
+ham_vg = jit(jft.mean_value_and_grad(ham))
+ham_metric = jit(jft.mean_metric(ham.metric))
+MetricKL = jit(
+    partial(jft.MetricKL, ham),
+    static_argnames=("n_samples", "mirror_samples", "linear_sampling_name")
+)
+
+# Minimize the potential
+for i in range(n_mgvi_iterations):
+    print(f"MGVI Iteration {i}", file=sys.stderr)
+    print("Sampling...", file=sys.stderr)
+    key, subkey = random.split(key, 2)
+    samples = MetricKL(
+        pos,
+        n_samples=n_samples,
+        key=subkey,
+        mirror_samples=True,
+        linear_sampling_kwargs={"absdelta": absdelta / 10.}
+    )
+
+    print("Minimizing...", file=sys.stderr)
+    opt_state = jft.minimize(
+        None,
+        x0=pos,
+        method="trust-ncg",
+        options={
+            "fun_and_grad": partial(ham_vg, primals_samples=samples),
+            "hessp": partial(ham_metric, primals_samples=samples),
+            "initial_trust_radius": 1e+1,
+            "max_trust_radius": 1e+4,
+            "absdelta": absdelta,
+            "maxiter": n_newton_iterations,
+            "name": "N",
+            "subproblem_kwargs": {
+                "miniter": 6,
+            }
+        }
+    )
+    # opt_state = jft.minimize(
+    #     None,
+    #     x0=pos,
+    #     method="newton-cg",
+    #     options={
+    #         "fun_and_grad": partial(ham_vg, primals_samples=samples),
+    #         "hessp": partial(ham_metric, primals_samples=samples),
+    #         "absdelta": absdelta,
+    #         "maxiter": n_newton_iterations
+    #     }
+    # )
+    pos = opt_state.x
+    print(
+        (
+            f"Post MGVI Iteration {i}: Energy {samples.at(pos).mean(ham):2.4e}"
+            f"; Cos-Sim {cosine_similarity(pos.val, pos_truth.val):2.3%}"
+            f"; #NaNs {jnp.isnan(pos.val).sum()}"
+        ),
+        file=sys.stderr
+    )
+
+post_sr_mean = jft.mean(tuple(signal_response(s) for s in samples.at(pos)))
+fig, ax = plt.subplots()
+ax.plot(signal_response_truth, alpha=0.7, label="Signal")
+ax.plot(noise_truth, alpha=0.7, label="Noise")
+ax.plot(data, alpha=0.7, label="Data")
+ax.plot(post_sr_mean, alpha=0.7, label="Reconstruction")
+ax.legend()
+fig.tight_layout()
+fig.savefig("cf_w_known_spectrum.png", dpi=400)
+plt.close()
diff --git a/demos/re/correlated_field_w_unknown_factorizing_spectra.py b/demos/re/correlated_field_w_unknown_factorizing_spectra.py
new file mode 100644
index 0000000000000000000000000000000000000000..76e246479bf954aa5703ad65073ed857585dff87
--- /dev/null
+++ b/demos/re/correlated_field_w_unknown_factorizing_spectra.py
@@ -0,0 +1,128 @@
+#!/usr/bin/env python3
+
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial
+import sys
+
+from jax import numpy as jnp
+from jax import random
+from jax import jit
+from jax.config import config as jax_config
+import matplotlib.pyplot as plt
+
+import nifty8.re as jft
+
+jax_config.update("jax_enable_x64", True)
+
+seed = 42
+key = random.PRNGKey(seed)
+
+dims_ax1 = (128, )
+dims_ax2 = (256, )
+
+n_mgvi_iterations = 3
+n_samples = 4
+n_newton_iterations = 10
+absdelta = 1e-4 * jnp.prod(jnp.array(dims_ax1 + dims_ax2))
+
+cf_zm = {"offset_mean": 0., "offset_std": (1e-3, 1e-4)}
+cf_fl = {
+    "fluctuations": (1e-1, 5e-3),
+    "loglogavgslope": (-1., 1e-2),
+    "flexibility": (1e+0, 5e-1),
+    "asperity": (5e-1, 1e-1),
+    "harmonic_domain_type": "Fourier"
+}
+cfm = jft.CorrelatedFieldMaker("cf")
+cfm.set_amplitude_total_offset(**cf_zm)
+d = 1. / dims_ax1[0]
+cfm.add_fluctuations(dims_ax1, distances=d, **cf_fl, prefix="ax1")
+d = 1. / dims_ax2[0]
+cfm.add_fluctuations(dims_ax2, distances=d, **cf_fl, prefix="ax2")
+correlated_field, ptree = cfm.finalize()
+
+signal_response = lambda x: correlated_field(x)
+noise_cov = lambda x: 0.1**2 * x
+noise_cov_inv = lambda x: 0.1**-2 * x
+
+# Create synthetic data
+key, subkey = random.split(key)
+pos_truth = jft.random_like(subkey, ptree)
+signal_response_truth = signal_response(pos_truth)
+key, subkey = random.split(key)
+noise_truth = jnp.sqrt(
+    noise_cov(jnp.ones(signal_response_truth.shape))
+) * random.normal(shape=signal_response_truth.shape, key=key)
+data = signal_response_truth + noise_truth
+
+nll = jft.Gaussian(data, noise_cov_inv) @ signal_response
+ham = jft.StandardHamiltonian(likelihood=nll).jit()
+
+ham_vg = jit(jft.mean_value_and_grad(ham))
+ham_metric = jit(jft.mean_metric(ham.metric))
+MetricKL = jit(
+    partial(jft.MetricKL, ham),
+    static_argnames=("n_samples", "mirror_samples", "linear_sampling_name")
+)
+
+key, subkey = random.split(key)
+pos_init = jft.random_like(subkey, ptree)
+pos = 1e-2 * jft.Field(pos_init)
+
+# Minimize the potential
+for i in range(n_mgvi_iterations):
+    print(f"MGVI Iteration {i}", file=sys.stderr)
+    print("Sampling...", file=sys.stderr)
+    key, subkey = random.split(key, 2)
+    samples = MetricKL(
+        pos,
+        n_samples=n_samples,
+        key=subkey,
+        mirror_samples=True,
+        linear_sampling_kwargs={"absdelta": absdelta / 10.}
+    )
+
+    print("Minimizing...", file=sys.stderr)
+    opt_state = jft.minimize(
+        None,
+        x0=pos,
+        method="newton-cg",
+        options={
+            "fun_and_grad": partial(ham_vg, primals_samples=samples),
+            "hessp": partial(ham_metric, primals_samples=samples),
+            "absdelta": absdelta,
+            "maxiter": n_newton_iterations
+        }
+    )
+    pos = opt_state.x
+    msg = f"Post MGVI Iteration {i}: Energy {samples.at(pos).mean(ham):2.4e}"
+    print(msg, file=sys.stderr)
+
+namps = cfm.get_normalized_amplitudes()
+post_sr_mean = jft.mean(tuple(signal_response(s) for s in samples.at(pos)))
+post_namps1_mean = jft.mean(tuple(namps[0](s)[1:] for s in samples.at(pos)))
+post_namps2_mean = jft.mean(tuple(namps[1](s)[1:] for s in samples.at(pos)))
+to_plot = [
+    ("Signal", signal_response_truth, "im"),
+    ("Noise", noise_truth, "im"),
+    ("Data", data, "im"),
+    ("Reconstruction", post_sr_mean, "im"),
+    ("Ax1", (namps[0](pos_truth)[1:], post_namps1_mean), "loglog"),
+    ("Ax2", (namps[1](pos_truth)[1:], post_namps2_mean), "loglog"),
+]
+fig, axs = plt.subplots(2, 3, figsize=(16, 9))
+for ax, (title, field, tp) in zip(axs.flat, to_plot):
+    ax.set_title(title)
+    if tp == "im":
+        im = ax.imshow(field, cmap="inferno")
+        plt.colorbar(im, ax=ax, orientation="horizontal")
+    else:
+        ax_plot = ax.loglog if tp == "loglog" else ax.plot
+        field = field if isinstance(field, (tuple, list)) else (field, )
+        for f in field:
+            ax_plot(f, alpha=0.7)
+fig.tight_layout()
+fig.savefig("cf_w_unknown_factorizing_spectra.png", dpi=400)
+plt.close()
diff --git a/demos/re/correlated_field_w_unknown_spectrum.py b/demos/re/correlated_field_w_unknown_spectrum.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5254b753e918b60a4f4fe71ef292d71eba6693c
--- /dev/null
+++ b/demos/re/correlated_field_w_unknown_spectrum.py
@@ -0,0 +1,122 @@
+#!/usr/bin/env python3
+
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial
+import sys
+
+from jax import numpy as jnp
+from jax import random
+from jax import jit
+from jax.config import config
+import matplotlib.pyplot as plt
+
+import nifty8.re as jft
+
+config.update("jax_enable_x64", True)
+
+seed = 42
+key = random.PRNGKey(seed)
+
+dims = (256, 256)
+
+n_mgvi_iterations = 3
+n_samples = 4
+n_newton_iterations = 10
+absdelta = 1e-4 * jnp.prod(jnp.array(dims))
+
+cf_zm = {"offset_mean": 0., "offset_std": (1e-3, 1e-4)}
+cf_fl = {
+    "fluctuations": (1e-1, 5e-3),
+    "loglogavgslope": (-1., 1e-2),
+    "flexibility": (1e+0, 5e-1),
+    "asperity": (5e-1, 5e-2),
+    "harmonic_domain_type": "Fourier"
+}
+cfm = jft.CorrelatedFieldMaker("cf")
+cfm.set_amplitude_total_offset(**cf_zm)
+cfm.add_fluctuations(dims, distances=1. / dims[0], **cf_fl, prefix="ax1")
+correlated_field, ptree = cfm.finalize()
+
+signal_response = lambda x: jnp.exp(correlated_field(x))
+noise_cov = lambda x: 0.1**2 * x
+noise_cov_inv = lambda x: 0.1**-2 * x
+
+# Create synthetic data
+key, subkey = random.split(key)
+pos_truth = jft.random_like(subkey, ptree)
+signal_response_truth = signal_response(pos_truth)
+key, subkey = random.split(key)
+noise_truth = jnp.sqrt(noise_cov(jnp.ones(dims))
+                      ) * random.normal(shape=dims, key=key)
+data = signal_response_truth + noise_truth
+
+nll = jft.Gaussian(data, noise_cov_inv) @ signal_response
+ham = jft.StandardHamiltonian(likelihood=nll).jit()
+
+ham_vg = jit(jft.mean_value_and_grad(ham))
+ham_metric = jit(jft.mean_metric(ham.metric))
+MetricKL = jit(
+    partial(jft.MetricKL, ham),
+    static_argnames=("n_samples", "mirror_samples", "linear_sampling_name")
+)
+
+key, subkey = random.split(key)
+pos_init = jft.random_like(subkey, ptree)
+pos = 1e-2 * jft.Field(pos_init.copy())
+
+# Minimize the potential
+for i in range(n_mgvi_iterations):
+    print(f"MGVI Iteration {i}", file=sys.stderr)
+    print("Sampling...", file=sys.stderr)
+    key, subkey = random.split(key, 2)
+    samples = jft.MetricKL(
+        ham,
+        pos,
+        n_samples=n_samples,
+        key=subkey,
+        mirror_samples=True,
+        linear_sampling_kwargs={"absdelta": absdelta / 10.}
+    )
+
+    print("Minimizing...", file=sys.stderr)
+    opt_state = jft.minimize(
+        None,
+        pos,
+        method="newton-cg",
+        options={
+            "fun_and_grad": partial(ham_vg, primals_samples=samples),
+            "hessp": partial(ham_metric, primals_samples=samples),
+            "absdelta": absdelta,
+            "maxiter": n_newton_iterations
+        }
+    )
+    pos = opt_state.x
+    msg = f"Post MGVI Iteration {i}: Energy {samples.at(pos).mean(ham):2.4e}"
+    print(msg, file=sys.stderr)
+
+namps = cfm.get_normalized_amplitudes()
+post_sr_mean = jft.mean(tuple(signal_response(s) for s in samples.at(pos)))
+post_a_mean = jft.mean(tuple(cfm.amplitude(s)[1:] for s in samples.at(pos)))
+to_plot = [
+    ("Signal", signal_response_truth, "im"),
+    ("Noise", noise_truth, "im"),
+    ("Data", data, "im"),
+    ("Reconstruction", post_sr_mean, "im"),
+    ("Ax1", (cfm.amplitude(pos_truth)[1:], post_a_mean), "loglog"),
+]
+fig, axs = plt.subplots(2, 3, figsize=(16, 9))
+for ax, (title, field, tp) in zip(axs.flat, to_plot):
+    ax.set_title(title)
+    if tp == "im":
+        im = ax.imshow(field, cmap="inferno")
+        plt.colorbar(im, ax=ax, orientation="horizontal")
+    else:
+        ax_plot = ax.loglog if tp == "loglog" else ax.plot
+        field = field if isinstance(field, (tuple, list)) else (field, )
+        for f in field:
+            ax_plot(f, alpha=0.7)
+fig.tight_layout()
+fig.savefig("cf_w_unknown_spectrum.png", dpi=400)
+plt.close()
diff --git a/demos/re/graph_refine.py b/demos/re/graph_refine.py
new file mode 100644
index 0000000000000000000000000000000000000000..edc5e701cc134ee47a6058421d51ce7ca0b7f606
--- /dev/null
+++ b/demos/re/graph_refine.py
@@ -0,0 +1,141 @@
+#!/usr/bin/env python3
+
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+import jax
+from jax import numpy as jnp
+from jax import random
+import matplotlib.pyplot as plt
+import numpy as np
+
+import nifty8.re as jft
+
+
+def get_kernel(layer_weights, depth, n_samples=100):
+    xi = {"offset": 0., "layer_weights": layer_weights}
+    kernel = np.zeros(2**depth)
+    for _ in range(n_samples):
+        xi["excitations"] = jnp.array(rng.normal(size=(2**depth, )))
+        r = fwd(xi)
+        for i in range(r.size):
+            kernel[i] += np.mean(r * np.roll(r, i))
+    kernel /= len(n_samples)
+    return kernel
+
+
+def fwd(xi):
+    offset = xi["offset"]
+    excitations = xi["excitations"]
+    layer_wgt = xi["layer_weights"]
+
+    kernel = jnp.array([1., 2., 1.])
+    kernel /= kernel.sum()
+    layers = [excitations]
+    while layers[-1].size > 1:
+        lvl = layers[-1]
+        if layers[-1].size > 2:
+            lvl = jnp.convolve(lvl, kernel, mode="same")
+        layers += [0.5 * lvl.reshape(-1, 2).sum(axis=1)]
+    if len(layers) != len(layer_wgt):
+        raise ValueError()
+
+    field = offset
+    for d, (wgt, lvl) in enumerate(zip(layer_wgt, layers)):
+        field += wgt * jnp.repeat(lvl, 2**d)
+
+    return field
+
+
+# %%
+rng = np.random.default_rng(42)
+depth = 8
+
+for _ in range(10):
+    layer_weights = jnp.array(rng.normal(size=(depth + 1, )))
+    layer_weights = jnp.exp(0.1 * layer_weights)  #lognomral
+    kernel = get_kernel(layer_weights, depth, n_samples=30)
+    plt.plot(kernel)
+plt.show()
+
+# %%
+spec = np.fft.fft(kernel)
+plt.plot(spec)
+plt.yscale("log")
+plt.xscale("log")
+plt.show()
+
+# %%
+
+rng = np.random.default_rng(42)
+depth = 12
+xi = {
+    "offset": 0.,
+    "excitations": jnp.array(rng.normal(size=(2**depth, ))),
+    # "layer_weights": jnp.exp(+jnp.array(rng.normal(size=(depth + 1, )))),
+    "layer_weights": jnp.exp(0.1 * jnp.arange(depth + 1, dtype=float)),
+    # "layer_weights": jnp.array([1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,]),
+}
+
+plt.plot(xi["excitations"], label="excitations", alpha=0.6)
+plt.plot(fwd(xi), label="Forward Model", alpha=0.6)
+plt.legend()
+plt.show()
+
+# %%
+cf_zm = {"offset_mean": 0., "offset_std": (1e-3, 1e-4)}
+cf_fl = {
+    "fluctuations": (1e-1, 5e-3),
+    "loglogavgslope": (-1.5, 1e-2),
+    "flexibility": (5e-1, 1e-1),
+    "asperity": (5e-1, 5e-2),
+    "harmonic_domain_type": "Fourier"
+}
+dims = jnp.array([2**depth])
+
+cfm = jft.CorrelatedFieldMaker("cf")
+cfm.set_amplitude_total_offset(**cf_zm)
+cfm.add_fluctuations(dims, distances=1. / dims.shape[0], **cf_fl, prefix="ax1")
+correlated_field, ptree = cfm.finalize()
+key = random.PRNGKey(42)
+
+pos_truth = jft.random_like(key, ptree)
+plt.plot(correlated_field(pos_truth))
+plt.show()
+
+# %%
+d = correlated_field(pos_truth)
+lh = jax.jit(
+    lambda x: ((d - fwd(x))**2).sum() +
+    sum([(el**2).sum() for el in jax.tree_util.tree_leaves(x)])
+)
+print(lh(xi))
+
+# %%
+opt_state = jft.minimize(
+    lh,
+    jft.Field(xi),
+    method="newton-cg",
+    options={
+        "name": "N",
+        "absdelta": 0.1,
+        "maxiter": 30
+    }
+)
+
+# %%
+plt.plot(correlated_field(pos_truth), label="truth")
+plt.plot(fwd(opt_state.x), label="reconstruction")
+plt.legend()
+plt.show()
+
+# %%
+pos_rec = opt_state.x.val.copy()
+pos_rec["layer_weights"] = pos_rec["layer_weights"].at[:-8].set(0.)
+
+pos_truth = jft.random_like(key, ptree)
+plt.plot(correlated_field(pos_truth), alpha=0.7, label="truth")
+plt.plot(fwd(opt_state.x), alpha=0.7, label="reconstruction")
+plt.plot(fwd(pos_rec), alpha=0.7, label="reconstruction coarse")
+plt.legend()
+plt.show()
diff --git a/demos/re/hmc_multimodality.py b/demos/re/hmc_multimodality.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ee563fb545b40e912d5673a13787c2e2c61771d
--- /dev/null
+++ b/demos/re/hmc_multimodality.py
@@ -0,0 +1,106 @@
+#!/usr/bin/env python3
+
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+# %%
+from functools import partial
+
+import jax.numpy as jnp
+import matplotlib.pyplot as plt
+
+import nifty8.re as jft
+
+
+def loggaussian(x, mu, sigma):
+    return -0.5 * (x - mu)**2 / sigma
+
+
+def sum_of_gaussians(x, separation, sigma1, sigma2):
+    return -jnp.logaddexp(
+        loggaussian(x, 0, sigma1), loggaussian(x, separation, sigma2)
+    )
+
+
+ham = partial(sum_of_gaussians, separation=10., sigma1=1., sigma2=1.)
+
+N = 100000
+SEED = 43
+EPS = 0.3
+
+subplots = (2, 2)
+fig_width_pt = 426  # pt (a4paper, and such)
+# fig_width_pt = 360 # pt
+inches_per_pt = 1 / 72.27
+fig_width_in = 0.9 * fig_width_pt * inches_per_pt
+fig_height_in = fig_width_in * 0.618 * (subplots[0] / subplots[1])
+fig_dims = (fig_width_in, fig_height_in * 1.5)
+
+fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(
+    subplots[0],
+    subplots[1],
+    sharex='col',
+    figsize=fig_dims,
+    gridspec_kw={'width_ratios': [1, 2]}
+)
+
+# %%
+nuts_sampler = jft.NUTSChain(
+    potential_energy=ham,
+    inverse_mass_matrix=5.,
+    position_proto=jnp.array(0.),
+    step_size=EPS,
+    max_tree_depth=15,
+    max_energy_difference=1000.,
+)
+
+chain, _ = nuts_sampler.generate_n_samples(
+    SEED, jnp.array(3.), num_samples=N, save_intermediates=True
+)
+print(f"small mass matrix acceptance: {chain.acceptance}")
+
+ax1.hist(chain.samples, bins=30, density=True)
+ax2.plot(chain.samples, linewidth=0.5)
+
+ax1.set_title(rf'$m={1. / nuts_sampler.inverse_mass_matrix:1.2f}$')
+ax2.set_title(rf'$m={1. / nuts_sampler.inverse_mass_matrix:1.2f}$')
+
+# %%
+nuts_sampler = jft.NUTSChain(
+    potential_energy=ham,
+    inverse_mass_matrix=50.,
+    position_proto=jnp.array(0.),
+    step_size=EPS,
+    max_tree_depth=15,
+    max_energy_difference=1000.,
+)
+
+chain, _ = nuts_sampler.generate_n_samples(
+    SEED, jnp.array(3.), num_samples=N, save_intermediates=True
+)
+print(f"large mass matrix acceptance: {chain.acceptance}")
+
+ax3.hist(chain.samples, bins=30, density=True)
+ax4.plot(chain.samples, linewidth=0.5)
+
+ax3.set_title(rf'$m={1. / nuts_sampler.inverse_mass_matrix:1.2f}$')
+ax4.set_title(rf'$m={1. / nuts_sampler.inverse_mass_matrix:1.2f}$')
+
+# %%
+xs = jnp.linspace(-10, 20, num=500)
+Z = jnp.trapz(jnp.exp(-ham(xs)), xs)
+ax1.plot(xs, jnp.exp(-ham(xs)) / Z, linewidth=0.5, c='r')
+ax3.plot(xs, jnp.exp(-ham(xs)) / Z, linewidth=0.5, c='r')
+
+ax1.set_ylabel('frequency')
+ax2.set_ylabel('position')
+ax3.set_xlabel('position')
+ax3.set_ylabel('frequency')
+ax4.set_xlabel('time')
+ax4.set_ylabel('position')
+
+#fig.suptitle("sum of two Gaussians, with different choices of mass matrix")
+
+fig.tight_layout()
+fig.savefig("multimodal.pdf", bbox_inches='tight')
+print("final figure saved as multimodal.pdf")
diff --git a/demos/re/hmc_nuts_trajectories.py b/demos/re/hmc_nuts_trajectories.py
new file mode 100644
index 0000000000000000000000000000000000000000..340fc8ca2eb40caab298c67f8fe716aae48c114a
--- /dev/null
+++ b/demos/re/hmc_nuts_trajectories.py
@@ -0,0 +1,111 @@
+#!/usr/bin/env python3
+
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+# %%
+#
+# WARNING: This code does not behave deterministically. It works fine when
+# executing cell by cell using vscodes notebook functionality but when running
+# from the command line with either python3 or ipython3 the following happens:
+# This is probably due to an issue with host_callback.
+# Concretely it just stops adding points to the debug list after some random
+# number of leapfrog steps.
+#
+
+import jax.numpy as jnp
+import matplotlib
+import matplotlib.pyplot as plt
+
+import nifty8.re as jft
+
+# %%
+jft.hmc._DEBUG_FLAG = True
+
+# %%
+cov = jnp.array([10., 1.])
+
+potential_energy = lambda q: jnp.sum(0.5 * q**2 / cov)
+
+initial_position = jnp.array([1., 1.])
+
+sampler = jft.NUTSChain(
+    potential_energy=potential_energy,
+    inverse_mass_matrix=1.,
+    position_proto=initial_position,
+    step_size=0.12,
+    max_tree_depth=10,
+)
+
+# %%
+jft.hmc._DEBUG_STORE = []
+jft.hmc._DEBUG_TREE_END_IDXS = []
+jft.hmc._DEBUG_SUBTREE_END_IDXS = []
+
+chain, _ = sampler.generate_n_samples(
+    48, initial_position, num_samples=5, save_intermediates=True
+)
+
+plt.hist(chain.depths)
+plt.show()
+
+# %%
+debug_pos = jnp.array([qp.position for qp in jft.hmc._DEBUG_STORE])
+print(len(debug_pos))
+
+# %%
+prop_cycle = plt.rcParams['axes.prop_cycle']
+colors = prop_cycle.by_key()['color']
+
+ax = plt.gca()
+ellipse = matplotlib.patches.Ellipse(
+    xy=(0, 0),
+    width=jnp.sqrt(cov[0]),
+    height=jnp.sqrt(cov[1]),
+    edgecolor='k',
+    fc='None',
+    lw=1
+)
+ax.add_patch(ellipse)
+
+color_idx = 0
+start_and_end_idxs = zip(
+    [
+        0,
+    ] + jft.hmc._DEBUG_SUBTREE_END_IDXS[:-1], jft.hmc._DEBUG_SUBTREE_END_IDXS
+)
+for start_idx, end_idx in start_and_end_idxs:
+    slice = debug_pos[start_idx:end_idx]
+    ax.plot(
+        slice[:, 0],
+        slice[:, 1],
+        '-o',
+        markersize=1,
+        linewidth=0.5,
+        color=colors[color_idx % len(colors)]
+    )
+    if end_idx in jft.hmc._DEBUG_TREE_END_IDXS:
+        color_idx = (color_idx + 1) % len(colors)
+
+ax.scatter(
+    chain.samples[:, 0],
+    chain.samples[:, 1],
+    marker='x',
+    color='k',
+    label='samples'
+)
+ax.scatter(initial_position[0], initial_position[1], label='starting position')
+ax.set_xlabel('x')
+ax.set_ylabel('y')
+ax.legend()
+
+fig_width_pt = 426  # pt (a4paper, and such)
+# fig_width_pt = 360 # pt
+inches_per_pt = 1 / 72.27
+fig_width_in = 0.9 * fig_width_pt * inches_per_pt
+fig_height_in = fig_width_in * 0.618
+fig_dims = (fig_width_in, fig_height_in)
+
+plt.tight_layout()
+plt.show()
+plt.savefig("trajectories.pdf", bbox_inches='tight')
diff --git a/demos/re/hmc_wiener_filter.py b/demos/re/hmc_wiener_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..8664e9937d44de8aac90e672176703768cb6a02c
--- /dev/null
+++ b/demos/re/hmc_wiener_filter.py
@@ -0,0 +1,314 @@
+#!/usr/bin/env python3
+
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+#%%
+from jax import numpy as jnp
+from jax import lax, random
+import jax
+from jax.config import config
+import matplotlib
+import matplotlib.pyplot as plt
+
+import nifty8.re as jft
+
+config.update("jax_enable_x64", True)
+
+matplotlib.rcParams['figure.figsize'] = (10, 7)
+
+#%%
+dims = (512, )
+#datadims = (4,)
+loglogslope = 2.
+power_spectrum = lambda k: 1. / (k**loglogslope + 1.)
+modes = jnp.arange((dims[0] / 2) + 1., dtype=float)
+harmonic_power = power_spectrum(modes)
+harmonic_power = jnp.concatenate((harmonic_power, harmonic_power[-2:0:-1]))
+
+#%%
+correlated_field = lambda x: jft.correlated_field.hartley(
+    # x is a signal in fourier space
+    # each modes amplitude gets multiplied by it's harmonic_power
+    # and the whole signal is transformed back
+    harmonic_power * x
+)
+
+# %% [markdown]
+# signal_response = lambda x: jnp.exp(1. + correlated_field(x))
+signal_response = lambda x: correlated_field(x)
+# The signal response is $ \vec{d} = \begin{pmatrix} 1 \\ 1 \\ 1 \\ 1 \end{pmatrix} \cdot s + \vec{n} $ where $s \in \mathbb{R}$ and $\vec{n} \sim \mathcal{G}(0, N)$
+# signal_response = lambda x: jnp.ones(shape=datadims) * x
+# ???
+noise_cov_inv_sqrt = lambda x: 1.**-1 * x
+
+#%%
+# create synthetic data
+seed = 43
+key = random.PRNGKey(seed)
+key, subkey = random.split(key)
+# normal random fourier amplitude
+pos_truth = random.normal(shape=dims, key=subkey)
+signal_response_truth = signal_response(pos_truth)
+key, subkey = random.split(key)
+# 1. / noise_cov_inv_sqrt(jnp.ones(dims)) becomes the standard deviation of the noise gaussian
+noise_truth = 1. / noise_cov_inv_sqrt(jnp.ones(dims)
+                                     ) * random.normal(shape=dims, key=subkey)
+data = signal_response_truth + noise_truth
+
+#%%
+plt.plot(signal_response_truth, label='signal response')
+#plt.plot(noise_truth, label='noise', linewidth=0.5)
+plt.plot(data, 'k.', label='noisy data', markersize=4.)
+plt.xlabel('real space domain')
+plt.ylabel('field value')
+plt.legend()
+plt.title("signal and data")
+plt.show()
+
+
+#%%
+def Gaussian(data, noise_cov_inv_sqrt):
+    # Simple but not very generic Gaussian energy
+    # primals
+    def hamiltonian(primals):
+        p_res = primals - data
+        # TODO: is this the weighting with noies amplitude thing again?
+        l_res = noise_cov_inv_sqrt(p_res)
+        return 0.5 * jnp.sum(l_res**2)
+
+    return jft.Likelihood(hamiltonian, )
+
+
+# negative log likelihood
+nll = Gaussian(data, noise_cov_inv_sqrt) @ signal_response
+
+#%%
+ham = jft.StandardHamiltonian(likelihood=nll)
+ham_gradient = jax.grad(ham)
+
+
+# %% [markdown]
+def plot_mean_and_stddev(ax, samples, mean_of_r=None, truth=False, **kwargs):
+    signal_response_of_samples = lax.map(signal_response, samples)
+    if mean_of_r == None:
+        mean_of_signal_response = jnp.mean(signal_response_of_samples, axis=0)
+    else:
+        mean_of_signal_response = mean_of_r
+    mean_label = kwargs.pop('mean_label', 'sample mean of signal response')
+    ax.plot(mean_of_signal_response, label=mean_label)
+    std_dev_of_signal_response = jnp.std(signal_response_of_samples, axis=0)
+    if truth:
+        ax.plot(signal_response_truth, label="truth")
+    ax.fill_between(
+        jnp.arange(len(mean_of_signal_response)),
+        y1=mean_of_signal_response - std_dev_of_signal_response,
+        y2=mean_of_signal_response + std_dev_of_signal_response,
+        color='grey',
+        alpha=0.5
+    )
+    title = kwargs.pop('title', 'position samples')
+    if title is not None:
+        ax.set_title(title)
+    xlabel = kwargs.pop('xlabel', 'position')
+    if xlabel is not None:
+        ax.set_xlabel(xlabel)
+    ylabel = kwargs.pop('ylabel', 'signal response')
+    if ylabel is not None:
+        ax.set_ylabel(ylabel)
+    ax.legend(loc='lower right', fontsize=8)
+
+
+#%%
+key, subkey = random.split(key)
+initial_position = random.uniform(key=subkey, shape=pos_truth.shape)
+
+sampler = jft.HMCChain(
+    potential_energy=ham,
+    inverse_mass_matrix=1.,
+    position_proto=initial_position,
+    step_size=0.05,
+    num_steps=128,
+)
+
+chain, _ = sampler.generate_n_samples(
+    42, initial_position, num_samples=30, save_intermediates=True
+)
+print(f"acceptance ratio: {chain.acceptance}")
+
+# %%
+plot_mean_and_stddev(plt.gca(), chain.samples, truth=True)
+plt.title("HMC position samples")
+plt.show()
+
+# %% [markdown]
+# # NUTS
+jft.hmc._DEBUG_STORE = []
+
+sampler = jft.NUTSChain(
+    position_proto=initial_position,
+    potential_energy=ham,
+    inverse_mass_matrix=1.,
+    # 0.9193 # integrates to ~3-7, very smooth sample mean
+    # 0.8193 # integrates to depth ~22, very noisy sample mean
+    step_size=0.05,
+    max_tree_depth=17,
+)
+
+chain, _ = sampler.generate_n_samples(
+    42, initial_position, num_samples=30, save_intermediates=True
+)
+plt.hist(chain.depths, bins=jnp.arange(sampler.max_tree_depth + 2))
+plt.title('NUTS tree depth histogram')
+plt.xlabel('tree depth')
+plt.ylabel('count')
+plt.show()
+
+# %%
+plot_mean_and_stddev(plt.gca(), chain.samples, truth=True)
+plt.title("NUTS position samples")
+plt.show()
+
+# %%
+if jft.hmc._DEBUG_FLAG:
+    debug_pos = jnp.array(jft.hmc._DEBUG_STORE)[:, 0, :]
+
+    for idx, dbgp in enumerate(debug_pos):
+        plt.plot(signal_response(dbgp), label=f'{idx}', alpha=0.1)
+    #plt.legend()
+
+    # %%
+    debug_pos_x = jnp.array(jft.hmc._DEBUG_STORE)[:, 0, 0]
+    debug_pos_y = jnp.array(jft.hmc._DEBUG_STORE)[:, 0, 1]
+    for idx, dbgp in enumerate(debug_pos):
+        plt.scatter(debug_pos_x, debug_pos_y, s=0.1, color='k')
+    #plt.legend()
+    plt.show()
+
+# %%[markdown]
+# # 1D position and momentum time series
+if chain.samples[0].shape == (1, ):
+    plt.plot(chain.samples, label='position')
+    #plt.plot(momentum_samples, label='momentum', linewidth=0.2)
+    #plt.plot(unintegrated_momenta, label='unintegrated momentum', linewidth=0.2)
+    plt.title('position and momentum time series')
+    plt.xlabel('time')
+    plt.ylabel('position, momentum')
+    plt.legend()
+    plt.show()
+
+# %% [markdown]
+# # energy time series
+potential_energies = lax.map(ham, chain.samples)
+kinetic_energies = jnp.sum(chain.trees.proposal_candidate.momentum**2, axis=1)
+#rejected_potential_energies = lax.map(ham, rejected_position_samples)
+#rejected_kinetic_energies = jnp.sum(rejected_momentum_samples**2, axis=1)
+plt.plot(potential_energies, label='pot')
+plt.plot(kinetic_energies, label='kin', linewidth=1)
+plt.plot(kinetic_energies + potential_energies, label='total', linewidth=1)
+#plt.plot(rejected_potential_energies , label='rejected_pot')
+#plt.plot(rejected_kinetic_energies , label='rejected_kin', linewidth=2)
+#plt.plot(rejected_kinetic_energies + rejected_potential_energies, label='rejected_total', linewidth=0.2)
+plt.title('NUTS energy time series')
+plt.xlabel('time')
+plt.ylabel('energy')
+plt.yscale('log')
+plt.legend()
+plt.show()
+
+# %% [markdown]
+# # Wiener Filter
+
+# jax.linear_transpose for R^\dagger
+# square noise_sqrt_inv ... for N^-1
+# S is unit due to StandardHamiltonian
+# jax.scipy.sparse.linalg.cg for D
+
+# signal_response(s) is only needed for shape of data space
+_impl_signal_response_dagger = jax.linear_transpose(signal_response, pos_truth)
+signal_response_dagger = lambda d: _impl_signal_response_dagger(d)[0]
+# noise_cov_inv_sqrt is diagonal
+noise_cov_inv = lambda d: noise_cov_inv_sqrt(noise_cov_inv_sqrt(d))
+
+# signal prior covariance S is assumed to be unit (see StandardHamiltonian)
+# the tranposed function wierdly returns a (1,)-tuple which we unpack right here
+D_inv = lambda s: s + signal_response_dagger(noise_cov_inv(signal_response(s)))
+
+j = signal_response_dagger(noise_cov_inv(data))
+
+m, _ = jax.scipy.sparse.linalg.cg(D_inv, j)
+
+# %%
+
+# TODO fix labels
+plt.plot(signal_response(m), label='signal response of mean')
+plt.plot(signal_response_truth, label='true signal response')
+plt.legend()
+plt.title('Wiener Filter')
+plt.show()
+
+
+# %%
+def sample_from_d_inv(key):
+    s_inv_key, rnr_key = random.split(key)
+    s_inv_smpl = random.normal(s_inv_key, pos_truth.shape)
+    # random.normal sample from dataspace and then R^\dagger \sqrt{N^{-1}}
+    # jax.eval_shape(signal_response, pos_truth)
+    rnr_smpl = signal_response_dagger(
+        noise_cov_inv_sqrt(random.normal(rnr_key, signal_response_truth.shape))
+    )
+    return s_inv_smpl + rnr_smpl
+
+
+def sample_from_d(key):
+    d_inv_smpl = sample_from_d_inv(key)
+    # TODO: what to do here?
+    smpl, _ = jft.cg(D_inv, d_inv_smpl, maxiter=32)
+    return smpl
+
+
+wiener_samples = jnp.array(
+    list(map(lambda key: sample_from_d(key) + m, random.split(key, 30)))
+)
+
+# %%
+subplots = (3, 1)
+fig_height_pt = 541  # pt
+#fig_width_pt = 360 # pt
+inches_per_pt = 1 / 72.27
+fig_height_in = 1. * fig_height_pt * inches_per_pt
+fig_width_in = fig_height_in / 0.618 * (subplots[1] / subplots[0])
+fig_dims = (fig_width_in, fig_height_in)
+
+fig, (ax_raw, ax_nuts, ax_wiener) = plt.subplots(
+    subplots[0], subplots[1], sharex=True, sharey=False, figsize=fig_dims
+)
+
+ax_raw.plot(signal_response_truth, label='true signal response')
+ax_raw.plot(data, 'k.', label='noisy data', markersize=2.)
+#ax_raw.set_xlabel('position')
+ax_raw.set_ylabel('signal response')
+ax_raw.set_title("signal and data")
+ax_raw.legend(fontsize=8)
+
+plot_mean_and_stddev(
+    ax_nuts,
+    chain.samples,
+    truth=True,
+    title="NUTS",
+    xlabel=None,
+    mean_label='sample mean'
+)
+plot_mean_and_stddev(
+    ax_wiener,
+    wiener_samples,
+    mean_of_r=signal_response(m),
+    truth=True,
+    title="Wiener Filter",
+    mean_label='exact posterior mean'
+)
+
+fig.tight_layout()
+
+plt.savefig('wiener.pdf', bbox_inches='tight')
+print("final plot saved as wiener.pdf")
diff --git a/demos/re/lognorm_w_hmc.py b/demos/re/lognorm_w_hmc.py
new file mode 100644
index 0000000000000000000000000000000000000000..d519b98de4f27e8bbfc8109d3c03cd9c6a62c328
--- /dev/null
+++ b/demos/re/lognorm_w_hmc.py
@@ -0,0 +1,338 @@
+#!/usr/bin/env python3
+
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+# %%
+from functools import partial
+import sys
+
+from jax import numpy as jnp
+from jax import lax, random
+from jax import jit
+from jax.config import config
+import matplotlib.pyplot as plt
+
+import nifty8.re as jft
+
+config.update("jax_enable_x64", True)
+
+seed = 42
+key = random.PRNGKey(seed)
+
+
+# %%
+def cartesian_product(arrays, out=None):
+    import numpy as np
+
+    # Generalized N-dimensional products
+    arrays = [np.asarray(x) for x in arrays]
+    la = len(arrays)
+    dtype = np.find_common_type([a.dtype for a in arrays], [])
+    if out is None:
+        out = np.empty([len(a) for a in arrays] + [la], dtype=dtype)
+    for i, a in enumerate(np.ix_(*arrays)):
+        out[..., i] = a
+    return out.reshape(-1, la)
+
+
+def helper_phi_b(b, x):
+    return b * x[0] * jnp.exp(b * x[1])
+
+
+# %%
+b = 2.
+
+signal_response = partial(helper_phi_b, b)
+nll = jft.Gaussian(0., lambda x: x / jnp.sqrt(1.)) @ signal_response
+
+ham = jft.StandardHamiltonian(nll).jit()
+ham_vg = jit(jft.mean_value_and_grad(ham))
+ham_metric = jit(jft.mean_metric(ham.metric))
+MetricKL = jit(
+    partial(jft.MetricKL, ham),
+    static_argnames=("n_samples", "mirror_samples", "linear_sampling_name")
+)
+GeoMetricKL = partial(jft.GeoMetricKL, ham)
+
+# %%
+n_pix_sqrt = 1000
+x = jnp.linspace(-4, 4, n_pix_sqrt)
+y = jnp.linspace(-4, 4, n_pix_sqrt)
+xx = cartesian_product((x, y))
+ham_everywhere = jnp.vectorize(ham, signature="(2)->()")(xx).reshape(
+    n_pix_sqrt, n_pix_sqrt
+)
+plt.imshow(
+    jnp.exp(-ham_everywhere.T),
+    extent=(x.min(), x.max(), y.min(), y.max()),
+    origin="lower"
+)
+plt.colorbar()
+plt.title("target distribution")
+plt.show()
+
+# %%
+n_mgvi_iterations = 30
+n_samples = [2] * (n_mgvi_iterations - 10) + [2] * 5 + [10, 10, 10, 10, 100]
+n_newton_iterations = [7] * (n_mgvi_iterations - 10) + [10] * 6 + 4 * [25]
+absdelta = 1e-13
+
+initial_position = jnp.array([1., 1.])
+mkl_pos = 1e-2 * jft.Field(initial_position)
+
+mgvi_positions = []
+
+# Minimize the potential
+for i in range(n_mgvi_iterations):
+    print(f"MGVI Iteration {i}", file=sys.stderr)
+    print("Sampling...", file=sys.stderr)
+    key, subkey = random.split(key, 2)
+    mg_samples = MetricKL(
+        mkl_pos,
+        n_samples[i],
+        key=subkey,
+        mirror_samples=True,
+        linear_sampling_kwargs={"absdelta": absdelta / 10.},
+    )
+
+    print("Minimizing...", file=sys.stderr)
+    opt_state = jft.minimize(
+        None,
+        x0=mkl_pos,
+        method="newton-cg",
+        options={
+            "fun_and_grad": partial(ham_vg, primals_samples=mg_samples),
+            "hessp": partial(ham_metric, primals_samples=mg_samples),
+            "absdelta": absdelta,
+            "maxiter": n_newton_iterations[i],
+            "cg_kwargs": {
+                "name": None
+            },
+            "name": "N"
+        }
+    )
+    mkl_pos = opt_state.x
+    msg = f"Post MGVI Iteration {i}: Energy {mg_samples.at(mkl_pos).mean(ham):2.4e}"
+    print(msg, file=sys.stderr)
+    mgvi_positions.append(mkl_pos)
+
+# %%
+n_geovi_iterations = 15
+n_samples = [1] * (n_geovi_iterations - 10) + [2] * 5 + [10, 10, 10, 10, 100]
+n_newton_iterations = [7] * (n_geovi_iterations - 10) + [10] * 6 + [25] * 4
+absdelta = 1e-10
+
+initial_position = jnp.array([1., 1.])
+gkl_pos = 1e-2 * jft.Field(initial_position)
+
+for i in range(n_geovi_iterations):
+    print(f"geoVI Iteration {i}", file=sys.stderr)
+    print("Sampling...", file=sys.stderr)
+    key, subkey = random.split(key, 2)
+    geo_samples = GeoMetricKL(
+        gkl_pos,
+        n_samples[i],
+        key=subkey,
+        mirror_samples=True,
+        linear_sampling_name=None,
+        linear_sampling_kwargs={"absdelta": absdelta / 10.},
+        non_linear_sampling_kwargs={
+            "cg_kwargs": {
+                "miniter": 0
+            },
+            "maxiter": 20
+        },
+    )
+
+    print("Minimizing...", file=sys.stderr)
+    opt_state = jft.minimize(
+        None,
+        x0=gkl_pos,
+        method="newton-cg",
+        options={
+            "fun_and_grad": partial(ham_vg, primals_samples=geo_samples),
+            "hessp": partial(ham_metric, primals_samples=geo_samples),
+            "absdelta": absdelta,
+            "maxiter": n_newton_iterations[i],
+            "cg_kwargs": {
+                "miniter": 0,
+                "name": None
+            },
+            "name": "N"
+        }
+    )
+    gkl_pos = opt_state.x
+    msg = f"Post geoVI Iteration {i}: Energy {geo_samples.at(gkl_pos).mean(ham):2.4e}"
+    print(msg, file=sys.stderr)
+
+# %%
+n_pix_sqrt = 200
+x = jnp.linspace(-4.0, 4.0, n_pix_sqrt, endpoint=True)
+y = jnp.linspace(-4.0, 4.0, n_pix_sqrt, endpoint=True)
+X, Y = jnp.meshgrid(x, y)
+XY = jnp.array([X, Y]).T
+xy = XY.reshape((XY.shape[0] * XY.shape[1], 2))
+es = jnp.exp(-lax.map(ham, xy)).reshape(XY.shape[:2]).T
+
+# %%
+mkl_b_space_smpls = jnp.array([s.val for s in mg_samples.at(mkl_pos)])
+
+fig, ax = plt.subplots()
+contour = ax.contour(X, Y, es)
+ax.clabel(contour, inline=True, fontsize=10)
+ax.scatter(*mkl_b_space_smpls.T)
+ax.plot(*mkl_pos, "rx")
+plt.title("MGVI")
+plt.show()
+
+# %%
+gkl_b_space_smpls = jnp.array([s.val for s in geo_samples.at(gkl_pos)])
+
+fig, ax = plt.subplots()
+contour = ax.contour(X, Y, es)
+ax.clabel(contour, inline=True, fontsize=10)
+ax.scatter(*gkl_b_space_smpls.T)
+ax.plot(*gkl_pos, "rx")
+plt.title("GeoVI")
+plt.show()
+
+# %%
+initial_position = jnp.array([1., 1.])
+
+hmc_sampler = jft.HMCChain(
+    potential_energy=ham,
+    inverse_mass_matrix=1.,
+    position_proto=initial_position,
+    step_size=0.1,
+    num_steps=64,
+)
+
+chain, _ = hmc_sampler.generate_n_samples(
+    42, 1e-2 * initial_position, num_samples=100, save_intermediates=True
+)
+
+# %%
+b_space_smpls = chain.samples
+fig, ax = plt.subplots()
+ax.scatter(*b_space_smpls.T)
+plt.title("HMC (Metroplis-Hastings) samples")
+plt.show()
+
+# %%
+initial_position = jnp.array([1., 1.])
+
+nuts_sampler = jft.NUTSChain(
+    potential_energy=ham,
+    inverse_mass_matrix=0.5,
+    position_proto=initial_position,
+    step_size=0.4,
+    max_tree_depth=10,
+)
+
+nuts_n_samples = []
+ns_samples = [200, 1000, 1000000]
+for n_samples in ns_samples:
+    chain, _ = nuts_sampler.generate_n_samples(
+        43 + n_samples,
+        1e-2 * initial_position,
+        num_samples=n_samples,
+        save_intermediates=True
+    )
+    nuts_n_samples.append(chain.samples)
+
+# %%
+b_space_smpls = chain.samples
+
+fig, ax = plt.subplots()
+contour = ax.contour(X, Y, es)
+ax.clabel(contour, inline=True, fontsize=10)
+ax.scatter(*b_space_smpls.T, s=2.)
+plt.show()
+
+# %%
+plt.hist2d(
+    *b_space_smpls.T,
+    bins=[x, y],
+    range=[[x.min(), x.max()], [y.min(), y.max()]]
+)
+plt.colorbar()
+plt.show()
+
+# %%
+subplots = (3, 2)
+
+fig_width_pt = 426  # pt (a4paper, and such)
+inches_per_pt = 1 / 72.27
+fig_width_in = fig_width_pt * inches_per_pt
+fig_height_in = fig_width_in * 1. * (subplots[0] / subplots[1])
+fig_dims = (fig_width_in, fig_height_in)
+
+fig, ((ax1, ax4), (ax2, ax5), (ax3, ax6)
+     ) = plt.subplots(*subplots, figsize=fig_dims, sharex=True, sharey=True)
+
+ax1.set_title(r'$P(d=0|\xi_1, \xi_2) \cdot P(\xi_1, \xi_2)$')
+xx = cartesian_product((x, y))
+ham_everywhere = jnp.vectorize(ham, signature="(2)->()")(xx).reshape(
+    n_pix_sqrt, n_pix_sqrt
+)
+ax1.imshow(
+    jnp.exp(-ham_everywhere.T),
+    extent=(x.min(), x.max(), y.min(), y.max()),
+    origin="lower"
+)
+#ax1.colorbar()
+
+ax1.set_ylim([-4., 4.])
+ax1.set_xlim([-4., 4.])
+#ax1.autoscale(enable=True, axis='y', tight=True)
+asp = float(
+    jnp.diff(jnp.array(ax1.get_xlim()))[0] /
+    jnp.diff(jnp.array(ax1.get_ylim()))[0]
+)
+
+smplmarkersize = .3
+smplmarkercolor = 'k'
+
+linewidths = 0.5
+fontsize = 5
+potlabels = False
+
+ax2.set_title('MGVI')
+mkl_b_space_smpls = jnp.array([s.val for s in mg_samples.at(mkl_pos)])
+contour = ax2.contour(X, Y, es, linewidths=linewidths)
+ax2.clabel(contour, inline=True, fontsize=fontsize)
+ax2.scatter(*mkl_b_space_smpls.T, s=smplmarkersize, c=smplmarkercolor)
+ax2.plot(*mkl_pos, "rx")
+#ax2.set_aspect(asp)
+
+ax3.set_title('geoVI')
+gkl_b_space_smpls = jnp.array([s.val for s in geo_samples.at(gkl_pos)])
+contour = ax3.contour(X, Y, es, linewidths=linewidths)
+ax3.clabel(contour, inline=True, fontsize=fontsize)
+ax3.scatter(*gkl_b_space_smpls.T, s=smplmarkersize, c=smplmarkercolor)
+ax3.plot(*gkl_pos, "rx")
+#ax3.set_aspect(asp)
+
+for i in range(3):
+    eval('ax' + str(i + 1)).set_ylabel(r'$\xi_2$')
+ax3.set_xlabel(r'$\xi_1$')
+ax6.set_xlabel(r'$\xi_1$')
+
+for n, samples, ax in zip(ns_samples[:2], nuts_n_samples[:2], [ax4, ax5]):
+    ax.set_title(f"NUTS N={n}")
+    contour = ax.contour(X, Y, es, linewidths=linewidths)
+    #ax.clabel(contour, inline=True, fontsize=fontsize)
+    ax.scatter(*samples.T, s=smplmarkersize, c=smplmarkercolor)
+
+h, _, _ = jnp.histogram2d(
+    *nuts_n_samples[-1].T,
+    bins=[x, y],
+    range=[[x.min(), x.max()], [y.min(), y.max()]]
+)
+ax6.imshow(h.T, extent=(x.min(), x.max(), y.min(), y.max()), origin="lower")
+ax6.set_title(f'NUTS N={ns_samples[-1]:.0E}')
+
+fig.tight_layout()
+fig.savefig("pinch.pdf", bbox_inches='tight')
+print("final plot saved as pinch.pdf")
diff --git a/demos/re/nifty_to_jifty.py b/demos/re/nifty_to_jifty.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b82053ebd9c1eaacc7aa2ba741fee4f9476e715
--- /dev/null
+++ b/demos/re/nifty_to_jifty.py
@@ -0,0 +1,324 @@
+#!/usr/bin/env python3
+
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial
+import sys
+
+from jax import numpy as jnp
+from jax import random
+from jax import jit
+from jax.config import config
+import matplotlib.pyplot as plt
+
+import nifty8.re as jft
+
+config.update("jax_enable_x64", True)
+
+# %%
+# ## Likelihood
+#
+# ### What is a Likelihood in jifty?
+#
+# * Very generally, the likelihood stores the cost term(s) for the final minimization
+#   * P(d|\xi) is a likelihood just like P(\xi) is a likelihood (w/ d := data, \xi := parameters)
+#   * Adding two likelihoods yields a likelihood again; thus P(d|\xi) + P(\xi) is just another likelihood
+# * Properties
+#   * Energy/Hamiltonian: negative log-probability
+#   * Left square root (L) of the metric (M; M = L L^\dagger): needed for sampling and minimization
+#   * Metric: needed for sampling and minimization; can be inferred from left sqrt metric
+#
+# ### Differences to NIFTy's `EnergyOperator`?
+#
+# * There are no operators in jifty, thus there is no EnergyOperator!
+# * NIFTy features many different energies classes; in jifty there is just one
+# * jifty needs to track the domain of the data without re-introducing operators
+#
+# ### What gives?
+#
+# * No manual tracking of the jacobian
+# * No linear operators; this also means we can not take the adjoint of the jacobian :(
+# * Trivial to define new likelihoods
+
+
+def Gaussian(data, noise_cov_inv_sqrt):
+    # Simple but not very generic Gaussian energy
+    def hamiltonian(primals):
+        p_res = primals - data
+        l_res = noise_cov_inv_sqrt(p_res)
+        return 0.5 * jnp.sum(l_res**2)
+
+    def left_sqrt_metric(primals, tangents):
+        return noise_cov_inv_sqrt(tangents)
+
+    lsm_tangents_shape = jnp.shape(data)
+    # Better: `tree_map(ShapeWithDtype.from_leave, data)`
+
+    return jft.Likelihood(
+        hamiltonian,
+        left_sqrt_metric=left_sqrt_metric,
+        lsm_tangents_shape=lsm_tangents_shape
+    )
+
+
+seed = 42
+key = random.PRNGKey(seed)
+
+dims = (1024, )
+
+loglogslope = 2.
+power_spectrum = lambda k: 1. / (k**loglogslope + 1.)
+modes = jnp.arange((dims[0] / 2) + 1., dtype=float)
+harmonic_power = power_spectrum(modes)
+harmonic_power = jnp.concatenate((harmonic_power, harmonic_power[-2:0:-1]))
+
+# Specify the model
+correlated_field = lambda x: jft.correlated_field.hartley(
+    harmonic_power * x.val
+)
+signal_response = lambda x: jnp.exp(1. + correlated_field(x))
+noise_cov_inv_sqrt = lambda x: 0.1**-1 * x
+
+# Create synthetic data
+key, subkey = random.split(key)
+pos_truth = jft.Field(random.normal(shape=dims, key=key))
+signal_response_truth = signal_response(pos_truth)
+key, subkey = random.split(key)
+noise_truth = 1. / noise_cov_inv_sqrt(jnp.ones(dims)
+                                     ) * random.normal(shape=dims, key=key)
+data = signal_response_truth + noise_truth
+
+nll = Gaussian(data, noise_cov_inv_sqrt) @ signal_response
+ham = jft.StandardHamiltonian(likelihood=nll).jit()
+ham_vg = jit(jft.mean_value_and_grad(ham))
+ham_metric = jit(jft.mean_metric(ham.metric))
+MetricKL = jit(
+    partial(jft.MetricKL, ham),
+    static_argnames=("n_samples", "mirror_samples", "linear_sampling_name")
+)
+
+key, subkey = random.split(key)
+pos_init = jft.Field(random.normal(shape=dims, key=subkey))
+pos = jft.Field(pos_init.val)
+
+n_newton_iterations = 10
+# Maximize the posterior using natural gradient scaling
+pos = jft.newton_cg(
+    fun_and_grad=ham_vg, x0=pos, hessp=ham_metric, maxiter=n_newton_iterations
+)
+
+fig, ax = plt.subplots()
+ax.plot(signal_response_truth, alpha=0.7, label="Signal")
+ax.plot(noise_truth, alpha=0.7, label="Noise")
+ax.plot(data, alpha=0.7, label="Data")
+ax.plot(signal_response(pos), alpha=0.7, label="Reconstruction")
+ax.legend()
+fig.tight_layout()
+fig.savefig("n2f_known_spectrum_MAP.png", dpi=400)
+plt.close()
+
+# ## Sampling
+#
+# ### How sampling works in jifty?
+#
+# To sample from a likelihood, we need to be able to draw samples which have
+# the metric as covariance structure and we need to be able to apply the
+# inverse metric. The first part is trivial since we can use the left square
+# root of the metric associated with every likelihood:
+#
+#   \tilde{d} \leftarrow \mathcal{G}(0,\mathbb{1})
+#   t = L \tilde{d}
+#
+# with $t$ now having a covariance structure of
+#
+#   <t t^\dagger> = L <\tilde{d} \tilde{d}^\dagger> L^\dagger = M.
+#
+# We now need to apply the inverse metric in order to transform the sample to
+# an inverse sample. We can do so using the conjugate gradient algorithm which
+# yields the solution to $M s = t$, i.e. applies the inverse of $M$ to $t$:
+#
+#   M s =  t
+#   s = M^{-1} t = cg(M, t) .
+#
+# ### Differences to NIFTy?
+#
+# * More generic implementation since the left square root of the metric can
+#   be applied independently from drawing samples
+# * By virtue of storing the left square root metric, no dedicated sampling
+#   method needs to be extended ever again
+#
+# ### What gives?
+#
+# The clearer separation of sampling and inverting the metric allows for a
+# better interplay of our methods with existing tools like JAX's cg
+# implementation.
+
+n_mgvi_iterations = 3
+n_samples = 4
+n_newton_iterations = 5
+
+key, subkey = random.split(key)
+pos_init = jft.Field(random.normal(shape=dims, key=subkey))
+pos = jft.Field(pos_init.val)
+
+# Minimize the potential
+for i in range(n_mgvi_iterations):
+    print(f"MGVI Iteration {i}", file=sys.stderr)
+    print("Sampling...", file=sys.stderr)
+    key, subkey = random.split(key, 2)
+    mg_samples = MetricKL(
+        pos,
+        n_samples=n_samples,
+        key=subkey,
+        mirror_samples=True,
+    )
+
+    print("Minimizing...", file=sys.stderr)
+    pos = jft.newton_cg(
+        fun_and_grad=partial(ham_vg, primals_samples=mg_samples),
+        x0=pos,
+        hessp=partial(ham_metric, primals_samples=mg_samples),
+        maxiter=n_newton_iterations
+    )
+    msg = f"Post MGVI Iteration {i}: Energy {mg_samples.at(pos).mean(ham):2.4e}"
+    print(msg, file=sys.stderr)
+
+post_sr_mean = jft.mean(tuple(signal_response(s) for s in mg_samples.at(pos)))
+fig, ax = plt.subplots()
+ax.plot(signal_response_truth, alpha=0.7, label="Signal")
+ax.plot(noise_truth, alpha=0.7, label="Noise")
+ax.plot(data, alpha=0.7, label="Data")
+ax.plot(post_sr_mean, alpha=0.7, label="Reconstruction")
+label = "Reconstructed samples"
+for s in mg_samples:
+    ax.plot(signal_response(s), color="gray", alpha=0.5, label=label)
+    label = None
+ax.legend()
+fig.tight_layout()
+fig.savefig("n2f_known_spectrum_MGVI.png", dpi=400)
+plt.close()
+
+# ## Correlated field
+#
+# ### Correlated fields in jifty
+#
+# * `CorrelatedFieldMaker` to track amplitudes along different axes
+# * `add_fluctuations` method to amend new amplitudes
+# * Zero-mode is tracked separately to the amplitudes
+# * `finalize` normalizes the amplitudes and takes their outer product
+# * Amplitudes are independent of the stack of amplitudes tracked in the correlated field, i.e. no normalization happens within the amplitude
+#
+# ### Differences to NIFTy
+#
+# A correlated field with a single axis but arbitrary dimensionality in NIFTy
+# is mostly equivalent to one in jifty. Though since jifty does not track
+# domains, everything related to harmonic modes and distributing power is
+# contained within the correlated field model.
+#
+# The normalization and factorization of amplitudes is done only once in
+# `finalize`. This conceptually simplifies the amplitude model by a lot.
+#
+# ### What gives?
+#
+# * Conceptually simpler amplitude model
+# * No domains --> no domain mismatches --> broadcasting \o/
+# * No domains --> no domain mismatches --> more errors :(
+
+dims_ax1 = (64, )
+dims_ax2 = (128, )
+cf_zm = {"offset_mean": 0., "offset_std": (1e-3, 1e-4)}
+cf_fl = {
+    "fluctuations": (1e-1, 5e-3),
+    "loglogavgslope": (-1., 1e-2),
+    "flexibility": (1e+0, 5e-1),
+    "asperity": (5e-1, 1e-1),
+    "harmonic_domain_type": "Fourier"
+}
+cfm = jft.CorrelatedFieldMaker("cf")
+cfm.set_amplitude_total_offset(**cf_zm)
+d = 1. / dims_ax1[0]
+cfm.add_fluctuations(dims_ax1, distances=d, **cf_fl, prefix="ax1")
+d = 1. / dims_ax2[0]
+cfm.add_fluctuations(dims_ax2, distances=d, **cf_fl, prefix="ax2")
+correlated_field, ptree = cfm.finalize()
+
+signal_response = lambda x: correlated_field(x)
+noise_cov = lambda x: 5**2 * x
+noise_cov_inv = lambda x: 5**-2 * x
+
+# Create synthetic data
+key, subkey = random.split(key)
+pos_truth = jft.random_like(subkey, ptree)
+signal_response_truth = signal_response(pos_truth)
+key, subkey = random.split(key)
+noise_truth = jnp.sqrt(
+    noise_cov(jnp.ones(signal_response_truth.shape))
+) * random.normal(shape=signal_response_truth.shape, key=key)
+data = signal_response_truth + noise_truth
+
+nll = jft.Gaussian(data, noise_cov_inv) @ signal_response
+ham = jft.StandardHamiltonian(likelihood=nll).jit()
+ham_vg = jit(jft.mean_value_and_grad(ham))
+ham_metric = jit(jft.mean_metric(ham.metric))
+MetricKL = jit(
+    partial(jft.MetricKL, ham),
+    static_argnames=("n_samples", "mirror_samples", "linear_sampling_name")
+)
+
+key, subkey = random.split(key)
+pos_init = jft.Field(jft.random_like(subkey, ptree))
+pos = jft.Field(pos_init.val)
+
+n_mgvi_iterations = 3
+n_samples = 4
+n_newton_iterations = 10
+
+# Minimize the potential
+for i in range(n_mgvi_iterations):
+    print(f"MGVI Iteration {i}", file=sys.stderr)
+    print("Sampling...", file=sys.stderr)
+    key, subkey = random.split(key, 2)
+    mg_samples = MetricKL(
+        pos,
+        n_samples=n_samples,
+        key=subkey,
+        mirror_samples=True,
+    )
+
+    print("Minimizing...", file=sys.stderr)
+    pos = jft.newton_cg(
+        fun_and_grad=partial(ham_vg, primals_samples=mg_samples),
+        x0=pos,
+        hessp=partial(ham_metric, primals_samples=mg_samples),
+        maxiter=n_newton_iterations
+    )
+    msg = f"Post MGVI Iteration {i}: Energy {mg_samples.at(pos).mean(ham):2.4e}"
+    print(msg, file=sys.stderr)
+
+namps = cfm.get_normalized_amplitudes()
+post_sr_mean = jft.mean(tuple(signal_response(s) for s in mg_samples.at(pos)))
+post_namps1_mean = jft.mean(tuple(namps[0](s)[1:] for s in mg_samples.at(pos)))
+post_namps2_mean = jft.mean(tuple(namps[1](s)[1:] for s in mg_samples.at(pos)))
+to_plot = [
+    ("Signal", signal_response_truth, "im"),
+    ("Noise", noise_truth, "im"),
+    ("Data", data, "im"),
+    ("Reconstruction", post_sr_mean, "im"),
+    ("Ax1", (namps[0](pos_truth)[1:], post_namps1_mean), "loglog"),
+    ("Ax2", (namps[1](pos_truth)[1:], post_namps2_mean), "loglog"),
+]
+fig, axs = plt.subplots(2, 3, figsize=(16, 9))
+for ax, (title, field, tp) in zip(axs.flat, to_plot):
+    ax.set_title(title)
+    if tp == "im":
+        im = ax.imshow(field, cmap="inferno")
+        plt.colorbar(im, ax=ax, orientation="horizontal")
+    else:
+        ax_plot = ax.loglog if tp == "loglog" else ax.plot
+        field = field if isinstance(field, (tuple, list)) else (field, )
+        for f in field:
+            ax_plot(f, alpha=0.7)
+fig.tight_layout()
+fig.savefig("n2f_unknown_factorizing_spectra.png", dpi=400)
+plt.close()
diff --git a/demos/re/refine.py b/demos/re/refine.py
new file mode 100644
index 0000000000000000000000000000000000000000..872662fe6356c5fe2fdf6a8312bd5772870025b6
--- /dev/null
+++ b/demos/re/refine.py
@@ -0,0 +1,254 @@
+#!/usr/bin/env python3
+
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from collections import namedtuple
+from functools import partial
+import sys
+
+import jax
+from jax import numpy as jnp
+from jax import random
+from jax.scipy.interpolate import RegularGridInterpolator
+import matplotlib.pyplot as plt
+import numpy as np
+from scipy.special import kv as mod_bessel2
+
+import nifty8.re as jft
+
+jax.config.update("jax_enable_x64", True)
+# jax.config.update("jax_debug_nans", True)
+
+Timed = namedtuple("Timed", ("time", "number"), rename=True)
+
+
+def timeit(stmt, setup=lambda: None, number=None):
+    import timeit
+
+    if number is None:
+        number, _ = timeit.Timer(stmt).autorange()
+
+    setup()
+    t = timeit.timeit(stmt, number=number) / number
+    return Timed(time=t, number=number)
+
+
+def _matern_kernel(distance, scale, cutoff, dof):
+    from jax.scipy.special import gammaln
+
+    reg_dist = jnp.sqrt(2 * dof) * distance / cutoff
+    return scale**2 * 2**(1 - dof) / jnp.exp(
+        gammaln(dof)
+    ) * (reg_dist)**dof * mod_bessel2(dof, reg_dist)
+
+
+n_dof = 100
+n_dist = 1000
+min_reg_dist = 1e-6  # approx. lowest resolution of `_matern_kernel` at float64
+max_reg_dist = 8e+2  # approx. highest resolution of `_matern_kernel` at float64
+eps = 8. * jnp.finfo(jnp.array(min_reg_dist).dtype.type).eps
+dof_grid = np.linspace(0., 15., n_dof)
+reg_dist_grid = np.logspace(
+    np.log(min_reg_dist * (1. - eps)),
+    np.log(max_reg_dist * (1. + eps)),
+    base=np.e,
+    num=n_dist
+)
+grid = np.meshgrid(dof_grid, reg_dist_grid, indexing="ij")
+_unsafe_ln_mod_bessel2 = RegularGridInterpolator(
+    (dof_grid, reg_dist_grid), jnp.log(mod_bessel2(*grid)), fill_value=-np.inf
+)
+
+
+def matern_kernel(distance, scale, cutoff, dof):
+    from jax.scipy.special import gammaln
+
+    reg_dist = jnp.sqrt(2 * dof) * distance / cutoff
+    dof, reg_dist = jnp.broadcast_arrays(dof, reg_dist)
+
+    # Never produce NaNs (https://github.com/google/jax/issues/1052)
+    reg_dist = reg_dist.clip(min_reg_dist, max_reg_dist)
+
+    ln_kv = jnp.squeeze(
+        _unsafe_ln_mod_bessel2(jnp.stack((dof, reg_dist), axis=-1))
+    )
+    corr = 2**(1 - dof) * jnp.exp(ln_kv - gammaln(dof)) * (reg_dist)**dof
+    return scale**2 * corr
+
+
+scale, cutoff, dof = 1., 80., 3 / 2
+
+# %%
+x = np.logspace(-6, 11, base=jnp.e, num=int(1e+5))
+y = _matern_kernel(x, scale, cutoff, dof)
+y = jnp.nan_to_num(y, nan=0.)
+kernel = partial(jnp.interp, xp=x, fp=y)
+kernel_j = partial(matern_kernel, scale=scale, cutoff=cutoff, dof=dof)
+
+fig, ax = plt.subplots()
+x_s = x[x < 10 * cutoff]
+ax.plot(x_s, kernel(x_s))
+ax.plot(x_s, kernel_j(x_s))
+ax.plot(x_s, jnp.exp(-(x_s / (2. * cutoff))**2))
+ax.set_yscale("log")
+fig.savefig("re_refine_kernel.png", transparent=True)
+plt.close()
+
+# %%
+# Quick demo of the correlated field scheme that is to be used in the following
+cf_kwargs = {"shape0": (12, ), "distances0": (50., ), "kernel": kernel}
+
+cf = jft.RefinementField(**cf_kwargs, depth=5)
+xi = jft.random_like(random.PRNGKey(42), cf.shapewithdtype)
+
+fig, ax = plt.subplots(figsize=(8, 4))
+for i in range(cf.chart.depth):
+    cf_lvl = jft.RefinementField(**cf_kwargs, depth=i)
+    x = jnp.mgrid[tuple(slice(sz) for sz in cf_lvl.chart.shape)]
+    x = cf.chart.ind2rg(x, i)[0]
+    f_lvl = cf_lvl(xi[:i + 1])
+    ax.step(x, f_lvl, alpha=0.7, where="mid", label=f"level {i}")
+# ax.set_frame_on(False)
+# ax.set_xticks([], [])
+# ax.set_yticks([], [])
+ax.legend()
+fig.tight_layout()
+fig.savefig("re_refine_field_layers.png", transparent=True)
+plt.close()
+
+
+# %%
+def parametrized_kernel(xi, verbose=False):
+    scale = jnp.exp(-0.5 + 0.2 * xi["lat_scale"])
+    cutoff = jnp.exp(4. + 1e-2 * xi["lat_cutoff"])
+    # dof = jnp.exp(0.5 + 0.1 * xi["lat_dof"])
+    # kernel = lambda r: xi["scale"] * jnp.exp(-(r / xi["cutoff"])**2)
+    if verbose:
+        print(f"{scale=}, {cutoff=}, {dof=}")
+
+    return partial(matern_kernel, scale=scale, cutoff=cutoff, dof=dof)
+
+
+def signal_response(xi):
+    return cf(xi["excitations"], parametrized_kernel(xi))
+
+
+n_std = 0.5
+
+key = random.PRNGKey(45)
+key, *key_splits = random.split(key, 4)
+
+xi_truth = jft.random_like(key_splits.pop(), cf.shapewithdtype)
+d = cf(xi_truth, kernel)
+d += n_std * random.normal(key_splits.pop(), shape=d.shape)
+
+xi_swd = {
+    "excitations": cf.shapewithdtype,
+    "lat_scale": jft.ShapeWithDtype(()),
+    "lat_cutoff": jft.ShapeWithDtype(()),
+}
+pos = 1e-4 * jft.Field(jft.random_like(key_splits.pop(), xi_swd))
+
+n_mgvi_iterations = 15
+n_newton_iterations = 15
+n_samples = 2
+absdelta = 1e-5
+
+nll = jft.Gaussian(d, noise_std_inv=lambda x: x / n_std) @ signal_response
+ham = jft.StandardHamiltonian(nll)  # + 0.5 * jft.norm(x, ord=2, ravel=True)
+ham_vg = jax.jit(jft.mean_value_and_grad(ham))
+ham_metric = jax.jit(jft.mean_metric(ham.metric))
+MetricKL = jax.jit(
+    partial(jft.MetricKL, ham),
+    static_argnames=("n_samples", "mirror_samples", "linear_sampling_name")
+)
+
+# %%
+# Minimize the potential
+for i in range(n_mgvi_iterations):
+    print(f"MGVI Iteration {i}", file=sys.stderr)
+    print("Sampling...", file=sys.stderr)
+    key, subkey = random.split(key, 2)
+    samples = MetricKL(
+        pos,
+        n_samples=n_samples,
+        key=subkey,
+        mirror_samples=True,
+        linear_sampling_kwargs={"absdelta": absdelta / 10.}
+    )
+
+    print("Minimizing...", file=sys.stderr)
+    opt_state = jft.minimize(
+        None,
+        x0=pos,
+        method="newton-cg",
+        options={
+            "fun_and_grad": partial(ham_vg, primals_samples=samples),
+            "hessp": partial(ham_metric, primals_samples=samples),
+            "absdelta": absdelta,
+            "maxiter": n_newton_iterations
+        }
+    )
+    pos = opt_state.x
+    msg = f"Post MGVI Iteration {i}: Energy {samples.at(pos).mean(ham):2.4e}"
+    print(msg, file=sys.stderr)
+
+# %%
+fig, ax = plt.subplots(figsize=(8, 4))
+ax.plot(d, label="data")
+ax.plot(cf(xi_truth, kernel), label="truth")
+ax.plot(samples.at(pos).mean(signal_response), label="reconstruction")
+ax.legend()
+fig.tight_layout()
+fig.savefig("re_refine_reconstruction.png", transparent=True)
+plt.close()
+
+# %%
+cf_bench = jft.RefinementField(shape0=(12, ), kernel=kernel, depth=15)
+xi_wo = jft.random_like(random.PRNGKey(42), jft.Field(cf_bench.shapewithdtype))
+xi_w = jft.random_like(
+    random.PRNGKey(42),
+    jft.Field(
+        {
+            "excitations": cf_bench.shapewithdtype,
+            "lat_scale": jft.ShapeWithDtype(()),
+            "lat_cutoff": jft.ShapeWithDtype(()),
+        }
+    )
+)
+
+
+def signal_response_bench(xi):
+    return cf_bench(xi["excitations"], parametrized_kernel(xi))
+
+
+d = signal_response_bench(0.5 * xi_w)
+nll_wo_fwd = jft.Gaussian(d, noise_std_inv=lambda x: x / n_std)
+ham_w = jft.StandardHamiltonian(nll_wo_fwd @ signal_response_bench)
+ham_wo = jft.StandardHamiltonian(nll_wo_fwd @ cf_bench)
+
+# %%
+all_backends = {"cpu"}
+all_backends |= {jax.default_backend()}
+for backend in all_backends:
+    device_kw = {"device": jax.devices(backend=backend)[0]}
+    device_put = partial(jax.device_put, **device_kw)
+
+    cf_vag_bench = jax.jit(jax.value_and_grad(ham_w), **device_kw)
+    x = device_put(xi_w)
+    _ = jax.block_until_ready(cf_vag_bench(x))
+    t = timeit(lambda: jax.block_until_ready(cf_vag_bench(x)))
+    ti, num = t.time, t.number
+
+    msg = f"{backend.upper()} :: Shape {str(cf_bench.chart.shape):>16s} ({num:6d} loops) :: JAX w/ learnable {ti:4.2e}"
+    print(msg, file=sys.stderr)
+
+    cf_vag_bench = jax.jit(jax.value_and_grad(ham_wo), **device_kw)
+    x = device_put(xi_wo)
+    _ = jax.block_until_ready(cf_vag_bench(x))
+    t = timeit(lambda: jax.block_until_ready(cf_vag_bench(x)))
+    ti, num = t.time, t.number
+
+    msg = f"{backend.upper()} :: Shape {str(cf_bench.chart.shape):>16s} ({num:6d} loops) :: JAX w/o learnable {ti:4.2e}"
+    print(msg, file=sys.stderr)
diff --git a/setup.py b/setup.py
index 015e37aeb3aea54778f2a0441e0bc13dc7e0aae4..2de928b7fb12c546fa020b72f4b2f830787c9e4f 100644
--- a/setup.py
+++ b/setup.py
@@ -15,6 +15,8 @@
 #
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
 
+from functools import reduce
+import operator
 import os
 import site
 import sys
@@ -30,6 +32,14 @@ with open("README.md") as f:
     long_description = f.read()
 description = "Library for signal inference algorithms that operate regardless of the underlying grids and their resolutions."
 
+extras_require = {
+    "re": ("jax", ),
+    "native": ("ducc0", "finufft"),
+    "doc": ("sphinx", "pydata-sphinx-theme", "jupyter", "jupytext"),
+    "util": ("astropy", ),
+}
+extras_require["full"] = reduce(operator.add, extras_require.values())
+
 setup(name="nifty8",
       version=__version__,
       author="Martin Reinecke",
@@ -48,6 +58,7 @@ setup(name="nifty8",
       license="GPLv3",
       setup_requires=['scipy>=1.4.1', 'numpy>=1.17'],
       install_requires=['scipy>=1.4.1', 'numpy>=1.17'],
+      extras_require=extras_require,
       python_requires='>=3.7',
       classifiers=[
           "Development Status :: 5 - Production/Stable",
diff --git a/src/__init__.py b/src/__init__.py
index 8d0711b93e754980a15aa5138f3f18a376069154..9ac933d086e45c38ffaa16d8963bbfe76f0c557e 100644
--- a/src/__init__.py
+++ b/src/__init__.py
@@ -108,5 +108,11 @@ from .operator_tree_optimiser import optimise_operator
 
 from .ducc_dispatch import set_nthreads, nthreads
 
+try:
+    from . import re
+    from . import nifty2jax
+except ImportError:
+    pass
+
 # We deliberately don't set __all__ here, because we don't want people to do a
 # "from nifty8 import *"; that would swamp the global namespace.
diff --git a/src/domains/structured_domain.py b/src/domains/structured_domain.py
index 3ab7d9d4bdce268311a4dbb115fc6a9b17ed1110..0c3ad3ac8461531320094afdd81895125fc1abfa 100644
--- a/src/domains/structured_domain.py
+++ b/src/domains/structured_domain.py
@@ -70,7 +70,7 @@ class StructuredDomain(Domain):
 
         Returns
         -------
-        Field
+        :class:`nifty8.field.Field`
             An array containing the k vector lengths
         """
         raise NotImplementedError
diff --git a/src/extra.py b/src/extra.py
index c609295d9ec7aaefadd29f76d902400f217d3099..1cffcd7c2718cb289a9d28e3dc782d8790f177d6 100644
--- a/src/extra.py
+++ b/src/extra.py
@@ -107,7 +107,7 @@ def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
     ----------
     op : Operator
         Operator which shall be checked.
-    loc : Field or MultiField
+    loc : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
         An Field or MultiField instance which has the same domain
         as op. The location at which the gradient is checked
     tol : float
diff --git a/src/field.py b/src/field.py
index dcf3c1230efd74a10d3e14ecf9d1997141cf8f97..144ce2320634767150f9fb22ed12589a969929be 100644
--- a/src/field.py
+++ b/src/field.py
@@ -53,6 +53,10 @@ class Field(Operator):
         if not isinstance(val, np.ndarray):
             if np.isscalar(val):
                 val = np.broadcast_to(val, domain.shape)
+            elif np.shape(val) == domain.shape:
+                # If NumPy thinks the shapes are equal, attempt to convert to
+                # NumPy. This is especially helpful for JAX DeviceArrays.
+                val = np.asarray(val)
             else:
                 raise TypeError("val must be of type numpy.ndarray")
         if domain.shape != val.shape:
@@ -276,7 +280,7 @@ class Field(Operator):
 
         Parameters
         ----------
-        x : Field
+        x : :class:`nifty8.field.Field`
 
         Returns
         -------
@@ -294,7 +298,7 @@ class Field(Operator):
 
         Parameters
         ----------
-        x : Field
+        x : :class:`nifty8.field.Field`
             x must be defined on the same domain as `self`.
 
         spaces : None, int or tuple of int
@@ -326,7 +330,7 @@ class Field(Operator):
 
         Parameters
         ----------
-        x : Field
+        x : :class:`nifty8.field.Field`
             x must be defined on the same domain as `self`.
 
         Returns
diff --git a/src/library/adjust_variances.py b/src/library/adjust_variances.py
index 883d6cf5da87e574c346f4cca6719aab7339a6a5..5103261ea948d1c1f022033b87b5aa3115a84d6d 100644
--- a/src/library/adjust_variances.py
+++ b/src/library/adjust_variances.py
@@ -44,9 +44,9 @@ def make_adjust_variances_hamiltonian(a,
     xi : Operator
         Field Adapter selecting a part of position.
         xi is desired to be a Gaussian white Field.
-    position : Field, MultiField
+    position : :class:`nifty8.field.Field`, :class:`nifty8.multi_field.MultiField`
         Contains the initial values for the operators a and xi, to be adjusted
-    samples : Field, MultiField
+    samples : :class:`nifty8.field.Field`, :class:`nifty8.multi_field.MultiField`
         Residual samples of position.
     scaling : Float
         Optional rescaling of the Likelihood.
@@ -55,7 +55,7 @@ def make_adjust_variances_hamiltonian(a,
 
     Returns
     -------
-    StandardHamiltonian
+    :class:`nifty8.operators.energy_operators.StandardHamiltonian`
         A Hamiltonian that can be used for further minimization.
     """
 
@@ -91,7 +91,7 @@ def do_adjust_variances(position, A, minimizer, xi_key='xi', samples=[]):
 
     Parameters
     ----------
-    position : Field, MultiField
+    position : :class:`nifty8.field.Field`, :class:`nifty8.multi_field.MultiField`
         Contains the initial values for amplitude_operator and the key xi_key,
         to be adjusted.
     A : Operator
@@ -101,7 +101,7 @@ def do_adjust_variances(position, A, minimizer, xi_key='xi', samples=[]):
     xi_key : String
         Key of the Field containing undesired variations. This Field is
         contained in position.
-    samples : Field, MultiField, optional
+    samples : :class:`nifty8.field.Field`, :class:`nifty8.multi_field.MultiField`, optional
         Residual samples of position. If samples are supplied then phi remains
         only approximately constant. Default: [].
 
diff --git a/src/library/correlated_fields.py b/src/library/correlated_fields.py
index 6b17481788288b247e5de4f356e4ed0d4eb500a2..9b962c8df7449b9e41d3a644afd9df504653eaf4 100644
--- a/src/library/correlated_fields.py
+++ b/src/library/correlated_fields.py
@@ -27,6 +27,7 @@ import numpy as np
 from .. import utilities
 from ..domain_tuple import DomainTuple
 from ..domains.power_space import PowerSpace
+from ..domains.rg_space import RGSpace
 from ..domains.unstructured_domain import UnstructuredDomain
 from ..field import Field
 from ..logger import logger
@@ -38,7 +39,6 @@ from ..operators.distributors import PowerDistributor
 from ..operators.endomorphic_operator import EndomorphicOperator
 from ..operators.harmonic_operators import HarmonicTransformOperator
 from ..operators.linear_operator import LinearOperator
-from ..operators.mask_operator import MaskOperator
 from ..operators.normal_operators import LognormalTransform, NormalTransform
 from ..operators.operator import Operator
 from ..operators.simple_linear_operators import VdotOperator, ducktape
@@ -453,6 +453,16 @@ class CorrelatedFieldMaker:
         self._prefix = prefix
         self._total_N = total_N
 
+        try:
+            from .. import re as jft
+
+            if total_N != 0:
+                warn(f"unable to add JAX operator for total_N={total_N}")
+                raise ImportError("short-circuit JAX init")
+            self._jax_cfm = jft.CorrelatedFieldMaker(prefix=prefix)
+        except ImportError:
+            self._jax_cfm = None
+
     def add_fluctuations(self,
                          target_subdomain,
                          fluctuations,
@@ -566,6 +576,25 @@ class CorrelatedFieldMaker:
                          target_subdomain[-1].total_volume,
                          pre + 'spectrum', dofdex)
 
+        is_rg = all(isinstance(dom, RGSpace) for dom in target_subdomain)
+        if self._jax_cfm is not None and (len(dofdex) > 0 or index or not is_rg):
+            warn(f"unable to add JAX operator for {target_subdomain}")
+            self._jax_cfm = None
+        if self._jax_cfm is not None:
+            dists = tuple(e for di in target_subdomain for e in di.distances)
+            self._jax_cfm.add_fluctuations(
+                shape=target_subdomain.shape,
+                distances=dists,
+                fluctuations=fluctuations,
+                loglogavgslope=loglogavgslope,
+                flexibility=flexibility,
+                asperity=asperity,
+                prefix=str(prefix),
+                harmonic_domain_type="fourier",
+                non_parametric_kind="power",
+            )
+            amp._jax_expr = self._jax_cfm.fluctuations[-1]
+
         if index is not None:
             self._a.insert(index, amp)
             self._target_subdomains.insert(index, target_subdomain)
@@ -652,6 +681,10 @@ class CorrelatedFieldMaker:
         amp = _AmplitudeMatern(pow_spc, scale, cutoff, loglogslope,
                                totvol)
 
+        if self._jax_cfm is not None:
+            warn(f"unable to add JAX operator for Matern fluctuations")
+            self._jax_cfm = None
+
         self._a.append(amp)
         self._target_subdomains.append(target_subdomain)
 
@@ -684,12 +717,15 @@ class CorrelatedFieldMaker:
             logger.warning("Overwriting the previous mean offset and zero-mode")
 
         self._offset_mean = offset_mean
+        jax_offset_std = offset_std
         if offset_std is None:
             self._azm = 0.
         elif np.isscalar(offset_std) and offset_std == 1.:
             self._azm = 1.
+            jax_offset_std = lambda _: 1.
         elif isinstance(offset_std, Operator):
             self._azm = offset_std
+            jax_offset_std = offset_std.jax_expr
         else:
             if dofdex is None:
                 dofdex = np.full(self._total_N, 0)
@@ -710,6 +746,21 @@ class CorrelatedFieldMaker:
                 zm = _Distributor(dofdex, zm.target, UnstructuredDomain(self._total_N)) @ zm
             self._azm = zm
 
+        if self._jax_cfm is not None and dofdex is not None and len(dofdex) > 0:
+            warn(f"unable to add JAX operator for dofdex={dofdex}")
+            self._jax_cfm = None
+        if self._jax_cfm is not None:
+            try:
+                self._jax_cfm.set_amplitude_total_offset(
+                    offset_mean=offset_mean, offset_std=jax_offset_std
+                )
+                if not isinstance(self._azm, float):
+                    self._azm._jax_expr = self._jax_cfm.azm
+            except TypeError as e:
+                self._jax_cfm = None
+                if isinstance(e, TypeError):
+                    warn(f"no JAX operator for this configuration;\n{e}")
+
     def finalize(self, prior_info=100):
         """Finishes model construction process and returns the constructed
         operator.
@@ -766,6 +817,10 @@ class CorrelatedFieldMaker:
                 offset = float(offset)
                 op = Adder(full(op.target, offset)) @ op
         self.statistics_summary(prior_info)
+
+        if self._jax_cfm is not None:
+            cf, _ = self._jax_cfm.finalize()
+            op._jax_expr = cf
         return op
 
     def statistics_summary(self, prior_info):
@@ -829,8 +884,13 @@ class CorrelatedFieldMaker:
         elif self.azm == 1:
             return self.fluctuations
 
+        if self._jax_cfm:
+            normed_amps_jax = self._jax_cfm.get_normalized_amplitudes()
+        else:
+            normed_amps_jax = (None, ) * len(self._a)
+
         normal_amp = []
-        for amp in self._a:
+        for amp, na_jax in zip(self._a, normed_amps_jax):
             a_target = amp.target
             a_space = 0 if not hasattr(amp, "_space") else amp._space
             a_pp = amp.target[a_space]
@@ -852,7 +912,9 @@ class CorrelatedFieldMaker:
             zm_normalization = zm_unmask @ (
                 zm_mask @ azm_expander(self.azm.ptw("reciprocal"))
             )
-            normal_amp.append(zm_normalization * amp)
+            na = zm_normalization * amp
+            na._jax_expr = na_jax
+            normal_amp.append(na)
         return tuple(normal_amp)
 
     @property
@@ -865,12 +927,16 @@ class CorrelatedFieldMaker:
         normal_amp = self.get_normalized_amplitudes()[0]
 
         if np.isscalar(self.azm):
-            return normal_amp
+            na = normal_amp
         else:
             expand = ContractionOperator(
                 normal_amp.target, len(normal_amp.target) - 1
             ).adjoint
-            return normal_amp * (expand @ self.azm)
+            na = normal_amp * (expand @ self.azm)
+
+        if self._jax_cfm:
+            na._jax_expr = self._jax_cfm.amplitude
+        return na
 
     @property
     def power_spectrum(self):
diff --git a/src/library/correlated_fields_simple.py b/src/library/correlated_fields_simple.py
index ca754dd243d231456df5f7ed5aa2152355d9961c..57d53e39f959fed5dc106cb50dbfc6c9225aac84 100644
--- a/src/library/correlated_fields_simple.py
+++ b/src/library/correlated_fields_simple.py
@@ -17,6 +17,7 @@
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
 
 import numpy as np
+from warnings import warn
 
 from ..domain_tuple import DomainTuple
 from ..domains.power_space import PowerSpace
@@ -130,4 +131,38 @@ def SimpleCorrelatedField(
         op = Adder(full(op.target, float(offset_mean))) @ op
     op.amplitude = a
     op.power_spectrum = a**2
+
+    try:
+        from .. import re as jft
+        from .. import RGSpace
+
+        if not all(isinstance(dom, RGSpace) for dom in op.target):
+            warn(f"unable to add JAX operator for {op.target!r}")
+            raise ImportError("short-circuit JAX init")
+
+        dists = tuple(e for di in op.target for e in di.distances)
+        cfm = jft.CorrelatedFieldMaker(prefix=prefix)
+        cfm.add_fluctuations(
+            shape=op.target.shape,
+            distances=dists,
+            fluctuations=fluctuations,
+            loglogavgslope=loglogavgslope,
+            flexibility=flexibility,
+            asperity=asperity,
+            prefix="",
+            harmonic_domain_type="fourier",
+            non_parametric_kind="power",
+        )
+        cfm.set_amplitude_total_offset(
+            offset_mean=offset_mean, offset_std=offset_std
+        )
+        cf, _ = cfm.finalize()
+
+        op._jax_expr = cf
+        op.amplitude._jax_expr = cfm.amplitude
+        op.power_spectrum._jax_expr = cfm.power_spectrum
+    except (ImportError, TypeError) as e:
+        if isinstance(e, TypeError):
+            warn(f"no JAX operator for this configuration;\n{e}")
+
     return op
diff --git a/src/library/special_distributions.py b/src/library/special_distributions.py
index ad12218fb366bd35de9381de4e3c75347c27c039..1a0bccd1025960e1054fe2bea547da4b89e83d48 100644
--- a/src/library/special_distributions.py
+++ b/src/library/special_distributions.py
@@ -78,6 +78,23 @@ class _InterpolationOperator(Operator):
         self._deriv = self._interpolator.derivative()
         self._inv_table_func = inv_table_func
 
+        try:
+            from jax import numpy as jnp
+
+            def jax_expr(x):
+                res = jnp.interp(x, self._xs, self._table)
+                if inv_table_func is not None:
+                    ve = (
+                        "can not translate arbitrary inverse"
+                        f" table function {inv_table_func!r}"
+                    )
+                    raise ValueError(ve)
+                return res
+
+            self._jax_expr = jax_expr
+        except ImportError:
+            self._jax_expr = None
+
     def apply(self, x):
         self._check_input(x)
         lin = x.jac is not None
@@ -118,7 +135,7 @@ class InverseGammaOperator(Operator):
         time the domain and the target of the operator.
     alpha : float
         The alpha-parameter of the inverse-gamma distribution.
-    q : float or Field
+    q : float or :class:`nifty8.field.Field`
         The q-parameter of the inverse-gamma distribution.
     mode: float
         The mode of the inverse-gamma distribution.
@@ -155,6 +172,14 @@ class InverseGammaOperator(Operator):
             op = makeOp(self._q) @ op
         self._op = op
 
+        try:
+            from ..re.stats_distributions import invgamma_prior
+
+            q_val = self._q.val if isinstance(self._q, Field) else self._q
+            self._jax_expr = invgamma_prior(float(self._alpha), q_val)
+        except ImportError:
+            self._jax_expr = None
+
     def apply(self, x):
         return self._op(x)
 
diff --git a/src/library/variational_models.py b/src/library/variational_models.py
index ac64ce5310bfb215bae9f9cff2b12ed0a08b5c20..82ed4552b4d2a376af37f21c3b67b5fb664241cb 100644
--- a/src/library/variational_models.py
+++ b/src/library/variational_models.py
@@ -50,7 +50,7 @@ class MeanFieldVI:
 
     Parameters
     ----------
-    position : Field
+    position : :class:`nifty8.field.Field`
         The initial estimate of the approximate mean parameter.
     hamiltonian : Energy
         Hamiltonian of the approximated probability distribution.
@@ -62,7 +62,7 @@ class MeanFieldVI:
         doubles. Mirroring samples stabilizes the KL estimate as extreme sample
         variation is counterbalanced. Since it improves stability in many
         cases, it is recommended to set `mirror_samples` to `True`.
-    initial_sig : positive Field or positive float
+    initial_sig : positive :class:`nifty8.field.Field` or positive float
         The initial estimate of the standard deviation.
     comm : MPI communicator or None
         If not None, samples will be distributed as evenly as possible across
@@ -140,7 +140,7 @@ class FullCovarianceVI:
 
     Parameters
     ----------
-    position : Field
+    position : :class:`nifty8.field.Field`
         The initial estimate of the approximate mean parameter.
     hamiltonian : Energy
         Hamiltonian of the approximated probability distribution.
diff --git a/src/linearization.py b/src/linearization.py
index f1a590c0608c8ffe09c4a2ddd8f9ec7a9350f037..75c4ccf894bf96d48b809430260b59bee922bf28 100644
--- a/src/linearization.py
+++ b/src/linearization.py
@@ -29,7 +29,7 @@ class Linearization(Operator):
 
     Parameters
     ----------
-    val : Field or MultiField
+    val : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
         The value of the operator application.
     jac : LinearOperator
         The Jacobian.
@@ -52,7 +52,7 @@ class Linearization(Operator):
 
         Parameters
         ----------
-        val : Field or MultiField
+        val : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
             the value of the operator application
         jac : LinearOperator
             the Jacobian
@@ -83,7 +83,7 @@ class Linearization(Operator):
 
     @property
     def val(self):
-        """Field or MultiField : the value"""
+        """:class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` : the value"""
         return self._val
 
     @property
@@ -93,7 +93,7 @@ class Linearization(Operator):
 
     @property
     def gradient(self):
-        """Field or MultiField : the gradient
+        """:class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` : the gradient
 
         Notes
         -----
@@ -198,7 +198,7 @@ class Linearization(Operator):
 
         Parameters
         ----------
-        other : Field or MultiField or Linearization
+        other : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` or Linearization
 
         Returns
         -------
@@ -223,7 +223,7 @@ class Linearization(Operator):
 
         Parameters
         ----------
-        other : Field or MultiField or Linearization
+        other : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` or Linearization
 
         Returns
         -------
@@ -292,7 +292,7 @@ class Linearization(Operator):
 
         Parameters
         ----------
-        field : Field or Multifield
+        field : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
             the field to be converted
         want_metric : bool
             If True, the metric will be computed for other Linearizations
@@ -313,7 +313,7 @@ class Linearization(Operator):
 
         Parameters
         ----------
-        field : Field or Multifield
+        field : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
             the field to be converted
         want_metric : bool
             If True, the metric will be computed for other Linearizations
@@ -338,7 +338,7 @@ class Linearization(Operator):
 
         Parameters
         ----------
-        field : Field or Multifield
+        field : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
             the field to be converted
         want_metric : bool
             If True, the metric will be computed for other Linearizations
@@ -367,7 +367,7 @@ class Linearization(Operator):
 
         Parameters
         ----------
-        field : Multifield
+        field ::class:`nifty8.multi_field.MultiField`
             the field to be converted
         constants : list of string
             the MultiField components for which the Jacobian should be
diff --git a/src/minimization/descent_minimizers.py b/src/minimization/descent_minimizers.py
index 4c50ed0f644560b14adee3eed37273fc2331bf8f..9280392d2b4aa160b4020dfd445bcbb6834b6701 100644
--- a/src/minimization/descent_minimizers.py
+++ b/src/minimization/descent_minimizers.py
@@ -125,7 +125,7 @@ class DescentMinimizer(Minimizer):
 
         Returns
         -------
-        Field
+        :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
            The descent direction.
         """
         raise NotImplementedError
@@ -316,9 +316,9 @@ class _InformationStore:
     ----------
     max_history_length : int
         Maximum number of stored past updates.
-    x0 : Field
+    x0 : :class:`nifty8.field.Field`
         Initial position in variable space.
-    gradient : Field
+    gradient : :class:`nifty8.field.Field`
         Gradient at position x0.
 
     Attributes
@@ -329,9 +329,9 @@ class _InformationStore:
         Circular buffer of past position differences, which are Fields.
     y : List
         Circular buffer of past gradient differences, which are Fields.
-    last_x : Field
+    last_x : :class:`nifty8.field.Field`
         Latest position in variable space.
-    last_gradient : Field
+    last_gradient : :class:`nifty8.field.Field`
         Gradient at latest position.
     k : int
         Number of updates that have taken place
diff --git a/src/minimization/energy.py b/src/minimization/energy.py
index 5981612310224b05df8dc1a9995c7f85c90e9ef2..1a8c7922403f0b58c2c07bd4c72c1832e975ee65 100644
--- a/src/minimization/energy.py
+++ b/src/minimization/energy.py
@@ -26,7 +26,7 @@ class Energy(metaclass=NiftyMeta):
 
     Parameters
     ----------
-    position : Field
+    position : :class:`nifty8.field.Field`
         The input parameter of the scalar function.
 
     Notes
@@ -51,7 +51,7 @@ class Energy(metaclass=NiftyMeta):
 
         Parameters
         ----------
-        position : Field
+        position : :class:`nifty8.field.Field`
             Location in parameter space for the new Energy object.
 
         Returns
@@ -64,7 +64,7 @@ class Energy(metaclass=NiftyMeta):
     @property
     def position(self):
         """
-        Field : selected location in parameter space.
+        field : selected location in parameter space.
 
         The Field location in parameter space where value, gradient and
         metric are evaluated.
@@ -83,7 +83,7 @@ class Energy(metaclass=NiftyMeta):
     @property
     def gradient(self):
         """
-        Field : The gradient at given `position`.
+        field : The gradient at given `position`.
         """
         raise NotImplementedError
 
@@ -109,12 +109,12 @@ class Energy(metaclass=NiftyMeta):
         """
         Parameters
         ----------
-        x: Field or MultiField
+        x : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
             Argument for the metric operator
 
         Returns
         -------
-        Field or MultiField:
+        :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
             Output of the metric operator
         """
         raise NotImplementedError
@@ -124,7 +124,7 @@ class Energy(metaclass=NiftyMeta):
 
         Parameters
         ----------
-        direction : Field
+        direction : :class:`nifty8.field.Field`
             the search direction
 
         Returns
diff --git a/src/minimization/energy_adapter.py b/src/minimization/energy_adapter.py
index 1f265c3079c8fe515c7cc42de2270a43549cb20b..afb2f15add1ce314ec9ccbb75cf7a0625fad1a46 100644
--- a/src/minimization/energy_adapter.py
+++ b/src/minimization/energy_adapter.py
@@ -33,16 +33,16 @@ class EnergyAdapter(Energy):
 
     Parameters
     -----------
-    position: Field or MultiField
+    position : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
         The position where the minimization process is started.
-    op: EnergyOperator
+    op : EnergyOperator
         The expression computing the energy from the input data.
-    constants: list of strings
+    constants : list of strings
         The component names of the operator's input domain which are assumed
         to be constant during the minimization process.
         If the operator's input domain is not a MultiField, this must be empty.
         Default: [].
-    want_metric: bool
+    want_metric : bool
         If True, the class will provide a `metric` property. This should only
         be enabled if it is required, because it will most likely consume
         additional resources. Default: False.
@@ -170,7 +170,7 @@ class StochasticEnergyAdapter(Energy):
 
         Parameters
         ----------
-        position : MultiField
+        position : :class:`nifty8.multi_field.MultiField`
             Values of the optimization parameters
         op : Operator
             The objective function of the optimization problem. Must have a
diff --git a/src/minimization/kl_energies.py b/src/minimization/kl_energies.py
index 4889dba9aebfe71eaa23d736c968e1c7f44a8401..06e4402731cf5bc978dac4f3e80ca3ea436b5cb0 100644
--- a/src/minimization/kl_energies.py
+++ b/src/minimization/kl_energies.py
@@ -51,7 +51,7 @@ def _reduce_by_keys(field, operator, keys):
 
     Parameters
     ----------
-    field : Field or MultiField
+    field : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
         Potentially partially constant input field.
     operator : Operator
         Operator into which `field` is partially inserted.
@@ -183,9 +183,9 @@ def SampledKLEnergy(position, hamiltonian, n_samples, minimizer_sampling,
 
     Parameters
     ----------
-    position : Field
+    position : :class:`nifty8.field.Field`
         Expansion point of the coordinate transformation.
-    hamiltonian : StandardHamiltonian
+    hamiltonian : :class:`nifty8.operators.energy_operators.StandardHamiltonian`
         Hamiltonian of the approximated probability distribution.
     n_samples : integer
         Number of samples used to stochastically estimate the KL.
diff --git a/src/minimization/line_search.py b/src/minimization/line_search.py
index 127df9f7d7720e31a2ff0de69c756b5c06eaa5c2..f30f753ea4fb1b1c17019d1c7b04271140d9e0d6 100644
--- a/src/minimization/line_search.py
+++ b/src/minimization/line_search.py
@@ -34,7 +34,7 @@ class LineEnergy:
         self.energy.position = zero_point + line_position*line_direction
     energy : Energy
         The Energy object which will be evaluated along the given direction.
-    line_direction : Field
+    line_direction : :class:`nifty8.field.Field`
         Direction used for line evaluation. Does not have to be normalized.
     offset :  float *optional*
         Indirectly defines the zero point of the line via the equation
@@ -156,7 +156,7 @@ class LineSearch(metaclass=NiftyMeta):
         energy : Energy
             Energy object from which we will calculate the energy and the
             gradient at a specific point.
-        pk : Field
+        pk : :class:`nifty8.field.Field`
             Vector pointing into the search direction.
         f_k_minus_1 : float, optional
             Value of the fuction (which is being minimized) at the k-1
diff --git a/src/minimization/optimize_kl.py b/src/minimization/optimize_kl.py
index 0947c3516041ab0d1a92a2651330b78ef5983834..4a9446637673c3cc230a55c900301ec16943b019 100644
--- a/src/minimization/optimize_kl.py
+++ b/src/minimization/optimize_kl.py
@@ -138,14 +138,14 @@ def optimize_kl(likelihood_energy,
     output_directory : str or None
         Directory in which all output files are saved. If None, no output is
         stored.  Default: "nifty_optimize_kl_output".
-    initial_position : Field, MultiField or None
+    initial_position : :class:`nifty8.field.Field`, :class:`nifty8.multi_field.MultiField` or None
         Position in the definition space of `likelihood_energy` from which the
         optimization is started. If `None`, it starts at a random, normal
         distributed position with standard deviation 0.1. Default: None.
     initial_index : int
         Initial index that is used to enumerate the output files. May be used
         if `optimize_kl` is called multiple times. Default: 0.
-    ground_truth_position : Field, MultiField or None
+    ground_truth_position : :class:`nifty8.field.Field`, :class:`nifty8.multi_field.MultiField` or None
         Position in latent space that represents the ground truth. Used only in
         plotting. May be useful for validating algorithms.
     comm : MPI communicator or None
@@ -195,7 +195,7 @@ def optimize_kl(likelihood_energy,
     -------
     sl : SampleList
 
-    mean : Field or MultiField (optional)
+    mean : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` (optional)
 
     Note
     ----
diff --git a/src/minimization/quadratic_energy.py b/src/minimization/quadratic_energy.py
index f921caa30f35a258275f9ecc161c186a08440c14..a63d7ee22960a0e4cc9619d2f9db8c09b281cdb8 100644
--- a/src/minimization/quadratic_energy.py
+++ b/src/minimization/quadratic_energy.py
@@ -50,9 +50,9 @@ class QuadraticEnergy(Energy):
 
         Parameters
         ----------
-        position : Field
+        position : :class:`nifty8.field.Field`
             Location in parameter space for the new Energy object.
-        grad : Field
+        grad : :class:`nifty8.field.Field`
             Energy gradient at the new position.
 
         Returns
diff --git a/src/minimization/sample_list.py b/src/minimization/sample_list.py
index 0167586a040d1684cdce3e236c4e30667c5423c9..b709c77a97500d6d5d148bebf68ad8821f7ab23e 100644
--- a/src/minimization/sample_list.py
+++ b/src/minimization/sample_list.py
@@ -453,9 +453,9 @@ class ResidualSampleList(SampleListBase):
 
         Parameters
         ----------
-        mean : Field or MultiField
+        mean : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
             Mean of the sample list.
-        residuals : list of Field or list of MultiField
+        residuals : list of :class:`nifty8.field.Field` or list of :class:`nifty8.multi_field.MultiField`
             List of residuals from the mean. If it is a list of `MultiField`,
             the domain of the residuals can be a subdomain of the domain of
             mean. This results in adding just a zero in respective `MultiField`
@@ -547,7 +547,7 @@ class SampleList(SampleListBase):
 
         Parameters
         ----------
-        samples : list of Field or list of MultiField
+        samples : list of :class:`nifty8.field.Field` or list of :class:`nifty8.multi_field.MultiField`
             List of samples.
         comm : MPI communicator or None
             If not `None`, samples can be gathered across multiple MPI tasks. If
diff --git a/src/multi_field.py b/src/multi_field.py
index ed6f05d9da973d1ebc318a8eb89e51aad9f801b0..e87573901efbda77d9c0594d96cab8e1b849781f 100644
--- a/src/multi_field.py
+++ b/src/multi_field.py
@@ -262,7 +262,7 @@ class MultiField(Operator):
 
         Parameters
         ----------
-        other : MultiField
+        other : :class:`nifty8.multi_field.MultiField`
             the partner Field
 
         Returns
@@ -281,7 +281,7 @@ class MultiField(Operator):
 
         Parameters
         ----------
-        fields : iterable of MultiFields
+        fields : iterable of :class:`nifty8.multi_field.MultiField`
             The set of input fields. Their domains need not be identical.
         domain : MultiDomain or None
             If supplied, this will be the domain of the resulting field.
@@ -308,7 +308,7 @@ class MultiField(Operator):
 
         Parameters
         ----------
-        other : MultiField
+        other : :class:`nifty8.multi_field.MultiField`
             the partner Field
         neg : bool or dict
             if True, the partner field is subtracted, otherwise added
diff --git a/src/nifty2jax.py b/src/nifty2jax.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bb3910370b7bf82211c567973b4c3acc9fc1131
--- /dev/null
+++ b/src/nifty2jax.py
@@ -0,0 +1,149 @@
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial, reduce
+import operator
+from typing import Any, Callable, Optional, Union
+from warnings import warn
+
+from jax.tree_util import register_pytree_node_class
+
+from . import re as jft
+from .domain_tuple import DomainTuple
+from .field import Field
+from .multi_domain import MultiDomain
+from .multi_field import MultiField
+from .operators.operator import Operator
+from .sugar import makeField
+
+
+def spaces_to_axes(domain, spaces):
+    """Converts spaces in a domain to axes of the underlying NumPy array."""
+    if spaces is None:
+        return None
+
+    domain = DomainTuple.make(domain)
+    axes = tuple(domain.axes[sp_index] for sp_index in spaces)
+    axes = reduce(operator.add, axes) if len(axes) > 0 else axes
+    return axes
+
+
+def shapewithdtype_from_domain(domain, dtype):
+    if isinstance(dtype, dict):
+        dtp_fallback = float  # Fallback to `float` for unspecified keys
+        k2dtp = dtype
+    else:
+        dtp_fallback = dtype
+        k2dtp = {}
+
+    if isinstance(domain, MultiDomain):
+        parameter_tree = {}
+        for k, dom in domain.items():
+            parameter_tree[k] = jft.ShapeWithDtype(
+                dom.shape, k2dtp.get(k, dtp_fallback)
+            )
+    elif isinstance(domain, DomainTuple):
+        parameter_tree = jft.ShapeWithDtype(domain.shape, dtype)
+    else:
+        raise TypeError(f"incompatible domain {domain!r}")
+    return parameter_tree
+
+
+@register_pytree_node_class
+class Model(jft.Field):
+    """Modified field class with an additional call method taking itself as
+    input.
+    """
+    def __init__(self, apply: Optional[Callable], val, domain=None, flags=None):
+        """Instantiates a modified field with an accompanying callable.
+
+        Parameters
+        ----------
+        apply : callable
+            Method acting on `val`.
+        val : object
+            Arbitrary, flatten-able objects.
+        domain : dict or None, optional
+            Domain of the field, e.g. with description of modes and volume.
+        flags : set, str or None, optional
+            Capabilities and constraints of the field.
+        """
+        super().__init__(val, domain, flags)
+        self._apply = apply
+
+    def tree_flatten(self):
+        return ((self._val, ), (self._apply, self._domain, self._flags))
+
+    @classmethod
+    def tree_unflatten(cls, aux_data, children):
+        return cls(
+            aux_data[0], *children, domain=aux_data[1], flags=aux_data[2]
+        )
+
+    def __call__(self, *args, **kwargs):
+        if self._apply is None:
+            nie = "no `apply` method specified; behaving like field"
+            raise NotImplementedError(nie)
+        return self._apply(*args, **kwargs)
+
+    @property
+    def field(self):
+        return jft.Field(self.val, domain=self.domain, flags=self.flags)
+
+    def __str__(self):
+        s = f"Model(\n{self._apply},\n{self.val}"
+        if self._domain:
+            s += f",\ndomain={self._domain}"
+        if self._flags:
+            s += f",\nflags={self._flags}"
+        s += ")"
+        return s
+
+    def __repr__(self):
+        s = f"Model(\n{self._apply!r},\n{self.val!r}"
+        if self._domain:
+            s += f",\ndomain={self._domain!r}"
+        if self._flags:
+            s += f",\nflags={self._flags!r}"
+        s += ")"
+        return s
+
+
+def wrap_nifty_call(op, target_dtype=float) -> Callable[[Any], jft.Field]:
+    from jax.experimental.host_callback import call
+
+    if callable(op.jax_expr):
+        warn("wrapping operator that has a callable `.jax_expr`")
+
+    def pack_unpack_call(x):
+        x = makeField(op.domain, x)
+        return op(x).val
+
+    # TODO: define custom JVP and VJP rules
+    pt = shapewithdtype_from_domain(op.target, target_dtype)
+    hcb_call = partial(call, pack_unpack_call, result_shape=pt)
+
+    def wrapped_call(x) -> jft.Field:
+        return jft.Field(hcb_call(x))
+
+    return wrapped_call
+
+
+def convert(nifty_obj: Union[Operator,DomainTuple,MultiDomain], dtype=float) -> Model:
+    if not isinstance(nifty_obj, (Operator, DomainTuple, MultiDomain)):
+        raise TypeError(f"invalid input type {type(nifty_obj)!r}")
+
+    if isinstance(nifty_obj, (Field, MultiField)):
+        expr = None
+        parameter_tree = jft.Field(nifty_obj.val)
+    elif isinstance(nifty_obj, (DomainTuple, MultiDomain)):
+        expr = None
+        parameter_tree = shapewithdtype_from_domain(nifty_obj, dtype)
+    else:
+        expr = nifty_obj.jax_expr
+        parameter_tree = shapewithdtype_from_domain(nifty_obj.domain, dtype)
+        if not callable(expr):
+            # TODO: implement conversion via host_callback and custom_vjp
+            raise NotImplementedError("Sorry, not yet done :(")
+
+    return Model(expr, parameter_tree)
diff --git a/src/operators/adder.py b/src/operators/adder.py
index 6e9f0aece191ca9ead3454879a160476e97883fb..58bee1c735b8e01ee06ee0acfcace9cdc2e055e0 100644
--- a/src/operators/adder.py
+++ b/src/operators/adder.py
@@ -15,6 +15,8 @@
 #
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
 
+from operator import add, sub
+
 import numpy as np
 
 from ..field import Field
@@ -28,7 +30,7 @@ class Adder(Operator):
 
     Parameters
     ----------
-    a : Field or MultiField or Scalar
+    a : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` or Scalar
         The field by which the input is shifted.
     """
     def __init__(self, a, neg=False, domain=None):
@@ -42,6 +44,24 @@ class Adder(Operator):
         self._domain = self._target = dom
         self._neg = bool(neg)
 
+        try:
+            from ..re import Field as ReField
+            from jax.tree_util import tree_map
+
+            a_j = ReField(a.val) if isinstance(a, (Field, MultiField)) else a
+
+            def jax_expr(x):
+                # Preserve the input type
+                if not isinstance(x, ReField):
+                    a_astype_x = a_j.val if isinstance(a_j, ReField) else a_j
+                else:
+                    a_astype_x = a_j
+                return tree_map(sub if neg else add, x, a_astype_x)
+
+            self._jax_expr = jax_expr
+        except ImportError:
+            self._jax_expr = None
+
     def apply(self, x):
         self._check_input(x)
         if self._neg:
diff --git a/src/operators/chain_operator.py b/src/operators/chain_operator.py
index 62e7b8bf6eea18277eb687224f2f701c0a5db8b0..52f06fe2ef86b5814c84f35a776ea92317a0fbcf 100644
--- a/src/operators/chain_operator.py
+++ b/src/operators/chain_operator.py
@@ -40,6 +40,17 @@ class ChainOperator(LinearOperator):
         self._domain = self._ops[-1].domain
         self._target = self._ops[0].target
 
+        if all(callable(op.jax_expr) for op in ops):
+
+            def joined_jax_op(x):
+                for op in reversed(ops):
+                    x = op.jax_expr(x)
+                return x
+
+            self._jax_expr = joined_jax_op
+        else:
+            self._jax_expr = None
+
     @staticmethod
     def simplify(ops):
         # verify domains
diff --git a/src/operators/contraction_operator.py b/src/operators/contraction_operator.py
index 9eb10770752bf32d01444a7be403a5beb8a358c6..2db2343597a6125176d3d3c067ff1f4b92f50f87 100644
--- a/src/operators/contraction_operator.py
+++ b/src/operators/contraction_operator.py
@@ -15,6 +15,8 @@
 #
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
 
+from functools import partial
+
 import numpy as np
 
 from .. import utilities
@@ -51,6 +53,35 @@ class ContractionOperator(LinearOperator):
         self._power = power
         self._capability = self.TIMES | self.ADJOINT_TIMES
 
+        try:
+            from jax import numpy as jnp
+            from jax.tree_util import tree_map
+            from ..nifty2jax import spaces_to_axes
+
+            fct = jnp.array(1.)
+            wgt = jnp.array(1.)
+            if self._power != 0:
+                for ind in self._spaces:
+                    wgt_spc = self._domain[ind].dvol
+                    if np.isscalar(wgt_spc):
+                        fct *= wgt_spc
+                    else:
+                        new_shape = np.ones(len(self._domain.shape), dtype=np.int64)
+                        new_shape[self._domain.axes[ind][0]:
+                                  self._domain.axes[ind][-1]+1] = wgt_spc.shape
+                        wgt *= wgt_spc.reshape(new_shape)**power
+                fct = fct**power
+
+            def weighted_space_sum(x):
+                if self._power != 0:
+                    x = fct * wgt * x
+                axes = spaces_to_axes(self._domain, self._spaces)
+                return tree_map(partial(jnp.sum, axis=axes), x)
+
+            self._jax_expr = weighted_space_sum
+        except ImportError:
+            self._jax_expr = None
+
     def apply(self, x, mode):
         self._check_input(x, mode)
         if mode == self.ADJOINT_TIMES:
diff --git a/src/operators/diagonal_operator.py b/src/operators/diagonal_operator.py
index caeefca05c44a746f72a1d22ea23a298f71d499a..22978268397b004a984f96263f5d0cf3b8d102a8 100644
--- a/src/operators/diagonal_operator.py
+++ b/src/operators/diagonal_operator.py
@@ -16,6 +16,8 @@
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
 
 import numpy as np
+from functools import partial
+from operator import mul
 
 from .. import utilities
 from ..domain_tuple import DomainTuple
@@ -32,7 +34,7 @@ class DiagonalOperator(EndomorphicOperator):
 
     Parameters
     ----------
-    diagonal : Field
+    diagonal : :class:`nifty8.field.Field`
         The diagonal entries of the operator.
     domain : Domain, tuple of Domain or DomainTuple, optional
         The domain on which the Operator's input Field is defined.
@@ -92,6 +94,8 @@ class DiagonalOperator(EndomorphicOperator):
             self._ldiag = diagonal.val
         self._fill_rest()
 
+        self._jax_expr = partial(mul, self._ldiag)
+
     def _fill_rest(self):
         self._ldiag.flags.writeable = False
         self._complex = utilities.iscomplextype(self._ldiag.dtype)
@@ -109,6 +113,9 @@ class DiagonalOperator(EndomorphicOperator):
             res._spaces = tuple(set(self._spaces) | set(spc))
         res._ldiag = np.array(ldiag)
         res._fill_rest()
+
+        res._jax_expr = partial(mul, res._ldiag)
+
         return res
 
     def _scale(self, fct):
diff --git a/src/operators/einsum.py b/src/operators/einsum.py
index 03aa39989fb47ca3175491333690eba594b280be..8486a186635ae8c697635aa6f8ebd28e4c8808b8 100644
--- a/src/operators/einsum.py
+++ b/src/operators/einsum.py
@@ -174,7 +174,7 @@ class LinearEinsum(LinearOperator):
     ----------
     domain : Domain, DomainTuple or tuple of Domain
         The operator's input domain.
-    mf : MultiField
+    mf : :class:`nifty8.multi_field.MultiField`
         The first part of the left-hand side of the einsum.
     subscripts : str
         The subscripts which is passed to einsum. Everything before the very
diff --git a/src/operators/endomorphic_operator.py b/src/operators/endomorphic_operator.py
index 6adbfd1bf413b235a817378089731f5519e0c553..0e3fad962e5e540078c1cb58df63864043f7c6f7 100644
--- a/src/operators/endomorphic_operator.py
+++ b/src/operators/endomorphic_operator.py
@@ -43,7 +43,7 @@ class EndomorphicOperator(LinearOperator):
 
         Returns
         -------
-        Field or MultiField
+        :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
             A sample from the Gaussian of given covariance.
         """
         raise NotImplementedError
diff --git a/src/operators/energy_operators.py b/src/operators/energy_operators.py
index 4d2b3b8c78c77743c37d3362bf6369009eaa99c2..6d9f11c15300955c5be0bf6438a3fb0075a14132 100644
--- a/src/operators/energy_operators.py
+++ b/src/operators/energy_operators.py
@@ -488,7 +488,7 @@ class GaussianEnergy(LikelihoodEnergyOperator):
 
     Parameters
     ----------
-    data : Field or None
+    data : :class:`nifty8.field.Field` or None
         Observed data of the Gaussian likelihood. If `inverse_covariance` is
         `None`, the `dtype` of `data` is used for sampling. Default is
         0.
@@ -597,7 +597,7 @@ class PoissonianEnergy(LikelihoodEnergyOperator):
 
     Parameters
     ----------
-    d : Field
+    d : :class:`nifty8.field.Field`
         Data field with counts. Needs to have integer dtype and all field
         values need to be non-negative.
     """
@@ -635,14 +635,14 @@ class InverseGammaEnergy(LikelihoodEnergyOperator):
         \\sum_i (\\alpha_i+1)*\\ln(x_i) + \\beta_i/x_i
 
     This is the likelihood for the variance :math:`x=S_k` given data
-    :math:`\\beta = 0.5 |s_k|^2` where the Field :math:`s` is known to have
-    the covariance :math:`S_k`.
+    :math:`\\beta = 0.5 |s_k|^2` where the :class:`nifty8.field.Field`
+    :math:`s` is known to have the covariance :math:`S_k`.
 
     Parameters
     ----------
-    beta : Field
+    beta : :class:`nifty8.field.Field`
         beta parameter of the inverse gamma distribution
-    alpha : Scalar, Field, optional
+    alpha : Scalar, :class:`nifty8.field.Field`, optional
         alpha parameter of the inverse gamma distribution
     """
 
@@ -694,7 +694,7 @@ class StudentTEnergy(LikelihoodEnergyOperator):
     ----------
     domain : `Domain` or `DomainTuple`
         Domain of the operator
-    theta : Scalar or Field
+    theta : Scalar or :class:`nifty8.field.Field`
         Degree of freedom parameter for the student t distribution
     """
 
@@ -733,7 +733,7 @@ class BernoulliEnergy(LikelihoodEnergyOperator):
 
     Parameters
     ----------
-    d : Field
+    d : :class:`nifty8.field.Field`
         Data field with events (1) or non-events (0).
     """
 
@@ -838,7 +838,7 @@ class AveragedEnergy(EnergyOperator):
     ----------
     h: Hamiltonian
        The energy to be averaged.
-    res_samples : iterable of Fields
+    res_samples : iterable of :class:`nifty8.field.Field`
        Set of residual sample points to be added to mean field for
        approximate estimation of the KL.
 
diff --git a/src/operators/harmonic_operators.py b/src/operators/harmonic_operators.py
index 39d4b31bead9cdfa6b202665fbee32a29209cd76..1895f8cdf6cc599923cf2286a829eec38892bf7f 100644
--- a/src/operators/harmonic_operators.py
+++ b/src/operators/harmonic_operators.py
@@ -16,6 +16,7 @@
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
 
 import numpy as np
+from functools import partial
 
 from .. import utilities
 from ..domain_tuple import DomainTuple
@@ -71,6 +72,36 @@ class FFTOperator(LinearOperator):
         adom.check_codomain(target)
         target.check_codomain(adom)
 
+        try:
+            from jax.numpy import fft as jfft
+
+            axes = self.domain.axes[self._space]
+
+            def jax_expr(x, inverse=False):
+                if inverse:
+                    if self.domain[self._space].harmonic:
+                        func = jfft.fftn
+                        fct = 1.
+                    else:
+                        func = jfft.ifftn
+                        fct = self.domain[self._space].size
+                    fct *= self.target[self._space].scalar_dvol
+                else:
+                    if self.domain[self._space].harmonic:
+                        func = jfft.ifftn
+                        fct = self.domain[self._space].size
+                    else:
+                        func = jfft.fftn
+                        fct = 1.
+                    fct *= self.domain[self._space].scalar_dvol
+                return fct * func(x, axes=axes) if fct != 1 else func(x, axes=axes)
+
+            self._jax_expr = jax_expr
+            self._jax_expr_inv = partial(jax_expr, inverse=True)
+
+        except ImportError:
+            self._jax_expr = None
+
     def apply(self, x, mode):
         self._check_input(x, mode)
         ncells = x.domain[self._space].size
@@ -138,6 +169,33 @@ class HartleyOperator(LinearOperator):
         adom.check_codomain(target)
         target.check_codomain(adom)
 
+        try:
+            from jax.numpy import fft as jfft
+
+            axes = self.domain.axes[self._space]
+
+            def hartley(a):
+                ft = jfft.fftn(a, axes=axes)
+                return ft.real + ft.imag
+
+            def apply_cartesian(x, inverse=False):
+                if inverse:
+                    fct = self.target[self._space].scalar_dvol
+                else:
+                    fct = self.domain[self._space].scalar_dvol
+                return fct * hartley(x) if fct != 1 else hartley(x)
+
+            def jax_expr(x, inverse=False):
+                ap = partial(apply_cartesian, inverse=inverse)
+                if np.issubdtype(x.dtype.type, np.complexfloating):
+                    return ap(x.real) + 1j * ap(x.imag)
+                return ap(x)
+
+            self._jax_expr = jax_expr
+            self._jax_expr_inv = partial(jax_expr, inverse=True)
+        except ImportError:
+            self._jax_expr = None
+
     def apply(self, x, mode):
         self._check_input(x, mode)
         if utilities.iscomplextype(x.dtype):
@@ -314,6 +372,7 @@ class HarmonicTransformOperator(LinearOperator):
         self._domain = self._op.domain
         self._target = self._op.target
         self._capability = self.TIMES | self.ADJOINT_TIMES
+        self._jax_expr = self._op.jax_expr
 
     def apply(self, x, mode):
         self._check_input(x, mode)
diff --git a/src/operators/jax_operator.py b/src/operators/jax_operator.py
index a70676cfbe98ff349ca1d2fff4502497d769110d..79275a71fdf638ab0565f76bfe034cba1913a5a6 100644
--- a/src/operators/jax_operator.py
+++ b/src/operators/jax_operator.py
@@ -18,6 +18,7 @@ from types import SimpleNamespace
 from warnings import warn
 
 import numpy as np
+from functools import partial
 
 from .energy_operators import LikelihoodEnergyOperator
 from .linear_operator import LinearOperator
@@ -59,17 +60,23 @@ class JaxOperator(Operator):
         self._domain = makeDomain(domain)
         self._target = makeDomain(target)
         self._func = jax.jit(func)
-        self._vjp = jax.jit(lambda x: jax.vjp(func, x))
+        self._bwd = jax.jit(lambda x, y: jax.vjp(func, x)[1](y)[0])
         self._fwd = jax.jit(lambda x, y: jax.jvp(self._func, (x,), (y,))[1])
 
+        self._jax_expr = func
+
     def apply(self, x):
         from ..multi_domain import MultiDomain
         from ..sugar import is_linearization, makeField
         self._check_input(x)
         if is_linearization(x):
-            res, bwd = self._vjp(x.val.val)
-            fwd = lambda y: self._fwd(x.val.val, y)
-            jac = JaxLinearOperator(self._domain, self._target, fwd, func_T=lambda x: bwd(x)[0])
+            # TODO: Adapt the Linearization class to handle value_and_grad
+            # calls. Computing the pass through the function thrice (once now
+            # and twice when differentiating) is redundant and inefficient.
+            res = self._func(x.val.val)
+            bwd = partial(self._bwd, x.val.val)
+            fwd = partial(self._fwd, x.val.val)
+            jac = JaxLinearOperator(self._domain, self._target, fwd, func_T=bwd)
             return x.new(makeField(self._target, _jax2np(res)), jac)
         res = _jax2np(self._func(x.val))
         if isinstance(res, dict):
@@ -157,6 +164,8 @@ class JaxLinearOperator(LinearOperator):
         self._func_T = func_T
         self._capability = self.TIMES | self.ADJOINT_TIMES
 
+        self._jax_expr = func
+
     def apply(self, x, mode):
         from ..sugar import makeField
         self._check_input(x, mode)
diff --git a/src/operators/linear_operator.py b/src/operators/linear_operator.py
index c2186c8875195e824382324854057e4df5b050bd..7c6522441dd69f3cbd0a4d57d36a72adf8596d28 100644
--- a/src/operators/linear_operator.py
+++ b/src/operators/linear_operator.py
@@ -148,7 +148,7 @@ class LinearOperator(Operator):
 
         Parameters
         ----------
-        x : Field
+        x : :class:`nifty8.field.Field`
             The input Field, defined on the Operator's domain or target,
             depending on mode.
 
@@ -161,7 +161,7 @@ class LinearOperator(Operator):
 
         Returns
         -------
-        Field
+        :class:`nifty8.field.Field`
             The processed Field defined on the Operator's target or domain,
             depending on mode.
         """
@@ -180,12 +180,12 @@ class LinearOperator(Operator):
 
         Parameters
         ----------
-        x : Field
+        x : :class:`nifty8.field.Field`
             The input Field, defined on the Operator's domain.
 
         Returns
         -------
-        Field
+        :class:`nifty8.field.Field`
             The processed Field defined on the Operator's target domain.
         """
         return self.apply(x, self.TIMES)
@@ -195,12 +195,12 @@ class LinearOperator(Operator):
 
         Parameters
         ----------
-        x : Field
+        x : :class:`nifty8.field.Field`
             The input Field, defined on the Operator's target domain
 
         Returns
         -------
-        Field
+        :class:`nifty8.field.Field`
             The processed Field defined on the Operator's domain.
         """
         return self.apply(x, self.INVERSE_TIMES)
@@ -210,12 +210,12 @@ class LinearOperator(Operator):
 
         Parameters
         ----------
-        x : Field
+        x : :class:`nifty8.field.Field`
             The input Field, defined on the Operator's target domain
 
         Returns
         -------
-        Field
+        :class:`nifty8.field.Field`
             The processed Field defined on the Operator's domain.
         """
         return self.apply(x, self.ADJOINT_TIMES)
@@ -225,12 +225,12 @@ class LinearOperator(Operator):
 
         Parameters
         ----------
-        x : Field
+        x : :class:`nifty8.field.Field`
             The input Field, defined on the Operator's domain.
 
         Returns
         -------
-        Field
+        :class:`nifty8.field.Field`
             The processed Field defined on the Operator's target domain.
 
         Notes
diff --git a/src/operators/mask_operator.py b/src/operators/mask_operator.py
index 11adb3b4efbc60de385331b970bcad8f87faa999..9d2584ffafc22119cf99be30561112a99b840695 100644
--- a/src/operators/mask_operator.py
+++ b/src/operators/mask_operator.py
@@ -31,7 +31,7 @@ class MaskOperator(LinearOperator):
 
     Parameters
     ----------
-    flags : Field
+    flags : :class:`nifty8.field.Field`
         Is converted to boolean. Where True, the input field is flagged.
     """
     def __init__(self, flags):
@@ -42,6 +42,11 @@ class MaskOperator(LinearOperator):
         self._target = DomainTuple.make(UnstructuredDomain(self._flags.sum()))
         self._capability = self.TIMES | self.ADJOINT_TIMES
 
+        def mask(x):
+            return x[self._flags]
+
+        self._jax_expr = mask
+
     def apply(self, x, mode):
         self._check_input(x, mode)
         x = x.val
diff --git a/src/operators/operator.py b/src/operators/operator.py
index 60e7fc5a0c706243bc05df54dd4025d0179469f0..722c1946b70bd04e7aa436f8fa7edf67502c259b 100644
--- a/src/operators/operator.py
+++ b/src/operators/operator.py
@@ -21,6 +21,9 @@ from operator import add
 
 import numpy as np
 
+from warnings import warn
+from typing import Callable, Optional
+
 from .. import pointwise
 from ..domain_tuple import DomainTuple
 from ..logger import logger
@@ -112,6 +115,15 @@ class Operator(metaclass=NiftyMeta):
         """
         return None
 
+    @property
+    def jax_expr(self) -> Optional[Callable]:
+        """Equivalent representation of the operator in JAX."""
+        expr = getattr(self, "_jax_expr", None)
+        # NOTE, it is incredibly useful to enable this for debugging
+        # if expr is None:
+        #     warn(f"no JAX expression associated with operator {self!r}")
+        return expr
+
     def scale(self, factor):
         if not isinstance(factor, numbers.Number):
             raise TypeError(".scale() takes a number as input")
@@ -250,7 +262,7 @@ class Operator(metaclass=NiftyMeta):
 
         Parameters
         ----------
-        x : Field or MultiField
+        x : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
             Input on which the operator shall act. Needs to be defined on
             :attr:`domain`.
         """
@@ -416,6 +428,32 @@ class _FunctionApplier(Operator):
         self._args = args
         self._kwargs = kwargs
 
+        try:
+            import jax.numpy as jnp
+            from jax import nn as jax_nn
+
+            if funcname in pointwise.ptw_nifty2jax_dict:
+                jax_expr = pointwise.ptw_nifty2jax_dict[funcname]
+            elif hasattr(jnp, funcname):
+                jax_expr = getattr(jnp, funcname)
+            elif hasattr(jax_nn, funcname):
+                jax_expr = getattr(jax_nn, funcname)
+            else:
+                warn(f"unable to add JAX call for {funcname!r}")
+                jax_expr = None
+
+            def jax_expr_part(x):  # Partial insert with first open argument
+                return jax_expr(x, *args, **kwargs)
+
+            if isinstance(self.domain, MultiDomain):
+                from functools import partial
+                from jax.tree_util import tree_map
+
+                jax_expr_part = partial(tree_map, jax_expr_part)
+            self._jax_expr = jax_expr_part
+        except ImportError:
+            self._jax_expr = None
+
     def apply(self, x):
         self._check_input(x)
         return x.ptw(self._funcname, *self._args, **self._kwargs)
@@ -425,11 +463,22 @@ class _FunctionApplier(Operator):
 
 
 class _CombinedOperator(Operator):
-    def __init__(self, ops, _callingfrommake=False):
+    def __init__(self, ops, jax_ops, _callingfrommake=False):
         if not _callingfrommake:
             raise NotImplementedError
         self._ops = tuple(ops)
 
+        if all(callable(jop) for jop in jax_ops):
+
+            def joined_jax_op(x):
+                for jop in reversed(jax_ops):
+                    x = jop(x)
+                return x
+
+            self._jax_expr = joined_jax_op
+        else:
+            self._jax_expr = None
+
     @classmethod
     def unpack(cls, ops, res):
         for op in ops:
@@ -444,12 +493,13 @@ class _CombinedOperator(Operator):
         res = cls.unpack(ops, [])
         if len(res) == 1:
             return res[0]
-        return cls(res, _callingfrommake=True)
+        jax_res = tuple(op.jax_expr for op in ops)
+        return cls(res, jax_res, _callingfrommake=True)
 
 
 class _OpChain(_CombinedOperator):
-    def __init__(self, ops, _callingfrommake=False):
-        super(_OpChain, self).__init__(ops, _callingfrommake)
+    def __init__(self, ops, jax_ops, _callingfrommake=False):
+        super(_OpChain, self).__init__(ops, jax_ops, _callingfrommake)
         self._domain = self._ops[-1].domain
         self._target = self._ops[0].target
         for i in range(1, len(self._ops)):
@@ -486,6 +536,17 @@ class _OpProd(Operator):
         self._op1 = op1
         self._op2 = op2
 
+        lhs_has_jax = callable(self._op1.jax_expr)
+        rhs_has_jax = callable(self._op2.jax_expr)
+        if lhs_has_jax and rhs_has_jax:
+
+            def joined_jax_expr(x):
+                return self._op1.jax_expr(x) * self._op2.jax_expr(x)
+
+            self._jax_expr = joined_jax_expr
+        else:
+            self._jax_expr = None
+
     def apply(self, x):
         from ..linearization import Linearization
         from ..sugar import makeOp
@@ -529,6 +590,16 @@ class _OpSum(Operator):
         self._op1 = op1
         self._op2 = op2
 
+        try:
+            from ..re import unite
+
+            def joined_jax_expr(x):
+                return unite(self._op1.jax_expr(x), self._op2.jax_expr(x))
+
+            self._jax_expr = joined_jax_expr
+        except ImportError:
+            self._jax_expr = None
+
     def apply(self, x):
         self._check_input(x)
         return self._apply_operator_sum(x, [self._op1, self._op2])
diff --git a/src/operators/operator_adapter.py b/src/operators/operator_adapter.py
index e16a43e3565b0d073ccd1afc9c75b24bb7c02f19..9f8da261e0afa8c0aa29fda0b8481015e95cddbb 100644
--- a/src/operators/operator_adapter.py
+++ b/src/operators/operator_adapter.py
@@ -38,7 +38,7 @@ class OperatorAdapter(LinearOperator):
         3) adjoint inverse
     """
 
-    def __init__(self, op, op_transform):
+    def __init__(self, op, op_transform, domain_dtype=float):
         self._op = op
         self._trafo = int(op_transform)
         if self._trafo < 1 or self._trafo > 3:
@@ -47,6 +47,35 @@ class OperatorAdapter(LinearOperator):
         self._target = self._op._tgt(1 << self._trafo)
         self._capability = self._capTable[self._trafo][self._op.capability]
 
+        try:
+            from jax import eval_shape, linear_transpose
+            import jax.numpy as jnp
+            from jax.tree_util import tree_map, tree_all
+
+            from ..nifty2jax import shapewithdtype_from_domain
+            from ..re import Field
+
+            if callable(op.jax_expr) and self._trafo == self.ADJOINT_BIT:
+                def jax_expr(y):
+                    op_domain = shapewithdtype_from_domain(op.domain, domain_dtype)
+                    op_domain = Field(op_domain) if isinstance(y, Field) else op_domain
+                    tentative_yshape = eval_shape(op.jax_expr, op_domain)
+                    if not tree_all(tree_map(lambda a,b : jnp.can_cast(a.dtype, b.dtype), y, tentative_yshape)): 
+                        raise ValueError(f"wrong dtype during transposition:/got {tentative_yshape} and expected {y!r}")
+                    y = tree_map(lambda c, d: c.astype(d.dtype, casting="safe", copy=False), y, tentative_yshape) 
+                    y_conj = tree_map(jnp.conj, y)
+                    jax_expr_T = linear_transpose(op.jax_expr, op_domain)
+                    return tree_map(jnp.conj, jax_expr_T(y_conj)[0])
+
+                self._jax_expr = jax_expr
+            elif hasattr(op, "_jax_expr_inv") and callable(op._jax_expr_inv) and self._trafo == self.INVERSE_BIT:
+                self._jax_expr = op._jax_expr_inv
+                self._jax_expr_inv = op._jax_expr
+            else:
+                self._jax_expr = None
+        except ImportError:
+            self._jax_expr = None
+
     def _flip_modes(self, trafo):
         newtrafo = trafo ^ self._trafo
         return self._op if newtrafo == 0 \
diff --git a/src/operators/outer_product_operator.py b/src/operators/outer_product_operator.py
index 72b8b71f233784e2087de6e1149669d610dd2e91..6dd6a4d4248b2ee9fb8ded42488a120700b4d6aa 100644
--- a/src/operators/outer_product_operator.py
+++ b/src/operators/outer_product_operator.py
@@ -15,10 +15,12 @@
 #
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
 
+from functools import partial
 import numpy as np
 
 from ..domain_tuple import DomainTuple
 from ..field import Field
+from ..multi_field import MultiField
 from .linear_operator import LinearOperator
 
 
@@ -27,8 +29,8 @@ class OuterProduct(LinearOperator):
 
     Parameters
     ---------
-    domain: DomainTuple, the domain of the input field
-    field: Field
+    domain : DomainTuple, the domain of the input field
+    field : :class:`nifty8.field.Field`
     ---------
     """
     def __init__(self, domain, field):
@@ -38,6 +40,29 @@ class OuterProduct(LinearOperator):
             tuple(sub_d for sub_d in field.domain._dom + self._domain._dom))
         self._capability = self.TIMES | self.ADJOINT_TIMES
 
+        try:
+            from ..re import Field as ReField
+            from jax import numpy as jnp
+            from jax.tree_util import tree_map
+
+            a_j = ReField(field.val) if isinstance(field, (Field, MultiField)) else field
+
+            def jax_expr(x):
+                # Preserve the input type
+                if not isinstance(x, ReField):
+                    a_astype_x = a_j.val if isinstance(a_j, ReField) else a_j
+                else:
+                    a_astype_x = a_j
+
+                return tree_map(
+                    partial(jnp.tensordot, axes=((), ())),
+                    a_astype_x, x
+                )
+
+            self._jax_expr = jax_expr
+        except ImportError:
+            self._jax_expr = None
+
     def apply(self, x, mode):
         self._check_input(x, mode)
         if mode == self.TIMES:
diff --git a/src/operators/scaling_operator.py b/src/operators/scaling_operator.py
index ab6e79f9a2770d500b2e3e8ba5a6a27b154668b9..6a557027836fc84a38ca08d0b747eb80d20bf2f4 100644
--- a/src/operators/scaling_operator.py
+++ b/src/operators/scaling_operator.py
@@ -66,6 +66,14 @@ class ScalingOperator(EndomorphicOperator):
         check_dtype_or_none(sampling_dtype, self._domain)
         self._dtype = sampling_dtype
 
+        try:
+            from jax import numpy as jnp
+            from functools import partial
+
+            self._jax_expr = partial(jnp.multiply, factor)
+        except ImportError:
+            self._jax_expr = None
+
     def apply(self, x, mode):
         from ..sugar import full
 
diff --git a/src/operators/simple_linear_operators.py b/src/operators/simple_linear_operators.py
index 336f19b1bfb84158b9757405e86b68c0235bc2fa..5e075cb3d7809334aec3a134810934391ca62806 100644
--- a/src/operators/simple_linear_operators.py
+++ b/src/operators/simple_linear_operators.py
@@ -16,6 +16,7 @@
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
 
 import numpy as np
+from functools import partial
 
 from ..domain_tuple import DomainTuple
 from ..domains.unstructured_domain import UnstructuredDomain
@@ -32,7 +33,7 @@ class VdotOperator(LinearOperator):
 
     Parameters
     ----------
-    field : Field or MultiField
+    field : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
         The field used to build the scalar product with the operator input
     """
     def __init__(self, field):
@@ -41,6 +42,13 @@ class VdotOperator(LinearOperator):
         self._target = DomainTuple.scalar_domain()
         self._capability = self.TIMES | self.ADJOINT_TIMES
 
+        try:
+            from ..re import vdot
+
+            self._jax_expr = partial(vdot, field.val)
+        except ImportError:
+            self._jax_expr = None
+
     def apply(self, x, mode):
         self._check_mode(mode)
         if mode == self.TIMES:
@@ -61,6 +69,14 @@ class ConjugationOperator(EndomorphicOperator):
         self._domain = DomainTuple.make(domain)
         self._capability = self._all_ops
 
+        try:
+            from jax import numpy as jnp
+            from jax.tree_util import tree_map
+
+            self._jax_expr = partial(tree_map, jnp.conjugate)
+        except ImportError:
+            self._jax_expr = None
+
     def apply(self, x, mode):
         self._check_input(x, mode)
         return x.conjugate()
@@ -108,6 +124,14 @@ class Realizer(EndomorphicOperator):
         self._domain = DomainTuple.make(domain)
         self._capability = self.TIMES | self.ADJOINT_TIMES
 
+        try:
+            from jax import numpy as jnp
+            from jax.tree_util import tree_map
+
+            self._jax_expr = partial(tree_map, jnp.real)
+        except ImportError:
+            self._jax_expr = None
+
     def apply(self, x, mode):
         self._check_input(x, mode)
         return x.real
@@ -126,6 +150,14 @@ class Imaginizer(EndomorphicOperator):
         self._domain = DomainTuple.make(domain)
         self._capability = self.TIMES | self.ADJOINT_TIMES
 
+        try:
+            from jax import numpy as jnp
+            from jax.tree_util import tree_map
+
+            self._jax_expr = partial(tree_map, jnp.imag)
+        except ImportError:
+            self._jax_expr = None
+
     def apply(self, x, mode):
         self._check_input(x, mode)
         if mode == self.TIMES:
@@ -166,6 +198,22 @@ class FieldAdapter(LinearOperator):
             self._target = MultiDomain.make({name: tmp[name]})
         self._capability = self.TIMES | self.ADJOINT_TIMES
 
+        try:
+            from .. import re as jft
+
+            def wrap(x):
+                return jft.Field({name: x})
+
+            def unwrap(x):
+                return x[name]
+
+            if isinstance(tmp, DomainTuple):
+                self._jax_expr = unwrap
+            else:
+                self._jax_expr = wrap
+        except ImportError:
+            self._jax_expr = None
+
     def apply(self, x, mode):
         self._check_input(x, mode)
         if isinstance(x, MultiField):
@@ -310,6 +358,11 @@ class GeometryRemover(LinearOperator):
         self._target = DomainTuple.make(tgt)
         self._capability = self.TIMES | self.ADJOINT_TIMES
 
+        def identity(x):
+            return x
+
+        self._jax_expr = identity
+
     def apply(self, x, mode):
         self._check_input(x, mode)
         return x.cast_domain(self._tgt(mode))
diff --git a/src/operators/sum_operator.py b/src/operators/sum_operator.py
index 3cbbb05106b75112e853b924e861beedd7b3a26e..65a8018ed800269783426baa3df48a64cdb97574 100644
--- a/src/operators/sum_operator.py
+++ b/src/operators/sum_operator.py
@@ -16,6 +16,7 @@
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
 
 from collections import defaultdict
+import operator
 
 from ..sugar import domain_union
 from ..utilities import indent
@@ -42,6 +43,24 @@ class SumOperator(LinearOperator):
         for op in ops:
             self._capability &= op.capability
 
+        try:
+            from ..re import unite
+
+            def joined_jax_expr(x):
+                res = None
+                for op, n in zip(ops, neg):
+                    tmp = op.jax_expr(x)
+                    if res is None:
+                        res = -tmp if n is True else tmp
+                    else:
+                        o = operator.sub if n is True else operator.add
+                        res = unite(res, tmp, op=o)
+                return res
+
+            self._jax_expr = joined_jax_expr
+        except ImportError:
+            self._jax_expr = None
+
     @staticmethod
     def simplify(ops, neg):
         from .diagonal_operator import DiagonalOperator
@@ -173,7 +192,7 @@ class SumOperator(LinearOperator):
             Individual operators of the sum.
         neg: list of bool
             Same length as ops.
-            If True then the equivalent operator gets a minus in the sum.
+            If True then the corresponding operator gets a minus in the sum.
         """
         ops = tuple(ops)
         neg = tuple(neg)
diff --git a/src/plot.py b/src/plot.py
index d47fca6d6a3f5267b797770d36fa366247ad6163..efcba788a55ff146d0a810bfed50a9251b0cf523 100644
--- a/src/plot.py
+++ b/src/plot.py
@@ -575,7 +575,7 @@ class Plot:
 
         Parameters
         ----------
-        f: Field or list of Field or None
+        f : :class:`nifty8.field.Field` or list of :class:`nifty8.field.Field` or None
             If `f` is a single Field, it must be defined on a single `RGSpace`,
             `PowerSpace`, `HPSpace`, `GLSpace`.
             If it is a list, all list members must be Fields defined over the
diff --git a/src/pointwise.py b/src/pointwise.py
index b709d2ba674b5ecccae9fdbb2ebbf5598bc0749b..a15d74bf34f8269da0fab50a4a5c4c53cb234280 100644
--- a/src/pointwise.py
+++ b/src/pointwise.py
@@ -153,3 +153,23 @@ ptw_dict = {
     "arctan": (np.arctan, lambda v: (np.arctan(v), 1./(1.+v**2))),
     "unitstep": (lambda v: _step_helper(v, False), lambda v: _step_helper(v, True))
     }
+
+
+def sigmoid_j(v):
+    from jax import numpy as jnp
+
+    # NOTE, the sigmoid used in NIFTy is different to the one commonly referred
+    # to as sigmoid in most of the literature.
+    return 0.5 + (0.5 * jnp.tanh(v))
+
+
+def exponentiate_j(v, base):
+    from jax import numpy as jnp
+
+    return jnp.power(base, v)
+
+
+ptw_nifty2jax_dict = {
+    "sigmoid": sigmoid_j,
+    "exponentiate": exponentiate_j,
+}
diff --git a/src/probing.py b/src/probing.py
index eae3ef916155708b699ba543ddb019949abcef1a..04068f0c7232c48379ff66787d7ebe007da907b0 100644
--- a/src/probing.py
+++ b/src/probing.py
@@ -87,7 +87,7 @@ def probe_with_posterior_samples(op, post_op, nprobes, dtype):
 
     Returns
     -------
-    List of Field
+    List of :class:`nifty8.field.Field`
         List of two fields: the mean and the variance.
     '''
     if not isinstance(op, EndomorphicOperator):
@@ -129,7 +129,7 @@ def probe_diagonal(op, nprobes, random_type="pm1"):
 
     Returns
     -------
-    Field
+    :class:`nifty8.field.Field`
         The estimated diagonal.
     '''
     sc = StatCalculator()
diff --git a/src/re/README.md b/src/re/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1dfa61822c8f6496f658c84d173acac7f062b578
--- /dev/null
+++ b/src/re/README.md
@@ -0,0 +1,24 @@
+# Re-envisioning NIFTy
+
+## JAX
+
+The (soft linked) code in this directory is a new interface for NIFTy written in JAX.
+Some features of this new API are straight-forward re-implementations of features in NIFTy while other features are orthogonal to NIFTy and follow a different, usually more functional approach.
+All essential pieces of NIFTy are implemented and the API is capable of (almost) fully replacing NIFTy's current NumPy based implementation.
+
+### Current Features
+
+* MAP
+* MGVI
+* geoVI
+* Non-parametric correlated field
+
+### TODO
+
+The likelihood (or the Hamiltonian) probably is the object where it makes the most sense to translate to a different interface.
+The minimization can be different depending on the API used but the likelihood should be a common denominator.
+Inference schemes like MGVI, geoVI or MAP do not need to be similar nor should they be.
+For all of these methods a more functional approach is desired instead.
+
+Overall, it would make sense to re-implement `optimize_kl` from NIFTy because it abstracts away many details of how MGVI, geoVI or MAP is implemented.
+Furthermore, this would make transitioning from NumPy NIFTy to a JAX-based NIFTy more easy while at the same time allowing for many changes to the interfaces of MGVI, geoVI and MAP.
diff --git a/src/re/__init__.py b/src/re/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c170487ff1b48622b6a6b26fa0bc3f7e447975e6
--- /dev/null
+++ b/src/re/__init__.py
@@ -0,0 +1,68 @@
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from . import refine
+from . import refine_util
+from . import refine_chart
+from . import lanczos
+from . import structured_kernel_interpolation
+from .conjugate_gradient import cg, static_cg
+from .correlated_field import CorrelatedFieldMaker, non_parametric_amplitude
+from .energy_operators import (
+    Categorical,
+    Gaussian,
+    Poissonian,
+    StudentT,
+    VariableCovarianceGaussian,
+    VariableCovarianceStudentT,
+)
+from .field import Field
+from .forest_util import (
+    ShapeWithDtype,
+    assert_arithmetics,
+    dot,
+    has_arithmetics,
+    map_forest,
+    map_forest_mean,
+    norm,
+    shape,
+    size,
+    stack,
+    unite,
+    unstack,
+    vdot,
+    zeros_like,
+)
+from .hmc import generate_hmc_acc_rej, generate_nuts_tree
+from .hmc_oo import HMCChain, NUTSChain
+from .kl import (
+    GeoMetricKL,
+    MetricKL,
+    geometrically_sample_standard_hamiltonian,
+    mean_hessp,
+    mean_metric,
+    mean_value_and_grad,
+    sample_standard_hamiltonian,
+)
+from .lanczos import stochastic_lq_logdet
+from .likelihood import Likelihood, StandardHamiltonian
+from .optimize import minimize, newton_cg, trust_ncg
+from .refine_chart import CoordinateChart, RefinementField
+from .stats_distributions import (
+    invgamma_invprior,
+    invgamma_prior,
+    laplace_prior,
+    lognormal_invprior,
+    lognormal_prior,
+    normal_prior,
+    uniform_prior,
+)
+from .sugar import (
+    ducktape,
+    ducktape_left,
+    interpolate,
+    mean,
+    mean_and_std,
+    random_like,
+    sum_of_squares,
+)
diff --git a/src/re/conjugate_gradient.py b/src/re/conjugate_gradient.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f3415a06e360761531bbf49c205f52b94c330f0
--- /dev/null
+++ b/src/re/conjugate_gradient.py
@@ -0,0 +1,650 @@
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+import sys
+from datetime import datetime
+from functools import partial
+from jax import numpy as jnp
+from jax import lax
+
+from typing import Any, Callable, NamedTuple, Optional, Tuple, Union
+
+from .forest_util import assert_arithmetics, common_type, size, where, zeros_like
+from .forest_util import norm as jft_norm
+from .sugar import doc_from, sum_of_squares
+
+HessVP = Callable[[jnp.ndarray], jnp.ndarray]
+
+N_RESET = 20
+
+
+class CGResults(NamedTuple):
+    x: jnp.ndarray
+    nit: Union[int, jnp.ndarray]
+    nfev: Union[int, jnp.ndarray]  # number of matrix-evaluations
+    info: Union[int, jnp.ndarray]
+    success: Union[bool, jnp.ndarray]
+
+
+def cg(mat, j, x0=None, *args, **kwargs) -> Tuple[Any, Union[int, jnp.ndarray]]:
+    """Solve `mat(x) = j` using Conjugate Gradient. `mat` must be callable and
+    represent a hermitian, positive definite matrix.
+
+    Notes
+    -----
+    If set, the parameters `absdelta` and `resnorm` always take precedence over
+    `tol` and `atol`.
+    """
+    assert_arithmetics(j)
+    if x0 is not None:
+        assert_arithmetics(x0)
+    cg_res = _cg(mat, j, x0, *args, **kwargs)
+    return cg_res.x, cg_res.info
+
+
+@doc_from(cg)
+def static_cg(mat, j, x0=None, *args, **kwargs):
+    assert_arithmetics(j)
+    if x0 is not None:
+        assert_arithmetics(x0)
+    cg_res = _static_cg(mat, j, x0, *args, **kwargs)
+    return cg_res.x, cg_res.info
+
+
+# Taken from nifty
+def _cg(
+    mat,
+    j,
+    x0=None,
+    *,
+    absdelta=None,
+    resnorm=None,
+    norm_ord=None,
+    tol=1e-5,  # taken from SciPy's linalg.cg
+    atol=0.,
+    miniter=None,
+    maxiter=None,
+    name=None,
+    time_threshold=None,
+    _within_newton=False
+) -> CGResults:
+    norm_ord = 2 if norm_ord is None else norm_ord  # TODO: change to 1
+    maxiter_fallback = 20 * size(j)  # taken from SciPy's NewtonCG minimzer
+    miniter = min(
+        (6, maxiter if maxiter is not None else maxiter_fallback)
+    ) if miniter is None else miniter
+    maxiter = max(
+        (min((200, maxiter_fallback)), miniter)
+    ) if maxiter is None else maxiter
+
+    if absdelta is None and resnorm is None:  # fallback convergence criterion
+        resnorm = jnp.maximum(tol * jft_norm(j, ord=norm_ord, ravel=True), atol)
+
+    common_dtp = common_type(j)
+    eps = 6. * jnp.finfo(common_dtp).eps  # taken from SciPy's NewtonCG minimzer
+    tiny = 6. * jnp.finfo(common_dtp).tiny
+
+    if x0 is None:
+        pos = zeros_like(j)
+        r = -j
+        d = r
+        # energy = .5xT M x - xT j
+        energy = 0.
+        nfev = 0
+    else:
+        pos = x0
+        r = mat(pos) - j
+        d = r
+        energy = float(((r - j) / 2).dot(pos))
+        nfev = 1
+    previous_gamma = float(sum_of_squares(r))
+    if previous_gamma == 0:
+        info = 0
+        return CGResults(x=pos, info=info, nit=0, nfev=nfev, success=True)
+
+    info = -1
+    i = 0
+    for i in range(1, maxiter + 1):
+        q = mat(d)
+        nfev += 1
+
+        curv = float(d.dot(q))
+        if curv == 0.:
+            if _within_newton:
+                info = 0
+                break
+            nm = "CG" if name is None else name
+            raise ValueError(f"{nm}: zero curvature")
+        elif curv < 0.:
+            if _within_newton and i > 1:
+                info = 0
+                break
+            elif _within_newton:
+                pos = previous_gamma / (-curv) * j
+                info = 0
+                break
+            nm = "CG" if name is None else name
+            raise ValueError(f"{nm}: negative curvature")
+        alpha = previous_gamma / curv
+        pos = pos - alpha * d
+        if i % N_RESET == 0:
+            r = mat(pos) - j
+            nfev += 1
+        else:
+            r = r - q * alpha
+        gamma = float(sum_of_squares(r))
+        if time_threshold is not None and datetime.now() > time_threshold:
+            info = i
+            break
+        if gamma >= 0. and gamma <= tiny:
+            nm = "CG" if name is None else name
+            print(f"{nm}: gamma=0, converged!", file=sys.stderr)
+            info = 0
+            break
+        if resnorm is not None:
+            norm = float(jft_norm(r, ord=norm_ord, ravel=True))
+            if name is not None:
+                msg = f"{name}: |∇|:{norm:.6e} 🞋:{resnorm:.6e}"
+                print(msg, file=sys.stderr)
+            if norm < resnorm and i >= miniter:
+                info = 0
+                break
+        if absdelta is not None or name is not None:
+            new_energy = float(((r - j) / 2).dot(pos))
+            energy_diff = energy - new_energy
+            if name is not None:
+                msg = (
+                    f"{name}: Iteration {i} ⛰:{new_energy:+.6e} Δ⛰:{energy_diff:.6e}"
+                    + (f" 🞋:{absdelta:.6e}" if absdelta is not None else "")
+                )
+                print(msg, file=sys.stderr)
+        else:
+            new_energy = energy
+        if absdelta is not None:
+            neg_energy_eps = -eps * jnp.abs(new_energy)
+            if energy_diff < neg_energy_eps:
+                nm = "CG" if name is None else name
+                raise ValueError(f"{nm}: WARNING: energy increased")
+            if neg_energy_eps <= energy_diff < absdelta and i >= miniter:
+                info = 0
+                break
+        energy = new_energy
+        d = d * max(0, gamma / previous_gamma) + r
+        previous_gamma = gamma
+    else:
+        nm = "CG" if name is None else name
+        print(f"{nm}: Iteration Limit Reached", file=sys.stderr)
+        info = i
+    return CGResults(x=pos, info=info, nit=i, nfev=nfev, success=info == 0)
+
+
+def _static_cg(
+    mat,
+    j,
+    x0=None,
+    *,
+    absdelta=None,
+    resnorm=None,
+    norm_ord=None,
+    tol=1e-5,  # taken from SciPy's linalg.cg
+    atol=0.,
+    miniter=None,
+    maxiter=None,
+    name=None,
+    _within_newton=False,  # TODO
+    **kwargs
+) -> CGResults:
+    from jax.lax import cond, while_loop
+
+    norm_ord = 2 if norm_ord is None else norm_ord  # TODO: change to 1
+    maxiter_fallback = 20 * size(j)  # taken from SciPy's NewtonCG minimzer
+    miniter = jnp.minimum(
+        6, maxiter if maxiter is not None else maxiter_fallback
+    ) if miniter is None else miniter
+    maxiter = jnp.maximum(
+        jnp.minimum(200, maxiter_fallback), miniter
+    ) if maxiter is None else maxiter
+
+    if absdelta is None and resnorm is None:  # fallback convergence criterion
+        resnorm = jnp.maximum(tol * jft_norm(j, ord=norm_ord, ravel=True), atol)
+
+    common_dtp = common_type(j)
+    eps = 6. * jnp.finfo(common_dtp).eps  # taken from SciPy's NewtonCG minimzer
+    tiny = 6. * jnp.finfo(common_dtp).tiny
+
+    def continue_condition(v):
+        return v["info"] < -1
+
+    def cg_single_step(v):
+        info = v["info"]
+        pos, r, d, i = v["pos"], v["r"], v["d"], v["iteration"]
+        previous_gamma, previous_energy = v["gamma"], v["energy"]
+
+        i += 1
+
+        q = mat(d)
+        curv = d.dot(q)
+        # ValueError("zero curvature in conjugate gradient")
+        info = jnp.where(curv == 0., -1, info)
+        alpha = previous_gamma / curv
+        # ValueError("implausible gradient scaling `alpha < 0`")
+        info = jnp.where(alpha < 0., -1, info)
+        pos = pos - alpha * d
+        r = cond(
+            i % N_RESET == 0, lambda x: mat(x["pos"]) - x["j"],
+            lambda x: x["r"] - x["q"] * x["alpha"], {
+                "pos": pos,
+                "j": j,
+                "r": r,
+                "q": q,
+                "alpha": alpha
+            }
+        )
+        gamma = sum_of_squares(r)
+
+        info = jnp.where(
+            (gamma >= 0.) & (gamma <= tiny) & (info != -1), 0, info
+        )
+        if resnorm is not None:
+            norm = jft_norm(r, ord=norm_ord, ravel=True)
+            info = jnp.where(
+                (norm < resnorm) & (i >= miniter) & (info != -1), 0, info
+            )
+        else:
+            norm = None
+        # Do not compute the energy if we do not check `absdelta`
+        if absdelta is not None or name is not None:
+            energy = ((r - j) / 2).dot(pos)
+            energy_diff = previous_energy - energy
+        else:
+            energy = previous_energy
+            energy_diff = None
+        if absdelta is not None:
+            neg_energy_eps = -eps * jnp.abs(energy)
+            # print(f"energy increased", file=sys.stderr)
+            info = jnp.where(energy_diff < neg_energy_eps, -1, info)
+            info = jnp.where(
+                (energy_diff >= neg_energy_eps) & (energy_diff < absdelta) &
+                (i >= miniter) & (info != -1), 0, info
+            )
+        info = jnp.where((i >= maxiter) & (info != -1), i, info)
+
+        d = d * jnp.maximum(0, gamma / previous_gamma) + r
+
+        if name is not None:
+            from jax.experimental.host_callback import call
+
+            def pp(arg):
+                msg = (
+                    (
+                        "{name}: |∇|:{norm:.6e} 🞋:{resnorm:.6e}\n"
+                        if arg["resnorm"] is not None else ""
+                    ) + "{name}: Iteration {i} ⛰:{energy:+.6e}" +
+                    " Δ⛰:{energy_diff:.6e}" + (
+                        " 🞋:{absdelta:.6e}"
+                        if arg["absdelta"] is not None else ""
+                    ) + (
+                        "\n{name}: Iteration Limit Reached"
+                        if arg["i"] == arg["maxiter"] else ""
+                    )
+                )
+                print(msg.format(name=name, **arg), file=sys.stderr)
+
+            printable_state = {
+                "i": i,
+                "energy": energy,
+                "energy_diff": energy_diff,
+                "absdelta": absdelta,
+                "norm": norm,
+                "resnorm": resnorm,
+                "maxiter": maxiter
+            }
+            call(pp, printable_state, result_shape=None)
+
+        ret = {
+            "info": info,
+            "pos": pos,
+            "r": r,
+            "d": d,
+            "iteration": i,
+            "gamma": gamma,
+            "energy": energy
+        }
+        return ret
+
+    if x0 is None:
+        pos = zeros_like(j)
+        r = -j
+        d = r
+        nfev = 0
+    else:
+        pos = x0
+        r = mat(pos) - j
+        d = r
+        nfev = 1
+    energy = None
+    if absdelta is not None or name is not None:
+        if x0 is None:
+            # energy = .5xT M x - xT j
+            energy = jnp.array(0.)
+        else:
+            energy = ((r - j) / 2).dot(pos)
+
+    gamma = sum_of_squares(r)
+    val = {
+        "info": jnp.array(-2, dtype=int),
+        "pos": pos,
+        "r": r,
+        "d": d,
+        "iteration": jnp.array(0),
+        "gamma": gamma,
+        "energy": energy
+    }
+    # Finish early if already converged in the initial iteration
+    val["info"] = jnp.where(gamma == 0., 0, val["info"])
+
+    val = while_loop(continue_condition, cg_single_step, val)
+
+    i = val["iteration"]
+    info = val["info"]
+    nfev += i + i // N_RESET
+    return CGResults(
+        x=val["pos"], info=info, nit=i, nfev=nfev, success=info == 0
+    )
+
+
+# The following is code adapted from Nicholas Mancuso to work with pytrees
+class _QuadSubproblemResult(NamedTuple):
+    step: jnp.ndarray
+    hits_boundary: Union[bool, jnp.ndarray]
+    pred_f: Union[float, jnp.ndarray]
+    nit: Union[int, jnp.ndarray]
+    nfev: Union[int, jnp.ndarray]
+    njev: Union[int, jnp.ndarray]
+    nhev: Union[int, jnp.ndarray]
+    success: Union[bool, jnp.ndarray]
+
+
+class _CGSteihaugState(NamedTuple):
+    z: jnp.ndarray
+    r: jnp.ndarray
+    d: jnp.ndarray
+    step: jnp.ndarray
+    energy: Union[None, float, jnp.ndarray]
+    hits_boundary: Union[bool, jnp.ndarray]
+    done: Union[bool, jnp.ndarray]
+    nit: Union[int, jnp.ndarray]
+    nhev: Union[int, jnp.ndarray]
+
+
+def second_order_approx(
+    p: jnp.ndarray,
+    cur_val: Union[float, jnp.ndarray],
+    g: jnp.ndarray,
+    hessp_at_xk: HessVP,
+) -> Union[float, jnp.ndarray]:
+    return cur_val + g.dot(p) + 0.5 * p.dot(hessp_at_xk(p))
+
+
+def get_boundaries_intersections(
+    z: jnp.ndarray, d: jnp.ndarray, trust_radius: Union[float, jnp.ndarray]
+):  # Adapted from SciPy
+    """Solve the scalar quadratic equation ||z + t d|| == trust_radius.
+
+    This is like a line-sphere intersection.
+
+    Return the two values of t, sorted from low to high.
+    """
+    a = d.dot(d)
+    b = 2 * z.dot(d)
+    c = z.dot(z) - trust_radius**2
+    sqrt_discriminant = jnp.sqrt(b * b - 4 * a * c)
+
+    # The following calculation is mathematically
+    # equivalent to:
+    # ta = (-b - sqrt_discriminant) / (2*a)
+    # tb = (-b + sqrt_discriminant) / (2*a)
+    # but produce smaller round off errors.
+    # Look at Matrix Computation p.97
+    # for a better justification.
+    aux = b + jnp.copysign(sqrt_discriminant, b)
+    ta = -aux / (2 * a)
+    tb = -2 * c / aux
+
+    ra, rb = where(ta < tb, (ta, tb), (tb, ta))
+    return (ra, rb)
+
+
+def _cg_steihaug_subproblem(
+    cur_val: Union[float, jnp.ndarray],
+    g: jnp.ndarray,
+    hessp_at_xk: HessVP,
+    *,
+    trust_radius: Union[float, jnp.ndarray],
+    tr_norm_ord: Union[None, int, float, jnp.ndarray] = None,
+    resnorm: Optional[float],
+    absdelta: Optional[float] = None,
+    norm_ord: Union[None, int, float, jnp.ndarray] = None,
+    miniter: Union[None, int] = None,
+    maxiter: Union[None, int] = None,
+    name=None
+) -> _QuadSubproblemResult:
+    """
+    Solve the subproblem using a conjugate gradient method.
+
+    Parameters
+    ----------
+    cur_val : Union[float, jnp.ndarray]
+      Objective value evaluated at the current state.
+    g : jnp.ndarray
+      Gradient value evaluated at the current state.
+    hessp_at_xk: Callable
+      Function that accepts a proposal vector and computes the result of a
+      Hessian-vector product.
+    trust_radius : float
+      Upper bound on how large a step proposal can be.
+    tr_norm_ord : {non-zero int, inf, -inf, ‘fro’, ‘nuc’}, optional
+      Order of the norm for computing the length of the next step.
+    norm_ord : {non-zero int, inf, -inf, ‘fro’, ‘nuc’}, optional
+      Order of the norm for testing convergence.
+
+    Returns
+    -------
+    result : _QuadSubproblemResult
+      Contains the step proposal, whether it is at radius boundary, and
+      meta-data regarding function calls and successful convergence.
+
+    Notes
+    -----
+    This is algorithm (7.2) of Nocedal and Wright 2nd edition.
+    Only the function that computes the Hessian-vector product is required.
+    The Hessian itself is not required, and the Hessian does
+    not need to be positive semidefinite.
+    """
+    tr_norm_ord = jnp.inf if tr_norm_ord is None else tr_norm_ord  # taken from JAX
+    norm_ord = 2 if norm_ord is None else norm_ord  # TODO: change to 1
+    maxiter_fallback = 20 * size(g)  # taken from SciPy's NewtonCG minimzer
+    miniter = jnp.minimum(
+        6, maxiter if maxiter is not None else maxiter_fallback
+    ) if miniter is None else miniter
+    maxiter = jnp.maximum(
+        jnp.minimum(200, maxiter_fallback), miniter
+    ) if maxiter is None else maxiter
+
+    common_dtp = common_type(g)
+    eps = 6. * jnp.finfo(
+        common_dtp
+    ).eps  # Inspired by SciPy's NewtonCG minimzer
+
+    # second-order Taylor series approximation at the current values, gradient,
+    # and hessian
+    soa = partial(
+        second_order_approx, cur_val=cur_val, g=g, hessp_at_xk=hessp_at_xk
+    )
+
+    # helpers for internal switches in the main CGSteihaug logic
+    def noop(
+        param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]]
+    ) -> _CGSteihaugState:
+        iterp, z_next = param
+        return iterp
+
+    def step1(
+        param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]]
+    ) -> _CGSteihaugState:
+        iterp, z_next = param
+        z, d, nhev = iterp.z, iterp.d, iterp.nhev
+
+        ta, tb = get_boundaries_intersections(z, d, trust_radius)
+        pa = z + ta * d
+        pb = z + tb * d
+        p_boundary = where(soa(pa) < soa(pb), pa, pb)
+        return iterp._replace(
+            step=p_boundary, nhev=nhev + 2, hits_boundary=True, done=True
+        )
+
+    def step2(
+        param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]]
+    ) -> _CGSteihaugState:
+        iterp, z_next = param
+        z, d = iterp.z, iterp.d
+
+        ta, tb = get_boundaries_intersections(z, d, trust_radius)
+        p_boundary = z + tb * d
+        return iterp._replace(step=p_boundary, hits_boundary=True, done=True)
+
+    def step3(
+        param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]]
+    ) -> _CGSteihaugState:
+        iterp, z_next = param
+        return iterp._replace(step=z_next, hits_boundary=False, done=True)
+
+    # initialize the step
+    p_origin = zeros_like(g)
+
+    # init the state for the first iteration
+    z = p_origin
+    r = g
+    d = -r
+    energy = 0. if absdelta is not None or name is not None else None
+    init_param = _CGSteihaugState(
+        z=z,
+        r=r,
+        d=d,
+        step=p_origin,
+        energy=energy,
+        hits_boundary=False,
+        done=maxiter == 0,
+        nit=0,
+        nhev=0
+    )
+
+    # Search for the min of the approximation of the objective function.
+    def body_f(iterp: _CGSteihaugState) -> _CGSteihaugState:
+        z, r, d = iterp.z, iterp.r, iterp.d
+        energy, nit = iterp.energy, iterp.nit
+
+        nit += 1
+
+        Bd = hessp_at_xk(d)
+        dBd = d.dot(Bd)
+
+        r_squared = r.dot(r)
+        alpha = r_squared / dBd
+        z_next = z + alpha * d
+
+        r_next = r + alpha * Bd
+        r_next_squared = r_next.dot(r_next)
+
+        beta_next = r_next_squared / r_squared
+        d_next = -r_next + beta_next * d
+
+        accept_z_next = nit >= maxiter
+        if norm_ord == 2:
+            r_next_norm = jnp.sqrt(r_next_squared)
+        else:
+            r_next_norm = jft_norm(r_next, ord=norm_ord, ravel=True)
+        accept_z_next |= r_next_norm < resnorm
+        if absdelta is not None or name is not None:
+            # Relative to a plain CG, `z_next` is negative
+            energy_next = ((r_next + g) / 2).dot(z_next)
+            energy_diff = energy - energy_next
+        else:
+            energy_next = energy
+            energy_diff = jnp.nan
+        if absdelta is not None:
+            neg_energy_eps = -eps * jnp.abs(energy)
+            accept_z_next |= (energy_diff >= neg_energy_eps
+                             ) & (energy_diff < absdelta) & (nit >= miniter)
+
+        # include a junk switch to catch the case where none should be executed
+        z_next_norm = jft_norm(z_next, ord=tr_norm_ord, ravel=True)
+        index = jnp.argmax(
+            jnp.array(
+                [False, dBd <= 0, z_next_norm >= trust_radius, accept_z_next]
+            )
+        )
+        iterp = lax.switch(index, [noop, step1, step2, step3], (iterp, z_next))
+
+        iterp = iterp._replace(
+            z=z_next,
+            r=r_next,
+            d=d_next,
+            energy=energy_next,
+            nhev=iterp.nhev + 1,
+            nit=nit
+        )
+        if name is not None:
+            from jax.experimental.host_callback import call
+
+            def pp(arg):
+                msg = (
+                    "{name}: |∇|:{r_norm:.6e} 🞋:{resnorm:.6e} ↗:{tr:.6e}"
+                    " ☞:{case:1d} #∇²:{nhev:02d}"
+                    "\n{name}: Iteration {i} ⛰:{energy:+.6e} Δ⛰:{energy_diff:.6e}"
+                    + (
+                        " 🞋:{absdelta:.6e}"
+                        if arg["absdelta"] is not None else ""
+                    ) + (
+                        "\n{name}: Iteration Limit Reached"
+                        if arg["i"] == arg["maxiter"] else ""
+                    )
+                )
+                print(msg.format(name=name, **arg), file=sys.stderr)
+
+            printable_state = {
+                "i": nit,
+                "energy": iterp.energy,
+                "energy_diff": energy_diff,
+                "absdelta": absdelta,
+                "tr": trust_radius,
+                "r_norm": r_next_norm,
+                "resnorm": resnorm,
+                "nhev": iterp.nhev,
+                "case": index,
+                "maxiter": maxiter
+            }
+            call(pp, printable_state, result_shape=None)
+
+        return iterp
+
+    def cond_f(iterp: _CGSteihaugState) -> bool:
+        return jnp.logical_not(iterp.done)
+
+    # perform inner optimization to solve the constrained
+    # quadratic subproblem using cg
+    result = lax.while_loop(cond_f, body_f, init_param)
+
+    pred_f = soa(result.step)
+    result = _QuadSubproblemResult(
+        step=result.step,
+        hits_boundary=result.hits_boundary,
+        pred_f=pred_f,
+        nit=result.nit,
+        nfev=0,
+        njev=0,
+        nhev=result.nhev + 1,
+        success=True
+    )
+
+    return result
diff --git a/src/re/correlated_field.py b/src/re/correlated_field.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f6da97379ffd3a20a2260a4aa57e4d25b3f59ce
--- /dev/null
+++ b/src/re/correlated_field.py
@@ -0,0 +1,511 @@
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from collections.abc import Mapping
+from functools import partial
+import sys
+from typing import Callable, Dict, Optional, Tuple, Union
+
+from jax import numpy as jnp
+import numpy as np
+
+from .forest_util import ShapeWithDtype
+from .stats_distributions import lognormal_prior, normal_prior
+from .sugar import ducktape
+
+
+def _safe_assert(condition):
+    if not condition:
+        raise AssertionError()
+
+
+def hartley(p, axes=None):
+    from jax.numpy import fft
+
+    tmp = fft.fftn(p, axes=axes)
+    return tmp.real + tmp.imag
+
+
+def get_fourier_mode_distributor(
+    shape: Union[tuple, int], distances: Union[tuple, float]
+):
+    """Get the unique lengths of the Fourier modes, a mapping from a mode to
+    its length index and the multiplicity of each unique Fourier mode length.
+
+    Parameters
+    ----------
+    shape : tuple of int or int
+        Position-space shape.
+    distances : tuple of float or float
+        Position-space distances.
+
+    Returns
+    -------
+    mode_length_idx : jnp.ndarray
+        Index in power-space for every mode in harmonic-space. Can be used to
+        distribute power from a power-space to the full harmonic domain.
+    unique_mode_length : jnp.ndarray
+        Unique length of Fourier modes.
+    mode_multiplicity : jnp.ndarray
+        Multiplicity for each unique Fourier mode length.
+    """
+    shape = (shape, ) if isinstance(shape, int) else tuple(shape)
+
+    # Compute length of modes
+    mspc_distances = 1. / (jnp.array(shape) * jnp.array(distances))
+    m_length = jnp.arange(shape[0], dtype=jnp.float64)
+    m_length = jnp.minimum(m_length, shape[0] - m_length) * mspc_distances[0]
+    if len(shape) != 1:
+        m_length *= m_length
+        for i in range(1, len(shape)):
+            tmp = jnp.arange(shape[i], dtype=jnp.float64)
+            tmp = jnp.minimum(tmp, shape[i] - tmp) * mspc_distances[i]
+            tmp *= tmp
+            m_length = jnp.expand_dims(m_length, axis=-1) + tmp
+        m_length = jnp.sqrt(m_length)
+
+    # Construct an array of unique mode lengths
+    uniqueness_rtol = 1e-12
+    um = jnp.unique(m_length)
+    tol = uniqueness_rtol * um[-1]
+    um = um[jnp.diff(jnp.append(um, 2 * um[-1])) > tol]
+    # Group modes based on their length and store the result as power
+    # distributor
+    binbounds = 0.5 * (um[:-1] + um[1:])
+    m_length_idx = jnp.searchsorted(binbounds, m_length)
+    m_count = jnp.bincount(m_length_idx.ravel(), minlength=um.size)
+    if jnp.any(m_count == 0) or um.shape != m_count.shape:
+        raise RuntimeError("invalid harmonic mode(s) encountered")
+
+    return m_length_idx, um, m_count
+
+
+def _twolog_integrate(log_vol, x):
+    # Map the space to the one for the relative log-modes, i.e. pad the space
+    # of the log volume
+    twolog = jnp.empty((2 + log_vol.shape[0], ))
+    twolog = twolog.at[0].set(0.)
+    twolog = twolog.at[1].set(0.)
+
+    twolog = twolog.at[2:].set(jnp.cumsum(x[1], axis=0))
+    twolog = twolog.at[2:].set(
+        (twolog[2:] + twolog[1:-1]) / 2. * log_vol + x[0]
+    )
+    twolog = twolog.at[2:].set(jnp.cumsum(twolog[2:], axis=0))
+    return twolog
+
+
+def _remove_slope(rel_log_mode_dist, x):
+    sc = rel_log_mode_dist / rel_log_mode_dist[-1]
+    return x - x[-1] * sc
+
+
+def non_parametric_amplitude(
+    domain: Mapping,
+    fluctuations: Callable,
+    loglogavgslope: Callable,
+    flexibility: Optional[Callable] = None,
+    asperity: Optional[Callable] = None,
+    prefix: str = "",
+    kind: str = "amplitude",
+) -> Tuple[Callable, Dict[str, ShapeWithDtype]]:
+    """Constructs an function computing the amplitude of a non-parametric power
+    spectrum
+
+    See
+    :class:`nifty8.re.correlated_field.CorrelatedFieldMaker.add_fluctuations`
+    for more details on the parameters.
+
+    See also
+    --------
+    `Variable structures in M87* from space, time and frequency resolved
+    interferometry`, Arras, Philipp and Frank, Philipp and Haim, Philipp
+    and Knollmüller, Jakob and Leike, Reimar and Reinecke, Martin and
+    Enßlin, Torsten, `<https://arxiv.org/abs/2002.05218>`_
+    `<http://dx.doi.org/10.1038/s41550-021-01548-0>`_
+    """
+    totvol = domain.get("position_space_total_volume", 1.)
+    rel_log_mode_len = domain["relative_log_mode_lengths"]
+    mode_multiplicity = domain["mode_multiplicity"]
+    log_vol = domain.get("log_volume")
+
+    ptree = {}
+    fluctuations = ducktape(fluctuations, prefix + "fluctuations")
+    ptree[prefix + "fluctuations"] = ShapeWithDtype(())
+    loglogavgslope = ducktape(loglogavgslope, prefix + "loglogavgslope")
+    ptree[prefix + "loglogavgslope"] = ShapeWithDtype(())
+    if flexibility is not None:
+        flexibility = ducktape(flexibility, prefix + "flexibility")
+        ptree[prefix + "flexibility"] = ShapeWithDtype(())
+        # Register the parameters for the spectrum
+        _safe_assert(log_vol is not None)
+        _safe_assert(rel_log_mode_len.ndim == log_vol.ndim == 1)
+        ptree[prefix + "spectrum"] = ShapeWithDtype((2, ) + log_vol.shape)
+    if asperity is not None:
+        asperity = ducktape(asperity, prefix + "asperity")
+        ptree[prefix + "asperity"] = ShapeWithDtype(())
+
+    def correlate(primals: Mapping) -> jnp.ndarray:
+        flu = fluctuations(primals)
+        slope = loglogavgslope(primals)
+        slope *= rel_log_mode_len
+        ln_spectrum = slope
+
+        if flexibility is not None:
+            _safe_assert(log_vol is not None)
+            xi_spc = primals[prefix + "spectrum"]
+            flx = flexibility(primals)
+            sig_flx = flx * jnp.sqrt(log_vol)
+            sig_flx = jnp.broadcast_to(sig_flx, (2, ) + log_vol.shape)
+
+            if asperity is None:
+                shift = jnp.stack(
+                    (log_vol / jnp.sqrt(12.), jnp.ones_like(log_vol)), axis=0
+                )
+                asp = shift * sig_flx * xi_spc
+            else:
+                asp = asperity(primals)
+                shift = jnp.stack(
+                    (log_vol**2 / 12., jnp.ones_like(log_vol)), axis=0
+                )
+                sig_asp = jnp.broadcast_to(
+                    jnp.array([[asp], [0.]]), shift.shape
+                )
+                asp = jnp.sqrt(shift + sig_asp) * sig_flx * xi_spc
+
+            twolog = _twolog_integrate(log_vol, asp)
+            wo_slope = _remove_slope(rel_log_mode_len, twolog)
+            ln_spectrum += wo_slope
+
+        # Exponentiate and norm the power spectrum
+        spectrum = jnp.exp(ln_spectrum)
+        # Take the sqrt of the integral of the slope w/o fluctuations and
+        # zero-mode while taking into account the multiplicity of each mode
+        if kind.lower() == "amplitude":
+            norm = jnp.sqrt(jnp.sum(mode_multiplicity[1:] * spectrum[1:]**2))
+            norm /= jnp.sqrt(totvol)  # Due to integral in harmonic space
+            amplitude = flu * (jnp.sqrt(totvol) / norm) * spectrum
+        elif kind.lower() == "power":
+            norm = jnp.sqrt(jnp.sum(mode_multiplicity[1:] * spectrum[1:]))
+            norm /= jnp.sqrt(totvol)  # Due to integral in harmonic space
+            amplitude = flu * (jnp.sqrt(totvol) / norm) * jnp.sqrt(spectrum)
+        else:
+            raise ValueError(f"invalid kind specified {kind!r}")
+        amplitude = amplitude.at[0].set(totvol)
+        return amplitude
+
+    return correlate, ptree
+
+
+class CorrelatedFieldMaker():
+    """Construction helper for hierarchical correlated field models.
+
+    The correlated field models are parametrized by creating square roots of
+    power spectrum operators ("amplitudes") via calls to
+    :func:`add_fluctuations*` that act on the targeted field subdomains.
+    During creation of the :class:`CorrelatedFieldMaker`, a global offset from
+    zero of the field model can be defined and an operator applying
+    fluctuations around this offset is parametrized.
+
+    Creation of the model operator is completed by calling the method
+    :func:`finalize`, which returns the configured operator.
+
+    See the methods initialization, :func:`add_fluctuations` and
+    :func:`finalize` for further usage information."""
+    def __init__(self, prefix: str):
+        """Instantiate a CorrelatedFieldMaker object.
+
+        Parameters
+        ----------
+        prefix : string
+            Prefix to the names of the domains of the cf operator to be made.
+            This determines the names of the operator domain.
+        """
+        self._azm = None
+        self._offset_mean = None
+        self._fluctuations = []
+        self._target_subdomains = []
+        self._parameter_tree = {}
+
+        self._prefix = prefix
+
+    def add_fluctuations(
+        self,
+        shape: Union[tuple, int],
+        distances: Union[tuple, float],
+        fluctuations: Union[tuple, Callable],
+        loglogavgslope: Union[tuple, Callable],
+        flexibility: Union[tuple, Callable, None] = None,
+        asperity: Union[tuple, Callable, None] = None,
+        prefix: str = "",
+        harmonic_domain_type: str = "fourier",
+        non_parametric_kind: str = "amplitude",
+    ):
+        """Adds a correlation structure to the to-be-made field.
+
+        Correlations are described by their power spectrum and the subdomain on
+        which they apply.
+
+        Multiple calls to `add_fluctuations` are possible, in which case
+        the constructed field will have the outer product of the individual
+        power spectra as its global power spectrum.
+
+        The parameters `fluctuations`, `flexibility`, `asperity` and
+        `loglogavgslope` configure either the amplitude or the power
+        spectrum model used on the target field subdomain of type
+        `harmonic_domain_type`. It is assembled as the sum of a power
+        law component (linear slope in log-log
+        amplitude-frequency-space), a smooth varying component
+        (integrated Wiener process) and a ragged component
+        (un-integrated Wiener process).
+
+        Parameters
+        ----------
+        shape : tuple of int
+            Shape of the position space domain
+        distances : tuple of float or float
+            Distances in the position space domain
+        fluctuations : tuple of float (mean, std) or callable
+            Total spectral energy, i.e. amplitude of the fluctuations
+            (by default a priori log-normal distributed)
+        loglogavgslope : tuple of float (mean, std) or callable
+            Power law component exponent
+            (by default a priori normal distributed)
+        flexibility : tuple of float (mean, std) or callable or None
+            Amplitude of the non-power-law power spectrum component
+            (by default a priori log-normal distributed)
+        asperity : tuple of float (mean, std) or callable or None
+            Roughness of the non-power-law power spectrum component; use it to
+            accommodate single frequency peak
+            (by default a priori log-normal distributed)
+        prefix : str
+            Prefix of the power spectrum parameter domain names
+        harmonic_domain_type : str
+            Description of the harmonic partner domain in which the amplitude
+            lives
+
+        See also
+        --------
+        `Variable structures in M87* from space, time and frequency resolved
+        interferometry`, Arras, Philipp and Frank, Philipp and Haim, Philipp
+        and Knollmüller, Jakob and Leike, Reimar and Reinecke, Martin and
+        Enßlin, Torsten, `<https://arxiv.org/abs/2002.05218>`_
+        `<http://dx.doi.org/10.1038/s41550-021-01548-0>`_
+        """
+        shape = (shape, ) if isinstance(shape, int) else tuple(shape)
+        distances = tuple(np.broadcast_to(distances, jnp.shape(shape)))
+        totvol = jnp.prod(jnp.array(shape) * jnp.array(distances))
+
+        # Pre-compute lengths of modes and indices for distributing power
+        # TODO: cache results such that only references are used afterwards
+        domain = {
+            "position_space_shape": shape,
+            "position_space_total_volume": totvol,
+            "position_space_distances": distances,
+            "harmonic_domain_type": harmonic_domain_type.lower()
+        }
+        if harmonic_domain_type.lower() == "fourier":
+            domain["harmonic_space_shape"] = shape
+            m_length_idx, um, m_count = get_fourier_mode_distributor(
+                shape, distances
+            )
+            domain["power_distributor"] = m_length_idx
+            domain["mode_multiplicity"] = m_count
+
+            # Transform the unique modes to log-space for the amplitude model
+            um = um.at[1:].set(jnp.log(um[1:]))
+            um = um.at[1:].add(-um[1])
+            _safe_assert(um[0] == 0.)
+            domain["relative_log_mode_lengths"] = um
+            log_vol = um[2:] - um[1:-1]
+            _safe_assert(um.shape[0] - 2 == log_vol.shape[0])
+            domain["log_volume"] = log_vol
+        else:
+            ve = f"invalid `harmonic_domain_type` {harmonic_domain_type!r}"
+            raise ValueError(ve)
+
+        flu = fluctuations
+        if isinstance(flu, (tuple, list)):
+            flu = lognormal_prior(*flu)
+        elif not callable(flu):
+            te = f"invalid `fluctuations` specified; got '{type(fluctuations)}'"
+            raise TypeError(te)
+        slp = loglogavgslope
+        if isinstance(slp, (tuple, list)):
+            slp = normal_prior(*slp)
+        elif not callable(slp):
+            te = f"invalid `loglogavgslope` specified; got '{type(loglogavgslope)}'"
+            raise TypeError(te)
+
+        flx = flexibility
+        if isinstance(flx, (tuple, list)):
+            flx = lognormal_prior(*flx)
+        elif flx is not None and not callable(flx):
+            te = f"invalid `flexibility` specified; got '{type(flexibility)}'"
+            raise TypeError(te)
+        asp = asperity
+        if isinstance(asp, (tuple, list)):
+            asp = lognormal_prior(*asp)
+        elif asp is not None and not callable(asp):
+            te = f"invalid `asperity` specified; got '{type(asperity)}'"
+            raise TypeError(te)
+
+        npa, ptree = non_parametric_amplitude(
+            domain=domain,
+            fluctuations=flu,
+            loglogavgslope=slp,
+            flexibility=flx,
+            asperity=asp,
+            prefix=self._prefix + prefix,
+            kind=non_parametric_kind,
+        )
+        self._fluctuations.append(npa)
+        self._target_subdomains.append(domain)
+        self._parameter_tree.update(ptree)
+
+    def set_amplitude_total_offset(
+        self, offset_mean: float, offset_std: Union[tuple, Callable]
+    ):
+        """Sets the zero-mode for the combined amplitude operator
+
+        Parameters
+        ----------
+        offset_mean : float
+            Mean offset from zero of the correlated field to be made.
+        offset_std : tuple of float or callable
+            Mean standard deviation and standard deviation of the standard
+            deviation of the offset. No, this is not a word duplication.
+            (By default a priori log-normal distributed)
+        """
+        if self._offset_mean is not None and self._azm is not None:
+            msg = "Overwriting the previous mean offset and zero-mode"
+            print(msg, file=sys.stderr)
+
+        self._offset_mean = offset_mean
+        zm = offset_std
+        if not callable(zm):
+            if zm is None or len(zm) != 2:
+                raise TypeError(f"`offset_std` of invalid type {type(zm)!r}")
+            zm = lognormal_prior(*zm)
+
+        self._azm = ducktape(zm, self._prefix + "zeromode")
+        self._parameter_tree[self._prefix + "zeromode"] = ShapeWithDtype(())
+
+    @property
+    def amplitude_total_offset(self) -> Callable:
+        """Returns the total offset of the amplitudes"""
+        if self._azm is None:
+            nie = "You need to set the `amplitude_total_offset` first"
+            raise NotImplementedError(nie)
+        return self._azm
+
+    @property
+    def azm(self):
+        """Alias for `amplitude_total_offset`"""
+        return self.amplitude_total_offset
+
+    @property
+    def fluctuations(self) -> Tuple[Callable, ...]:
+        """Returns the added fluctuations, i.e. un-normalized amplitudes
+
+        Their scales are only meaningful relative to one another. Their
+        absolute scale bares no information.
+        """
+        return tuple(self._fluctuations)
+
+    def get_normalized_amplitudes(self) -> Tuple[Callable, ...]:
+        """Returns the normalized amplitude operators used in the final model
+
+        The amplitude operators are corrected for the otherwise degenerate
+        zero-mode. Their scales are only meaningful relative to one another.
+        Their absolute scale bares no information.
+        """
+        def _mk_normed_amp(amp):  # Avoid late binding
+            def normed_amplitude(p):
+                return amp(p).at[1:].mul(1. / self.azm(p))
+
+            return normed_amplitude
+
+        return tuple(_mk_normed_amp(amp) for amp in self._fluctuations)
+
+    @property
+    def amplitude(self) -> Callable:
+        """Returns the added fluctuation, i.e. un-normalized amplitude"""
+        if len(self._fluctuations) > 1:
+            s = (
+                'If more than one spectrum is present in the model,'
+                ' no unique set of amplitudes exist because only the'
+                ' relative scale is determined.'
+            )
+            raise NotImplementedError(s)
+        amp = self._fluctuations[0]
+
+        def ampliude_w_zm(p):
+            return amp(p).at[0].mul(self.azm(p))
+
+        return ampliude_w_zm
+
+    @property
+    def power_spectrum(self) -> Callable:
+        """Returns the power spectrum"""
+        amp = self.amplitude
+
+        def power(p):
+            return amp(p)**2
+
+        return power
+
+    def finalize(self) -> Tuple[Callable, Dict[str, ShapeWithDtype]]:
+        """Finishes off the model construction process and returns the
+        constructed operator.
+        """
+        harmonic_transforms = []
+        excitation_shape = ()
+        for sub_dom in self._target_subdomains:
+            sub_shp = None
+            sub_shp = sub_dom["harmonic_space_shape"]
+            excitation_shape += sub_shp
+            n = len(excitation_shape)
+            axes = tuple(range(n - len(sub_shp), n))
+
+            # TODO: Generalize to complex
+            harmonic_dvol = 1. / sub_dom["position_space_total_volume"]
+            harmonic_transforms.append((harmonic_dvol, partial(hartley, axes=axes)))
+        # Register the parameters for the excitations in harmonic space
+        # TODO: actually account for the dtype here
+        pfx = self._prefix + "xi"
+        self._parameter_tree[pfx] = ShapeWithDtype(excitation_shape)
+
+        def outer_harmonic_transform(p):
+            harmonic_dvol, ht = harmonic_transforms[0]
+            outer = harmonic_dvol * ht(p)
+            for harmonic_dvol, ht in harmonic_transforms[1:]:
+                outer = harmonic_dvol * ht(outer)
+            return outer
+
+        def _mk_expanded_amp(amp, sub_dom):  # Avoid late binding
+            def expanded_amp(p):
+                return amp(p)[sub_dom["power_distributor"]]
+
+            return expanded_amp
+
+        expanded_amplitudes = []
+        namps = self.get_normalized_amplitudes()
+        for amp, sub_dom in zip(namps, self._target_subdomains):
+            expanded_amplitudes.append(_mk_expanded_amp(amp, sub_dom))
+
+        def outer_amplitude(p):
+            outer = expanded_amplitudes[0](p)
+            for amp in expanded_amplitudes[1:]:
+                # NOTE, the order is important here and must match with the
+                # excitations
+                # TODO, use functions instead and utilize numpy's casting
+                outer = jnp.tensordot(outer, amp(p), axes=0)
+            return outer
+
+        def correlated_field(p):
+            ea = outer_amplitude(p)
+            cf_h = self.azm(p) * ea * p[self._prefix + "xi"]
+            return self._offset_mean + outer_harmonic_transform(cf_h)
+
+        return correlated_field, self._parameter_tree.copy()
diff --git a/src/re/disable_jax_control_flow.py b/src/re/disable_jax_control_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f88d92d8dd2bcc7f3134d8a71a8e65743e2e3b7
--- /dev/null
+++ b/src/re/disable_jax_control_flow.py
@@ -0,0 +1,36 @@
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from jax import lax
+
+_DISABLE_CONTROL_FLOW_PRIM = False
+
+
+def cond(pred, true_fun, false_fun, operand):
+    if _DISABLE_CONTROL_FLOW_PRIM:
+        if pred:
+            return true_fun(operand)
+        else:
+            return false_fun(operand)
+    else:
+        return lax.cond(pred, true_fun, false_fun, operand)
+
+
+def while_loop(cond_fun, body_fun, init_val):
+    if _DISABLE_CONTROL_FLOW_PRIM:
+        val = init_val
+        while cond_fun(val):
+            val = body_fun(val)
+        return val
+    else:
+        return lax.while_loop(cond_fun, body_fun, init_val)
+
+
+def fori_loop(lower, upper, body_fun, init_val):
+    if _DISABLE_CONTROL_FLOW_PRIM:
+        val = init_val
+        for i in range(int(lower), int(upper)):
+            val = body_fun(i, val)
+        return val
+    else:
+        return lax.fori_loop(lower, upper, body_fun, init_val)
diff --git a/src/re/energy_operators.py b/src/re/energy_operators.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2d3eef606baf85aff624e234ef4f3a2c48e7ee2
--- /dev/null
+++ b/src/re/energy_operators.py
@@ -0,0 +1,385 @@
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from typing import Callable, Optional, Tuple
+
+import sys
+from jax import numpy as jnp
+from jax.tree_util import tree_map
+
+from .forest_util import ShapeWithDtype
+from .likelihood import Likelihood
+
+
+def standard_t(nwr, dof):
+    return jnp.sum(jnp.log1p(nwr**2 / dof) * (dof + 1)) / 2
+
+
+def _shape_w_fixed_dtype(dtype):
+    def shp_w_dtp(e):
+        return ShapeWithDtype(jnp.shape(e), dtype)
+
+    return shp_w_dtp
+
+
+def _get_cov_inv_and_std_inv(
+    cov_inv: Optional[Callable],
+    std_inv: Optional[Callable],
+    primals=None
+) -> Tuple[Callable, Callable]:
+    if not cov_inv and not std_inv:
+
+        def cov_inv(tangents):
+            return tangents
+
+        def std_inv(tangents):
+            return tangents
+
+    elif not cov_inv:
+        wm = (
+            "assuming a diagonal covariance matrix"
+            ";\nsetting `cov_inv` to `std_inv(jnp.ones_like(data))**2`"
+        )
+        print(wm, file=sys.stderr)
+        noise_std_inv_sq = std_inv(tree_map(jnp.ones_like, primals))**2
+
+        def cov_inv(tangents):
+            return noise_std_inv_sq * tangents
+
+    elif not std_inv:
+        wm = (
+            "assuming a diagonal covariance matrix"
+            ";\nsetting `std_inv` to `cov_inv(jnp.ones_like(data))**0.5`"
+        )
+        print(wm, file=sys.stderr)
+        noise_cov_inv_sqrt = tree_map(
+            jnp.sqrt, cov_inv(tree_map(jnp.ones_like, primals))
+        )
+
+        def std_inv(tangents):
+            return noise_cov_inv_sqrt * tangents
+
+    if not (callable(cov_inv) and callable(std_inv)):
+        raise ValueError("received un-callable input")
+    return cov_inv, std_inv
+
+
+def Gaussian(
+    data,
+    noise_cov_inv: Optional[Callable] = None,
+    noise_std_inv: Optional[Callable] = None
+):
+    """Gaussian likelihood of the data
+
+    Parameters
+    ----------
+    data : tree-like structure of jnp.ndarray and float
+        Data with additive noise following a Gaussian distribution.
+    noise_cov_inv : callable acting on type of data
+        Function applying the inverse noise covariance of the Gaussian.
+    noise_std_inv : callable acting on type of data
+        Function applying the square root of the inverse noise covariance.
+
+    Notes
+    -----
+    If `noise_std_inv` is `None` it is inferred by assuming a diagonal noise
+    covariance, i.e. by applying it to a vector of ones and taking the square
+    root. If both `noise_cov_inv` and `noise_std_inv` are `None`, a unit
+    covariance is assumed.
+    """
+    noise_cov_inv, noise_std_inv = _get_cov_inv_and_std_inv(
+        noise_cov_inv, noise_std_inv, data
+    )
+
+    def hamiltonian(primals):
+        p_res = primals - data
+        return 0.5 * p_res.ravel().dot(noise_cov_inv(p_res).ravel())
+
+    def metric(primals, tangents):
+        return noise_cov_inv(tangents)
+
+    def left_sqrt_metric(primals, tangents):
+        return noise_std_inv(tangents)
+
+    def transformation(primals):
+        return noise_std_inv(primals)
+
+    lsm_tangents_shape = tree_map(ShapeWithDtype.from_leave, data)
+
+    return Likelihood(
+        hamiltonian,
+        transformation=transformation,
+        left_sqrt_metric=left_sqrt_metric,
+        metric=metric,
+        lsm_tangents_shape=lsm_tangents_shape
+    )
+
+
+def StudentT(
+    data,
+    dof,
+    noise_cov_inv: Optional[Callable] = None,
+    noise_std_inv: Optional[Callable] = None
+):
+    """Student's t likelihood of the data
+
+    Parameters
+    ----------
+    data : tree-like structure of jnp.ndarray and float
+        Data with additive noise following a Gaussian distribution.
+    dof : tree-like structure of jnp.ndarray and float
+        Degree-of-freedom parameter of Student's t distribution.
+    noise_cov_inv : callable acting on type of data
+        Function applying the inverse noise covariance of the Gaussian.
+    noise_std_inv : callable acting on type of data
+        Function applying the square root of the inverse noise covariance.
+
+    Notes
+    -----
+    If `noise_std_inv` is `None` it is inferred by assuming a diagonal noise
+    covariance, i.e. by applying it to a vector of ones and taking the square
+    root. If both `noise_cov_inv` and `noise_std_inv` are `None`, a unit
+    covariance is assumed.
+    """
+    noise_cov_inv, noise_std_inv = _get_cov_inv_and_std_inv(
+        noise_cov_inv, noise_std_inv, data
+    )
+
+    def hamiltonian(primals):
+        """
+        primals : mean
+        """
+        return standard_t(noise_std_inv(data - primals), dof)
+
+    def metric(primals, tangents):
+        """
+        primals, tangent : mean
+        """
+        return noise_cov_inv((dof + 1) / (dof + 3) * tangents)
+
+    def left_sqrt_metric(primals, tangents):
+        """
+        primals, tangents : mean
+        """
+        return noise_std_inv(jnp.sqrt((dof + 1) / (dof + 3)) * tangents)
+
+    def transformation(primals):
+        """
+        primals : mean
+        """
+        return noise_std_inv(jnp.sqrt((dof + 1) / (dof + 3)) * primals)
+
+    lsm_tangents_shape = tree_map(ShapeWithDtype.from_leave, data)
+
+    return Likelihood(
+        hamiltonian,
+        transformation=transformation,
+        left_sqrt_metric=left_sqrt_metric,
+        metric=metric,
+        lsm_tangents_shape=lsm_tangents_shape
+    )
+
+
+def Poissonian(data, sampling_dtype=float):
+    """Computes the negative log-likelihood, i.e. the Hamiltonians of an
+    expected count field constrained by Poissonian count data.
+
+    Represents up to an f-independent term :math:`log(d!)`:
+
+    .. math ::
+        E(f) = -\\log \\text{Poisson}(d|f) = \\sum f - d^\\dagger \\log(f),
+
+    where f is a field in data space of the expectation values for the counts.
+
+    Parameters
+    ----------
+    data : ndarray of uint
+        Data field with counts. Needs to have integer dtype and all values need
+        to be non-negative.
+    sampling_dtype : dtype, optional
+        Data-type for sampling.
+    """
+    from .forest_util import common_type
+
+    dtp = common_type(data)
+    if not jnp.issubdtype(dtp, jnp.integer):
+        raise TypeError("`data` of invalid type")
+    if jnp.any(data < 0):
+        raise ValueError("`data` may not be negative")
+
+    def hamiltonian(primals):
+        return jnp.sum(primals) - jnp.vdot(jnp.log(primals), data)
+
+    def metric(primals, tangents):
+        return tangents / primals
+
+    def left_sqrt_metric(primals, tangents):
+        return tangents / jnp.sqrt(primals)
+
+    def transformation(primals):
+        return jnp.sqrt(primals) * 2.
+
+    lsm_tangents_shape = tree_map(_shape_w_fixed_dtype(sampling_dtype), data)
+
+    return Likelihood(
+        hamiltonian,
+        transformation=transformation,
+        left_sqrt_metric=left_sqrt_metric,
+        metric=metric,
+        lsm_tangents_shape=lsm_tangents_shape
+    )
+
+
+def VariableCovarianceGaussian(data):
+    """Gaussian likelihood of the data with a variable covariance
+
+    Parameters
+    ----------
+    data : tree-like structure of jnp.ndarray and float
+        Data with additive noise following a Gaussian distribution.
+
+    Notes
+    -----
+    The likelihood acts on a tuple of (mean, std_inv).
+    """
+    from .sugar import sum_of_squares
+
+    def hamiltonian(primals):
+        """
+        primals : pair of (mean, std_inv)
+        """
+        res = (primals[0] - data) * primals[1]
+        return 0.5 * sum_of_squares(res) - jnp.sum(jnp.log(primals[1]))
+
+    def metric(primals, tangents):
+        """
+        primals, tangent : pair of (mean, std_inv)
+        """
+        prim_std_inv_sq = primals[1]**2
+        res = (prim_std_inv_sq * tangents[0], 2 * tangents[1] / prim_std_inv_sq)
+        return type(primals)(res)
+
+    def left_sqrt_metric(primals, tangents):
+        """
+        primals, tangent : pair of (mean, std_inv)
+        """
+        res = (primals[1] * tangents[0], jnp.sqrt(2) * tangents[1] / primals[1])
+        return type(primals)(res)
+
+    def transformation(primals):
+        """
+        pirmals : pair of (mean, std_inv)
+
+        Notes
+        -----
+        A global transformation to Euclidean space does not exist. A local
+        approximation invoking the residual is used instead.
+        """
+        # TODO: test by drawing synthetic data that actually follows the
+        # noise-cov and then average over it
+        res = (primals[1] * (primals[0] - data), tree_map(jnp.log, primals[1]))
+        return type(primals)(res)
+
+    lsm_tangents_shape = tree_map(ShapeWithDtype.from_leave, (data, data))
+
+    return Likelihood(
+        hamiltonian,
+        transformation=transformation,
+        left_sqrt_metric=left_sqrt_metric,
+        metric=metric,
+        lsm_tangents_shape=lsm_tangents_shape
+    )
+
+
+def VariableCovarianceStudentT(data, dof):
+    """Student's t likelihood of the data with a variable covariance
+
+    Parameters
+    ----------
+    data : tree-like structure of jnp.ndarray and float
+        Data with additive noise following a Gaussian distribution.
+    dof : tree-like structure of jnp.ndarray and float
+        Degree-of-freedom parameter of Student's t distribution.
+
+    Notes
+    -----
+    The likelihood acts on a tuple of (mean, std).
+    """
+    def hamiltonian(primals):
+        """
+        primals : pair of (mean, std)
+        """
+        t = standard_t((data - primals[0]) / primals[1], dof)
+        t += jnp.sum(jnp.log(primals[1]))
+        return t
+
+    def metric(primals, tangent):
+        """
+        primals, tangent : pair of (mean, std)
+        """
+        return (
+            tangent[0] * (dof + 1) / (dof + 3) / primals[1]**2,
+            tangent[1] * 2 * dof / (dof + 3) / primals[1]**2
+        )
+
+    def left_sqrt_metric(primals, tangents):
+        """
+        primals, tangents : pair of (mean, std)
+        """
+        cov = (
+            (dof + 1) / (dof + 3) / primals[1]**2,
+            2 * dof / (dof + 3) / primals[1]**2
+        )
+        res = (jnp.sqrt(cov[0]) * tangents[0], jnp.sqrt(cov[1]) * tangents[1])
+        return res
+
+    lsm_tangents_shape = tree_map(ShapeWithDtype.from_leave, (data, data))
+
+    return Likelihood(
+        hamiltonian,
+        left_sqrt_metric=left_sqrt_metric,
+        metric=metric,
+        lsm_tangents_shape=lsm_tangents_shape
+    )
+
+
+def Categorical(data, axis=-1, sampling_dtype=float):
+    """Categorical likelihood of the data, equivalent to cross entropy
+
+    Parameters
+    ----------
+    data : sequence of int
+        An array stating which of the categories is the realized in the data.
+        Must agree with the input shape except for the shape[axis]
+    axis : int
+        Axis over which the categories are formed
+    sampling_dtype : dtype, optional
+        Data-type for sampling.
+    """
+    def hamiltonian(primals):
+        from jax.nn import log_softmax
+        logits = log_softmax(primals, axis=axis)
+        return -jnp.sum(jnp.take_along_axis(logits, data, axis))
+
+    def metric(primals, tangents):
+        from jax.nn import softmax
+
+        preds = softmax(primals, axis=axis)
+        norm_term = jnp.sum(preds * tangents, axis=axis, keepdims=True)
+        return preds * tangents - preds * norm_term
+
+    def left_sqrt_metric(primals, tangents):
+        from jax.nn import softmax
+
+        sqrtp = jnp.sqrt(softmax(primals, axis=axis))
+        norm_term = jnp.sum(sqrtp * tangents, axis=axis, keepdims=True)
+        return sqrtp * (tangents - sqrtp * norm_term)
+
+    lsm_tangents_shape = tree_map(_shape_w_fixed_dtype(sampling_dtype), data)
+
+    return Likelihood(
+        hamiltonian,
+        left_sqrt_metric=left_sqrt_metric,
+        metric=metric,
+        lsm_tangents_shape=lsm_tangents_shape
+    )
diff --git a/src/re/field.py b/src/re/field.py
new file mode 100644
index 0000000000000000000000000000000000000000..eae10b7d08dcf6a9e5848186cb2ebdf2fedf9a3a
--- /dev/null
+++ b/src/re/field.py
@@ -0,0 +1,274 @@
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+import operator
+from jax import numpy as jnp
+from jax.tree_util import (
+    register_pytree_node_class, tree_leaves, tree_map, tree_structure
+)
+
+
+def _value_op(op, name=None):
+    def value_call(lhs, *args, **kwargs):
+        return op(lhs.val, *args, **kwargs)
+
+    name = op.__name__ if name is None else name
+    value_call.__name__ = f"__{name}__"
+    return value_call
+
+
+def _unary_op(op, name=None):
+    def unary_call(lhs):
+        return tree_map(op, lhs)
+
+    name = op.__name__ if name is None else name
+    unary_call.__name__ = f"__{name}__"
+    return unary_call
+
+
+def _enforce_flags(lhs, rhs):
+    flags = lhs.flags if isinstance(lhs, Field) else set()
+    flags |= rhs.flags if isinstance(rhs, Field) else set()
+    if "strict_domain_checking" in flags:
+        ts_lhs = tree_structure(lhs)
+        ts_rhs = tree_structure(rhs)
+
+        if not hasattr(rhs, "domain"):
+            te = f"RHS of type {type(rhs)} does not have a `domain` property"
+            raise TypeError(te)
+        if not hasattr(lhs, "domain"):
+            te = f"LHS of type {type(lhs)} does not have a `domain` property"
+            raise TypeError(te)
+        if rhs.domain != lhs.domain or ts_rhs != ts_lhs:
+            raise ValueError("domains and/or structures are incompatible")
+    return flags
+
+
+def _broadcast_binary_op(op, lhs, rhs):
+    from itertools import repeat
+
+    flags = _enforce_flags(lhs, rhs)
+
+    ts_lhs = tree_structure(lhs)
+    ts_rhs = tree_structure(rhs)
+    # Catch non-objects scalars and 0d array-likes with a `ndim` property
+    if jnp.isscalar(lhs) or getattr(lhs, "ndim", -1) == 0:
+        lhs = ts_rhs.unflatten(repeat(lhs, ts_rhs.num_leaves))
+    elif jnp.isscalar(rhs) or getattr(rhs, "ndim", -1) == 0:
+        rhs = ts_lhs.unflatten(repeat(rhs, ts_lhs.num_leaves))
+    elif ts_lhs.num_nodes != ts_rhs.num_nodes:
+        ve = f"invalid binary operation {op} for {ts_lhs!r} and {ts_rhs!r}"
+        raise ValueError(ve)
+
+    out = tree_map(op, lhs, rhs)
+    if flags != set():
+        out._flags = flags
+    return out
+
+
+def _binary_op(op, name=None):
+    def binary_call(lhs, rhs):
+        return _broadcast_binary_op(op, lhs, rhs)
+
+    name = op.__name__ if name is None else name
+    binary_call.__name__ = f"__{name}__"
+    return binary_call
+
+
+def _rev_binary_op(op, name=None):
+    def binary_call(lhs, rhs):
+        return _broadcast_binary_op(op, rhs, lhs)
+
+    name = op.__name__ if name is None else name
+    binary_call.__name__ = f"__r{name}__"
+    return binary_call
+
+
+def _fwd_rev_binary_op(op, name=None):
+    return (_binary_op(op, name=name), _rev_binary_op(op, name=name))
+
+
+def matmul(lhs, rhs):
+    """Returns the dot product of the two fields.
+
+    Parameters
+    ----------
+    lhs : object
+        Arbitrary, flatten-able objects.
+    other : object
+        Arbitrary, flatten-able objects.
+
+    Returns
+    -------
+    out : float
+        Dot product of fields.
+    """
+    from .forest_util import dot
+
+    _enforce_flags(lhs, rhs)
+
+    ts_lhs = tree_structure(lhs)
+    ts_rhs = tree_structure(rhs)
+    if ts_lhs.num_nodes != ts_rhs.num_nodes:
+        ve = f"invalid operation for {ts_lhs!r} and {ts_rhs!r}"
+        raise ValueError(ve)
+
+    return dot(lhs, rhs)
+
+
+dot = matmul
+
+
+@register_pytree_node_class
+class Field():
+    """Value storage for arbitrary objects with added numerics."""
+    supported_flags = {"strict_domain_checking"}
+
+    def __init__(self, val, domain=None, flags=None):
+        """Instantiates a field.
+
+        Parameters
+        ----------
+        val : object
+            Arbitrary, flatten-able objects.
+        domain : dict or None, optional
+            Domain of the field, e.g. with description of modes and volume.
+        flags : set, str or None, optional
+            Capabilities and constraints of the field.
+        """
+        self._val = val
+        self._domain = {} if domain is None else dict(domain)
+
+        flags = (flags, ) if isinstance(flags, str) else flags
+        flags = set() if flags is None else set(flags)
+        if not flags.issubset(Field.supported_flags):
+            ve = (
+                f"specified flags ({flags!r}) are not a subset of the"
+                f" supported flags ({Field.supported_flags!r})"
+            )
+            raise ValueError(ve)
+        self._flags = flags
+
+    def tree_flatten(self):
+        """Recipe for flattening fields.
+
+        Returns
+        -------
+        flat_tree : tuple of two tuples
+            Pair of an iterable with the children to be flattened recursively,
+            and some opaque auxiliary data.
+        """
+        return ((self._val, ), (self._domain, self._flags))
+
+    @classmethod
+    def tree_unflatten(cls, aux_data, children):
+        """Recipe to construct fields from flattened Pytrees.
+
+        Parameters
+        ----------
+        aux_data : tuple of a dict and a set
+            Opaque auxiliary data describing a field.
+        children: tuple
+            Value of the field, i.e. unflattened children.
+
+        Returns
+        -------
+        unflattened_tree : :class:`nifty8.field.Field`
+            Re-constructed field.
+        """
+        return cls(*children, domain=aux_data[0], flags=aux_data[1])
+
+    @property
+    def val(self):
+        """Retrieves a **view** of the field's values."""
+        return self._val
+
+    @property
+    def domain(self):
+        """Retrieves a **copy** of the field's domain."""
+        return self._domain.copy()
+
+    @property
+    def flags(self):
+        """Retrieves a **copy** of the field's flags."""
+        return self._flags.copy()
+
+    @property
+    def size(self):
+        from .forest_util import size
+
+        return size(self)
+
+    def __str__(self):
+        s = f"Field(\n{self.val}"
+        if self._domain:
+            s += f",\ndomain={self._domain}"
+        if self._flags:
+            s += f",\nflags={self._flags}"
+        s += ")"
+        return s
+
+    def __repr__(self):
+        s = f"Field(\n{self.val!r}"
+        if self._domain:
+            s += f",\ndomain={self._domain!r}"
+        if self._flags:
+            s += f",\nflags={self._flags!r}"
+        s += ")"
+        return s
+
+    def ravel(self):
+        return tree_map(jnp.ravel, self)
+
+    def __bool__(self):
+        return bool(self.val)
+
+    def __hash__(self):
+        return hash(tuple(tree_leaves(self)))
+
+    # NOTE, this partly redundant code could be abstracted away using
+    # `setattr`. However, static code analyzers will not be able to infer the
+    # properties then.
+
+    __add__, __radd__ = _fwd_rev_binary_op(operator.add)
+    __sub__, __rsub__ = _fwd_rev_binary_op(operator.sub)
+    __mul__, __rmul__ = _fwd_rev_binary_op(operator.mul)
+    __truediv__, __rtruediv__ = _fwd_rev_binary_op(operator.truediv)
+    __floordiv__, __rfloordiv__ = _fwd_rev_binary_op(operator.floordiv)
+    __pow__, __rpow__ = _fwd_rev_binary_op(operator.pow)
+    __mod__, __rmod__ = _fwd_rev_binary_op(operator.mod)
+    __matmul__ = __rmatmul__ = matmul  # arguments of matmul commute
+
+    def __divmod__(self, other):
+        return self // other, self % other
+
+    def __rdivmod__(self, other):
+        return other // self, other % self
+
+    __or__, __ror__ = _fwd_rev_binary_op(operator.or_, "or")
+    __xor__, __rxor__ = _fwd_rev_binary_op(operator.xor)
+    __and__, __rand__ = _fwd_rev_binary_op(operator.and_, "and")
+    __lshift__, __rlshift__ = _fwd_rev_binary_op(operator.lshift)
+    __rshift__, __rrshift__ = _fwd_rev_binary_op(operator.rshift)
+
+    __lt__ = _binary_op(operator.lt)
+    __le__ = _binary_op(operator.le)
+    __eq__ = _binary_op(operator.eq)
+    __ne__ = _binary_op(operator.ne)
+    __ge__ = _binary_op(operator.ge)
+    __gt__ = _binary_op(operator.gt)
+
+    __neg__ = _unary_op(operator.neg)
+    __pos__ = _unary_op(operator.pos)
+    __abs__ = _unary_op(operator.abs)
+    __invert__ = _unary_op(operator.invert)
+
+    conj = conjugate = _unary_op(jnp.conj)
+    real = _unary_op(jnp.real)
+    imag = _unary_op(jnp.imag)
+    dot = matmul
+
+    __getitem__ = _value_op(operator.getitem)
+    __contains__ = _value_op(operator.contains)
+    __len__ = _value_op(len)
+    __iter__ = _value_op(iter)
diff --git a/src/re/forest_util.py b/src/re/forest_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a493fb5a711e1126b1d6341d25c551166fb973c
--- /dev/null
+++ b/src/re/forest_util.py
@@ -0,0 +1,403 @@
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial, reduce
+import operator
+from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union
+
+from jax import lax
+from jax import numpy as jnp
+from jax.tree_util import (
+    all_leaves,
+    tree_leaves,
+    tree_map,
+    tree_reduce,
+    tree_structure,
+    tree_transpose,
+)
+import numpy as np
+
+from .field import Field
+from .sugar import is1d
+
+
+def split(mappable, keys):
+    """Split a dictionary into one containing only the specified keys and one
+    with all of the remaining ones.
+    """
+    sel, rest = {}, {}
+    for k, v in mappable.items():
+        if k in keys:
+            sel[k] = v
+        else:
+            rest[k] = v
+    return sel, rest
+
+
+def unite(x, y, op=operator.add):
+    """Unites two array-, dict- or Field-like objects.
+
+    If a key is contained in both objects, then the fields at that key
+    are combined.
+    """
+    if isinstance(x, Field) or isinstance(y, Field):
+        x = x.val if isinstance(x, Field) else x
+        y = y.val if isinstance(y, Field) else y
+        return Field(unite(x, y, op=op))
+    if not hasattr(x, "keys") and not hasattr(y, "keys"):
+        return op(x, y)
+    if not hasattr(x, "keys") or not hasattr(y, "keys"):
+        te = (
+            "one of the inputs does not have a `keys` property;"
+            f" got {type(x)} and {type(y)}"
+        )
+        raise TypeError(te)
+
+    out = {}
+    for k in x.keys() | y.keys():
+        if k in x and k in y:
+            out[k] = op(x[k], y[k])
+        elif k in x:
+            out[k] = x[k]
+        else:
+            out[k] = y[k]
+    return out
+
+
+CORE_ARITHMETIC_ATTRIBUTES = (
+    "__neg__", "__pos__", "__abs__", "__add__", "__radd__", "__sub__",
+    "__rsub__", "__mul__", "__rmul__", "__truediv__", "__rtruediv__",
+    "__floordiv__", "__rfloordiv__", "__pow__", "__rpow__", "__mod__",
+    "__rmod__", "__matmul__", "__rmatmul__"
+)
+
+
+def has_arithmetics(obj, additional_attributes=()):
+    desired_attrs = CORE_ARITHMETIC_ATTRIBUTES + additional_attributes
+    return all(hasattr(obj, attr) for attr in desired_attrs)
+
+
+def assert_arithmetics(obj, *args, **kwargs):
+    if not has_arithmetics(obj, *args, **kwargs):
+        ae = (
+            f"input of type {type(obj)} does not support"
+            " core arithmetic operations"
+            "\nmaybe you forgot to wrap your object in a"
+            " :class:`nifty8.re.field.Field` instance"
+        )
+        raise AssertionError(ae)
+
+
+class ShapeWithDtype():
+    """Minimal helper class storing the shape and dtype of an object.
+
+    Notes
+    -----
+    This class may not be transparent to JAX as it shall not be flattened
+    itself. If used in a tree-like structure. It should only be used as leave.
+    """
+    def __init__(self, shape: Union[Tuple[int], List[int], int], dtype=None):
+        """Instantiates a storage unit for shape and dtype.
+
+        Parameters
+        ----------
+        shape : tuple or list of int
+            One-dimensional sequence of integers denoting the length of the
+            object along each of the object's axis.
+        dtype : dtype
+            Data-type of the to-be-described object.
+        """
+        if isinstance(shape, int):
+            shape = (shape, )
+        if isinstance(shape, list):
+            shape = tuple(shape)
+        if not is1d(shape):
+            ve = f"invalid shape; got {shape!r}"
+            raise TypeError(ve)
+
+        self._shape = shape
+        self._dtype = jnp.float64 if dtype is None else dtype
+        self._size = None
+
+    @classmethod
+    def from_leave(cls, element):
+        """Convenience method for creating an instance of `ShapeWithDtype` from
+        an object.
+
+        To map a whole tree-like structure to a its shape and dtype use JAX's
+        `tree_map` method like so:
+
+            tree_map(ShapeWithDtype.from_leave, tree)
+
+        Parameters
+        ----------
+        element : tree-like structure
+            Object from which to take the shape and data-type.
+
+        Returns
+        -------
+        swd : instance of ShapeWithDtype
+            Instance storing the shape and data-type of `element`.
+        """
+        if not all_leaves((element, )):
+            ve = "tree is not flat and still contains leaves"
+            raise ValueError(ve)
+        return cls(jnp.shape(element), get_dtype(element))
+
+    @property
+    def shape(self) -> Tuple[int]:
+        """Retrieves the shape."""
+        return self._shape
+
+    @property
+    def dtype(self):
+        """Retrieves the data-type."""
+        return self._dtype
+
+    @property
+    def size(self) -> int:
+        """Total number of elements."""
+        if self._size is None:
+            self._size = reduce(operator.mul, self.shape, 1)
+        return self._size
+
+    @property
+    def ndim(self) -> int:
+        return len(self.shape)
+
+    def __len__(self) -> int:
+        if self.ndim > 0:
+            return self.shape[0]
+        else:  # mimic numpy
+            raise TypeError("len() of unsized object")
+
+    def __eq__(self, other) -> bool:
+        if not isinstance(other, ShapeWithDtype):
+            return False
+        else:
+            return (self.shape, self.dtype) == (other.shape, other.dtype)
+
+    def __repr__(self):
+        nm = self.__class__.__name__
+        return f"{nm}(shape={self.shape}, dtype={self.dtype})"
+
+
+def get_dtype(v: Any):
+    if hasattr(v, "dtype"):
+        return v.dtype
+    else:
+        return type(v)
+
+
+def common_type(*trees):
+    from numpy import find_common_type
+
+    common_dtp = find_common_type(
+        tuple(
+            find_common_type(tuple(get_dtype(v) for v in tree_leaves(tr)), ())
+            for tr in trees
+        ), ()
+    )
+    return common_dtp
+
+
+def _size(x):
+    return x.size if hasattr(x, "size") else jnp.size(x)
+
+
+def size(tree, axis: Optional[int] = None) -> int:
+    if axis is not None:
+        raise TypeError("axis of an arbitrary tree is ill defined")
+    sizes = tree_map(_size, tree)
+    return tree_reduce(operator.add, sizes)
+
+
+def _shape(x):
+    return x.shape if hasattr(x, "shape") else jnp.shape(x)
+
+
+T = TypeVar("T")
+
+
+def shape(tree: T) -> T:
+    return tree_map(_shape, tree)
+
+
+def _zeros_like(x, dtype, shape):
+    if hasattr(x, "shape") and hasattr(x, "dtype"):
+        shp = x.shape if shape is None else shape
+        dtp = x.dtype if dtype is None else dtype
+        return jnp.zeros(shape=shp, dtype=dtp)
+    return jnp.zeros_like(x, dtype=dtype, shape=shape)
+
+
+def zeros_like(a, dtype=None, shape=None):
+    return tree_map(partial(_zeros_like, dtype=dtype, shape=shape), a)
+
+
+def norm(tree, ord, *, ravel: bool):
+    from jax.numpy.linalg import norm
+
+    if ravel:
+
+        def el_norm(x):
+            return jnp.abs(x) if jnp.ndim(x) == 0 else norm(x.ravel(), ord=ord)
+    else:
+
+        def el_norm(x):
+            return jnp.abs(x) if jnp.ndim(x) == 0 else norm(x, ord=ord)
+
+    return norm(tree_leaves(tree_map(el_norm, tree)), ord=ord)
+
+
+def _ravel(x):
+    return x.ravel() if hasattr(x, "ravel") else jnp.ravel(x)
+
+
+def dot(a, b, *, precision=None):
+    tree_of_dots = tree_map(
+        lambda x, y: jnp.dot(_ravel(x), _ravel(y), precision=precision), a, b
+    )
+    return tree_reduce(operator.add, tree_of_dots, 0.)
+
+
+def vdot(a, b, *, precision=None):
+    tree_of_vdots = tree_map(
+        lambda x, y: jnp.vdot(_ravel(x), _ravel(y), precision=precision), a, b
+    )
+    return tree_reduce(jnp.add, tree_of_vdots, 0.)
+
+
+def select(pred, on_true, on_false):
+    return tree_map(partial(lax.select, pred), on_true, on_false)
+
+
+def where(condition, x, y):
+    """Selects a pytree based on the condition which can be a pytree itself.
+
+    Notes
+    -----
+    If `condition` is not a pytree, then a partially evaluated selection is
+    simply mapped over `x` and `y` without actually broadcasting `condition`.
+    """
+    import numpy as np
+    from itertools import repeat
+
+    ts_c = tree_structure(condition)
+    ts_x = tree_structure(x)
+    ts_y = tree_structure(y)
+    ts_max = (ts_c, ts_x, ts_y)[np.argmax(
+        [ts_c.num_nodes, ts_x.num_nodes, ts_y.num_nodes]
+    )]
+
+    if ts_x.num_nodes < ts_max.num_nodes:
+        if ts_x.num_nodes > 1:
+            raise ValueError("can not broadcast LHS")
+        x = ts_max.unflatten(repeat(x, ts_max.num_leaves))
+    if ts_y.num_nodes < ts_max.num_nodes:
+        if ts_y.num_nodes > 1:
+            raise ValueError("can not broadcast RHS")
+        y = ts_max.unflatten(repeat(y, ts_max.num_leaves))
+
+    if ts_c.num_nodes < ts_max.num_nodes:
+        if ts_c.num_nodes > 1:
+            raise ValueError("can not map condition")
+        return tree_map(partial(jnp.where, condition), x, y)
+    return tree_map(jnp.where, condition, x, y)
+
+
+def stack(arrays):
+    return tree_map(lambda *el: jnp.stack(el), *arrays)
+
+
+def unstack(stack):
+    element_count = tree_leaves(stack)[0].shape[0]
+    split = partial(jnp.split, indices_or_sections=element_count)
+    unstacked = tree_transpose(
+        tree_structure(stack), tree_structure((0., ) * element_count),
+        tree_map(split, stack)
+    )
+    return tree_map(partial(jnp.squeeze, axis=0), unstacked)
+
+
+def map_forest(
+    f: Callable,
+    in_axes: Union[int, Tuple] = 0,
+    out_axes: Union[int, Tuple] = 0,
+    tree_transpose_output: bool = True,
+    mapping: Union[str, Callable] = 'vmap',
+    **kwargs
+) -> Callable:
+    from jax import vmap, pmap
+
+    if out_axes != 0:
+        raise TypeError("`out_axis` not yet supported")
+    in_axes = in_axes if isinstance(in_axes, tuple) else (in_axes, )
+    i = None
+    for idx, el in enumerate(in_axes):
+        if el is not None and i is None:
+            i = idx
+        elif el is not None and i is not None:
+            ve = "mapping over more than one axis is not yet supported"
+            raise ValueError(ve)
+    if i is None:
+        raise ValueError("must map over at least one axis")
+    if not isinstance(i, int):
+        te = "mapping over a non integer axis is not yet supported"
+        raise TypeError(te)
+
+    if isinstance(mapping, str):
+        if mapping == 'vmap' or mapping == 'v':
+            f_map = vmap(f, in_axes=in_axes, out_axes=out_axes, **kwargs)
+        elif mapping == 'pmap' or mapping == 'p':
+            f_map = pmap(f, in_axes=in_axes, out_axes=out_axes, **kwargs)
+        elif mapping == 'lax.map' or mapping == 'lax':
+            if all(el == 0
+                   for el in in_axes) and np.all(0 == np.array(out_axes)):
+                f_map = partial(lax.map, f)
+            else:
+                ve = (
+                    "mapping `in_axes` and `out_axes` along another axis than"
+                    " the 0-axis is not possible for `lax.map`"
+                )
+                raise ValueError(ve)
+        else:
+            ve = (
+                f"{mapping} is not an accepted key to a mapping function"
+                "; please pass function directly"
+            )
+            raise ValueError(ve)
+    elif callable(mapping):
+        f_map = mapping(f, in_axes=in_axes, out_axes=out_axes, **kwargs)
+    else:
+        te = (
+            f"invalid `mapping` of type {type(mapping)!r}"
+            "; expected string or callable"
+        )
+        raise TypeError(te)
+
+    def apply(*xs):
+        if not isinstance(xs[i], (list, tuple)):
+            te = f"expected mapped axes to be a tuple; got {type(xs[i])}"
+            raise TypeError(te)
+        x_T = stack(xs[i])
+
+        out_T = f_map(*xs[:i], x_T, *xs[i + 1:])
+        # Since `out_axes` is forced to be `0`, we don't need to worry about
+        # transposing only part of the output
+        if not tree_transpose_output:
+            return out_T
+        return unstack(out_T)
+
+    return apply
+
+
+def map_forest_mean(method, mapping='vmap', *args, **kwargs) -> Callable:
+    method_map = map_forest(
+        method, *args, tree_transpose_output=False, mapping=mapping, **kwargs
+    )
+
+    def meaned_apply(*xs, **xs_kw):
+        return tree_map(partial(jnp.mean, axis=0), method_map(*xs, **xs_kw))
+
+    return meaned_apply
diff --git a/src/re/hmc.py b/src/re/hmc.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3bf05ae5def3afc17fd1bab9480857f89922d8b
--- /dev/null
+++ b/src/re/hmc.py
@@ -0,0 +1,630 @@
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial
+from typing import Callable, NamedTuple, TypeVar, Union
+
+from jax import numpy as jnp
+from jax import random, tree_util
+from jax.experimental import host_callback
+from jax.lax import population_count
+from jax.scipy.special import expit
+
+from .disable_jax_control_flow import cond, fori_loop, while_loop
+from .forest_util import select
+from .sugar import random_like
+
+_DEBUG_FLAG = False
+
+_DEBUG_TREE_END_IDXS = []
+_DEBUG_SUBTREE_END_IDXS = []
+_DEBUG_STORE = []
+
+
+def _DEBUG_ADD_QP(qp):
+    """Stores **all** results of leapfrog integration"""
+    global _DEBUG_STORE
+    _DEBUG_STORE.append(qp)
+
+
+def _DEBUG_FINISH_TREE(dummy_arg):
+    """Signal the position of a finished tree in `_DEBUG_STORE`"""
+    global _DEBUG_TREE_END_IDXS
+    _DEBUG_TREE_END_IDXS.append(len(_DEBUG_STORE))
+
+
+def _DEBUG_FINISH_SUBTREE(dummy_arg):
+    """Signal the position of a finished sub-tree in `_DEBUG_STORE`"""
+    global _DEBUG_SUBTREE_END_IDXS
+    _DEBUG_SUBTREE_END_IDXS.append(len(_DEBUG_STORE))
+
+
+### COMMON FUNCTIONALITY
+Q = TypeVar("Q")
+
+
+class QP(NamedTuple):
+    """Object holding a pair of position and momentum.
+
+    Attributes
+    ----------
+    position : Q
+        Position.
+    momentum : Q
+        Momentum.
+    """
+    position: Q
+    momentum: Q
+
+
+def flip_momentum(qp: QP) -> QP:
+    return QP(position=qp.position, momentum=-qp.momentum)
+
+
+def sample_momentum_from_diagonal(*, key, mass_matrix_sqrt):
+    """
+    Draw a momentum sample from the kinetic energy of the hamiltonian.
+
+    Parameters
+    ----------
+    key: ndarray
+        PRNGKey used as the random key.
+    mass_matrix_sqrt: ndarray
+        The left square-root mass matrix (i.e. square-root of the inverse
+        diagonal covariance) to use for sampling. Diagonal matrix represented
+        as (possibly pytree of) ndarray vector containing the entries of the
+        diagonal.
+    """
+    normal = random_like(key=key, primals=mass_matrix_sqrt, rng=random.normal)
+    return tree_util.tree_map(jnp.multiply, mass_matrix_sqrt, normal)
+
+
+# TODO: how to randomize step size (neal sect. 3.2)
+# @partial(jit, static_argnames=('potential_energy_gradient',))
+def leapfrog_step(
+    potential_energy_gradient,
+    kinetic_energy_gradient,
+    step_size,
+    inverse_mass_matrix,
+    qp: QP,
+):
+    """
+    Perform one iteration of the leapfrog integrator forwards in time.
+
+    Parameters
+    ----------
+    potential_energy_gradient: Callable[[ndarray], float]
+        Potential energy gradient part of the hamiltonian (V). Depends on
+        position only.
+    qp: QP
+        Point in position and momentum space from which to start integration.
+    step_size: float
+        Step length (usually called epsilon) of the leapfrog integrator.
+    """
+    position = qp.position
+    momentum = qp.momentum
+
+    momentum_halfstep = (
+        momentum - (step_size / 2.) * potential_energy_gradient(position)
+    )
+
+    position_fullstep = position + step_size * kinetic_energy_gradient(
+        inverse_mass_matrix, momentum_halfstep
+    )
+
+    momentum_fullstep = (
+        momentum_halfstep -
+        (step_size / 2.) * potential_energy_gradient(position_fullstep)
+    )
+
+    qp_fullstep = QP(position=position_fullstep, momentum=momentum_fullstep)
+
+    global _DEBUG_FLAG
+    if _DEBUG_FLAG:
+        # append result to global list variable
+        host_callback.call(_DEBUG_ADD_QP, qp_fullstep)
+
+    return qp_fullstep
+
+
+### SIMPLE HMC
+class AcceptedAndRejected(NamedTuple):
+    accepted_qp: QP
+    rejected_qp: QP
+    accepted: Union[jnp.ndarray, bool]
+    diverging: Union[jnp.ndarray, bool]
+
+
+# @partial(jit, static_argnames=('potential_energy', 'potential_energy_gradient'))
+def generate_hmc_acc_rej(
+    *, key, initial_qp, potential_energy, kinetic_energy, inverse_mass_matrix,
+    stepper, num_steps, step_size, max_energy_difference
+) -> AcceptedAndRejected:
+    """
+    Generate a sample given the initial position.
+
+    Parameters
+    ----------
+    key: ndarray
+        a PRNGKey used as the random key
+    position: ndarray
+        The the starting position of this step of the markov chain.
+    potential_energy: Callable[[ndarray], float]
+        The potential energy, which is the distribution to be sampled from.
+    mass_matrix: ndarray
+        The mass matrix used in the kinetic energy
+    num_steps: int
+        The number of steps the leapfrog integrator should perform.
+    step_size: float
+        The step size (usually epsilon) for the leapfrog integrator.
+    """
+    loop_body = partial(stepper, step_size, inverse_mass_matrix)
+    new_qp = fori_loop(
+        lower=0,
+        upper=num_steps,
+        body_fun=lambda _, args: loop_body(args),
+        init_val=initial_qp
+    )
+    # this flipping is needed to make the proposal distribution symmetric
+    # doesn't have any effect on acceptance though because kinetic energy depends on momentum^2
+    # might have an effect with other kinetic energies though
+    proposed_qp = flip_momentum(new_qp)
+
+    total_energy = partial(
+        total_energy_of_qp,
+        potential_energy=potential_energy,
+        kinetic_energy_w_inv_mass=partial(kinetic_energy, inverse_mass_matrix)
+    )
+    energy_diff = total_energy(initial_qp) - total_energy(proposed_qp)
+    energy_diff = jnp.where(jnp.isnan(energy_diff), jnp.inf, energy_diff)
+    transition_probability = jnp.minimum(1., jnp.exp(energy_diff))
+
+    accept = random.bernoulli(key, transition_probability)
+    accepted_qp, rejected_qp = select(
+        accept,
+        (proposed_qp, initial_qp),
+        (initial_qp, proposed_qp),
+    )
+    diverging = jnp.abs(energy_diff) > max_energy_difference
+    return AcceptedAndRejected(
+        accepted_qp, rejected_qp, accepted=accept, diverging=diverging
+    )
+
+
+### NUTS
+class Tree(NamedTuple):
+    """Object carrying tree metadata.
+
+    Attributes
+    ----------
+    left, right : QP
+        Respective endpoints of the trees path.
+    logweight: Union[jnp.ndarray, float]
+        Sum over all -H(q, p) in the tree's path.
+    proposal_candidate: QP
+        Sample from the trees path, distributed as exp(-H(q, p)).
+    turning: Union[jnp.ndarray, bool]
+        Indicator for either the left or right endpoint are a uturn or any
+        subtree is a uturn.
+    diverging: Union[jnp.ndarray, bool]
+        Indicator for a large increase in energy in the next larger tree.
+    depth: Union[jnp.ndarray, int]
+        Levels of the tree.
+    cumulative_acceptance: Union[jnp.ndarray, float]
+        Sum of all acceptance probabilities relative to some initial energy
+        value. This value is distinct from `logweight` as its absolute value is
+        only well defined for the very final tree of NUTS.
+    """
+    left: QP
+    right: QP
+    logweight: Union[jnp.ndarray, float]
+    proposal_candidate: QP
+    turning: Union[jnp.ndarray, bool]
+    diverging: Union[jnp.ndarray, bool]
+    depth: Union[jnp.ndarray, int]
+    cumulative_acceptance: Union[jnp.ndarray, float]
+
+
+def total_energy_of_qp(qp, potential_energy, kinetic_energy_w_inv_mass):
+    return potential_energy(qp.position
+                           ) + kinetic_energy_w_inv_mass(qp.momentum)
+
+
+def generate_nuts_tree(
+    initial_qp,
+    key,
+    step_size,
+    max_tree_depth,
+    stepper: Callable[[Union[jnp.ndarray, float], Q, QP], QP],
+    potential_energy,
+    kinetic_energy: Callable[[Q, Q], float],
+    inverse_mass_matrix: Q,
+    bias_transition: bool = True,
+    max_energy_difference: Union[jnp.ndarray, float] = jnp.inf
+) -> Tree:
+    """Generate a sample given the initial position.
+
+    This call implements a No-U-Turn-Sampler.
+
+    Parameters
+    ----------
+    initial_qp: QP
+        Starting pair of (position, momentum). **NOTE**, the momentum must be
+        resampled from conditional distribution **BEFORE** passing it into this
+        function!
+    key: ndarray
+        PRNGKey used as the random key.
+    step_size: float
+        Step size (usually called epsilon) for the leapfrog integrator.
+    max_tree_depth: int
+        The maximum depth of the trajectory tree before the expansion is
+        terminated. At the maximum iteration depth, the current value is
+        returned even if the U-turn condition is not met. The maximum number of
+        points (/integration steps) per trajectory is :math:`N =
+        2^{\\mathrm{max\\_tree\\_depth}}`. This function requires memory linear
+        in max_tree_depth, i.e. logarithmic in trajectory length. It is used to
+        statically allocate memory in advance.
+    stepper: Callable[[float, Q, QP], QP]
+        The function that performs (Leapfrog) steps. Takes as arguments (in order)
+        1) step size (containing the direction): float ,
+        2) inverse mass matrix: Q ,
+        3) starting point: QP .
+    potential_energy: Callable[[Q], float]
+        The potential energy, of the distribution to be sampled from. Takes
+        only the position part (QP.position) as argument.
+    kinetic_energy: Callable[[Q, Q], float], optional
+        Mapping of the momentum to its corresponding kinetic energy. As
+        argument the function takes the inverse mass matrix and the momentum.
+
+    Returns
+    -------
+    current_tree: Tree
+        The final tree, carrying a sample from the target distribution.
+
+    See Also
+    --------
+    No-U-Turn Sampler original paper (2011): https://arxiv.org/abs/1111.4246
+    NumPyro Iterative NUTS paper: https://arxiv.org/abs/1912.11554
+    Combination of samples from two trees, Sampling from trajectories according to target distribution in this paper's Appendix: https://arxiv.org/abs/1701.02434
+    """
+    # initialize depth 0 tree, containing 2**0 = 1 points
+    initial_neg_energy = -total_energy_of_qp(
+        initial_qp, potential_energy,
+        partial(kinetic_energy, inverse_mass_matrix)
+    )
+    current_tree = Tree(
+        left=initial_qp,
+        right=initial_qp,
+        logweight=initial_neg_energy,
+        proposal_candidate=initial_qp,
+        turning=False,
+        diverging=False,
+        depth=0,
+        cumulative_acceptance=jnp.zeros_like(initial_neg_energy)
+    )
+
+    def _cont_cond(loop_state):
+        _, current_tree, stop = loop_state
+        return (~stop) & (current_tree.depth <= max_tree_depth)
+
+    def cond_tree_doubling(loop_state):
+        key, current_tree, _ = loop_state
+        key, key_dir, key_subtree, key_merge = random.split(key, 4)
+
+        go_right = random.bernoulli(key_dir, 0.5)
+
+        # build tree adjacent to current_tree
+        new_subtree = iterative_build_tree(
+            key_subtree,
+            current_tree,
+            step_size,
+            go_right,
+            stepper,
+            potential_energy,
+            kinetic_energy,
+            inverse_mass_matrix,
+            max_tree_depth,
+            initial_neg_energy=initial_neg_energy,
+            max_energy_difference=max_energy_difference
+        )
+        # Mark current tree as diverging if it diverges in the next step
+        current_tree = current_tree._replace(diverging=new_subtree.diverging)
+
+        # combine current_tree and new_subtree into a tree which is one layer deeper only if new_subtree has no turning subtrees (including itself)
+        current_tree = cond(
+            # If new tree is turning or diverging, do not merge
+            pred=new_subtree.turning | new_subtree.diverging,
+            true_fun=lambda old_and_new: old_and_new[0],
+            false_fun=lambda old_and_new: merge_trees(
+                key_merge,
+                old_and_new[0],
+                old_and_new[1],
+                go_right,
+                bias_transition=bias_transition
+            ),
+            operand=(current_tree, new_subtree),
+        )
+        # stop if new subtree was turning -> we sample from the old one and don't expand further
+        # stop if new total tree is turning -> we sample from the combined trajectory and don't expand further
+        stop = new_subtree.turning | current_tree.turning
+        stop |= new_subtree.diverging
+        return (key, current_tree, stop)
+
+    loop_state = (key, current_tree, False)
+    _, current_tree, _ = while_loop(_cont_cond, cond_tree_doubling, loop_state)
+
+    global _DEBUG_FLAG
+    if _DEBUG_FLAG:
+        host_callback.call(_DEBUG_FINISH_TREE, None)
+
+    return current_tree
+
+
+def tree_index_get(ptree, idx):
+    return tree_util.tree_map(lambda arr: arr[idx], ptree)
+
+
+def tree_index_update(x, idx, y):
+    from jax.tree_util import tree_map
+
+    return tree_map(lambda x_el, y_el: x_el.at[idx].set(y_el), x, y)
+
+
+def count_trailing_ones(n):
+    """Count the number of trailing, consecutive ones in the binary
+    representation of `n`.
+
+    Warning
+    -------
+    `n` must be positive and strictly smaller than 2**64
+
+    Examples
+    --------
+    >>> print(bin(23), count_trailing_one_bits(23))
+    0b10111 3
+    """
+    # taken from http://num.pyro.ai/en/stable/_modules/numpyro/infer/hmc_util.html
+    _, trailing_ones_count = while_loop(
+        lambda nc: (nc[0] & 1) != 0, lambda nc: (nc[0] >> 1, nc[1] + 1), (n, 0)
+    )
+    return trailing_ones_count
+
+
+def is_euclidean_uturn(qp_left, qp_right):
+    """
+    See Also
+    --------
+    Betancourt - A conceptual introduction to Hamiltonian Monte Carlo
+    """
+    return (
+        (qp_right.momentum.dot(qp_right.position - qp_left.position) < 0.) &
+        (qp_left.momentum.dot(qp_left.position - qp_right.position) < 0.)
+    )
+
+
+# Essentially algorithm 2 from https://arxiv.org/pdf/1912.11554.pdf
+def iterative_build_tree(
+    key, initial_tree, step_size, go_right, stepper, potential_energy,
+    kinetic_energy, inverse_mass_matrix, max_tree_depth, initial_neg_energy,
+    max_energy_difference
+):
+    """
+    Starting from either the left or right endpoint of a given tree, builds a
+    new adjacent tree of the same size.
+
+    Parameters
+    ----------
+    key: ndarray
+        PRNGKey to choose a sample when adding QPs to the tree.
+    initial_tree: Tree
+        Tree to be extended (doubled) on the left or right.
+    step_size: float
+        The step size (usually called epsilon) for the leapfrog integrator.
+    go_right: bool
+        If `go_right` start at the right end, going right else start at the
+        left end, going left.
+    stepper: Callable[[float, Q, QP], QP]
+        The function that performs (Leapfrog) steps. Takes as arguments (in order)
+        1) step size (containing the direction): float ,
+        2) inverse mass matrix: Q ,
+        3) starting point: QP .
+    potential_energy: Callable[[Q], float]
+        Potential energy, of the distribution to be sampled from. Takes
+        only the position part (QP.position) as argument.
+    kinetic_energy: Callable[[Q, Q], float], optional
+        Mapping of the momentum to its corresponding kinetic energy. As
+        argument the function takes the inverse mass matrix and the momentum.
+    max_tree_depth: int
+        An upper bound on the 'depth' argument, but has no effect on the
+        functions behaviour. It's only required to statically set the size of
+        the `S` array (Q).
+    """
+    # 1. choose start point of integration
+    z = select(go_right, initial_tree.right, initial_tree.left)
+    depth = initial_tree.depth
+    max_num_proposals = 2**depth
+    # 2. build / collect new states
+    # Create a storage for left endpoints of subtrees. Size is determined
+    # statically by the `max_tree_depth` parameter.
+    # NOTE, let's hope this does not break anything but in principle we only
+    # need `max_tree_depth` element even though the tree can be of length `max_tree_depth +
+    # 1`. This is because we will never access the last element.
+    S = tree_util.tree_map(
+        lambda proto: jnp.
+        empty_like(proto, shape=(max_tree_depth, ) + jnp.shape(proto)), z
+    )
+
+    z = stepper(
+        jnp.where(go_right, 1., -1.) * step_size, inverse_mass_matrix, z
+    )
+    neg_energy = -total_energy_of_qp(
+        z, potential_energy, partial(kinetic_energy, inverse_mass_matrix)
+    )
+    diverging = jnp.abs(neg_energy - initial_neg_energy) > max_energy_difference
+    cum_acceptance = jnp.minimum(1., jnp.exp(initial_neg_energy - neg_energy))
+    incomplete_tree = Tree(
+        left=z,
+        right=z,
+        logweight=neg_energy,
+        proposal_candidate=z,
+        turning=False,
+        diverging=diverging,
+        depth=-1,
+        cumulative_acceptance=cum_acceptance
+    )
+    S = tree_index_update(S, 0, z)
+
+    def amend_incomplete_tree(state):
+        n, incomplete_tree, z, S, key = state
+
+        key, key_choose_candidate = random.split(key)
+        z = stepper(
+            jnp.where(go_right, 1., -1.) * step_size, inverse_mass_matrix, z
+        )
+        incomplete_tree = add_single_qp_to_tree(
+            key_choose_candidate,
+            incomplete_tree,
+            z,
+            go_right,
+            potential_energy,
+            kinetic_energy,
+            inverse_mass_matrix,
+            initial_neg_energy=initial_neg_energy,
+            max_energy_difference=max_energy_difference
+        )
+
+        def _even_fun(S):
+            # n is even, the current z is w.l.o.g. a left endpoint of some
+            # subtrees. Register the current z to be used in turning condition
+            # checks later, when the right endpoints of it's subtrees are
+            # generated.
+            S = tree_index_update(S, population_count(n), z)
+            return S, False
+
+        def _odd_fun(S):
+            # n is odd, the current z is w.l.o.g a right endpoint of some
+            # subtrees. Check turning condition against all left endpoints of
+            # subtrees that have the current z (/n) as their right endpoint.
+
+            # l = nubmer of subtrees that have current z as their right endpoint.
+            l = count_trailing_ones(n)
+            # inclusive indices into S referring to the left endpoints of the l subtrees.
+            i_max_incl = population_count(n - 1)
+            i_min_incl = i_max_incl - l + 1
+            # TODO: this should traverse the range in reverse
+            turning = fori_loop(
+                lower=i_min_incl,
+                upper=i_max_incl + 1,
+                # TODO: conditional for early termination
+                body_fun=lambda k, turning: turning |
+                is_euclidean_uturn(tree_index_get(S, k), z),
+                init_val=False
+            )
+            return S, turning
+
+        S, turning = cond(
+            pred=n % 2 == 0, true_fun=_even_fun, false_fun=_odd_fun, operand=S
+        )
+        incomplete_tree = incomplete_tree._replace(turning=turning)
+        return (n + 1, incomplete_tree, z, S, key)
+
+    def _cont_cond(state):
+        n, incomplete_tree, *_ = state
+        return (n < max_num_proposals) & (~incomplete_tree.turning
+                                         ) & (~incomplete_tree.diverging)
+
+    n, incomplete_tree, *_ = while_loop(
+        # while n < 2**depth and not stop
+        cond_fun=_cont_cond,
+        body_fun=amend_incomplete_tree,
+        init_val=(1, incomplete_tree, z, S, key)
+    )
+
+    global _DEBUG_FLAG
+    if _DEBUG_FLAG:
+        host_callback.call(_DEBUG_FINISH_SUBTREE, None)
+
+    # The depth of a tree which was aborted early is possibly ill defined
+    depth = jnp.where(n == max_num_proposals, depth, -1)
+    return incomplete_tree._replace(depth=depth)
+
+
+def add_single_qp_to_tree(
+    key, tree, qp, go_right, potential_energy, kinetic_energy,
+    inverse_mass_matrix, initial_neg_energy, max_energy_difference
+):
+    """Helper function for progressive sampling. Takes a tree with a sample, and
+    a new endpoint, propagates sample.
+    """
+    # This is technically just a special case of merge_trees with one of the
+    # trees being a singleton, depth 0 tree. However, no turning check is
+    # required and it is not possible to bias the transition.
+    left, right = select(go_right, (tree.left, qp), (qp, tree.right))
+
+    neg_energy = -total_energy_of_qp(
+        qp, potential_energy, partial(kinetic_energy, inverse_mass_matrix)
+    )
+    diverging = jnp.abs(neg_energy - initial_neg_energy) > max_energy_difference
+    # ln(e^-H_1 + e^-H_2)
+    total_logweight = jnp.logaddexp(tree.logweight, neg_energy)
+    # expit(x-y) := 1 / (1 + e^(-(x-y))) = 1 / (1 + e^(y-x)) = e^x / (e^y + e^x)
+    prob_of_keeping_old = expit(tree.logweight - neg_energy)
+    remain = random.bernoulli(key, prob_of_keeping_old)
+    proposal_candidate = select(remain, tree.proposal_candidate, qp)
+    # NOTE, set an invalid depth as to indicate that adding a single QP to a
+    # perfect binary tree does not yield another perfect binary tree
+    cum_acceptance = tree.cumulative_acceptance + jnp.minimum(
+        1., jnp.exp(initial_neg_energy - neg_energy)
+    )
+    return Tree(
+        left,
+        right,
+        total_logweight,
+        proposal_candidate,
+        turning=tree.turning,
+        diverging=diverging,
+        depth=-1,
+        cumulative_acceptance=cum_acceptance
+    )
+
+
+def merge_trees(key, current_subtree, new_subtree, go_right, bias_transition):
+    """Merges two trees, propagating the proposal_candidate"""
+    # 5. decide which sample to take based on total weights (merge trees)
+    if bias_transition:
+        # Bias the transition towards the new subtree (see Betancourt
+        # conceptual intro (and Numpyro))
+        transition_probability = jnp.minimum(
+            1., jnp.exp(new_subtree.logweight - current_subtree.logweight)
+        )
+    else:
+        # expit(x-y) := 1 / (1 + e^(-(x-y))) = 1 / (1 + e^(y-x)) = e^x / (e^y + e^x)
+        transition_probability = expit(
+            new_subtree.logweight - current_subtree.logweight
+        )
+    # print(f"prob of choosing new sample: {transition_probability}")
+    new_sample = select(
+        random.bernoulli(key, transition_probability),
+        new_subtree.proposal_candidate, current_subtree.proposal_candidate
+    )
+    # 6. define new tree
+    left, right = select(
+        go_right,
+        (current_subtree.left, new_subtree.right),
+        (new_subtree.left, current_subtree.right),
+    )
+    turning = is_euclidean_uturn(left, right)
+    diverging = current_subtree.diverging | new_subtree.diverging
+    neg_energy = jnp.logaddexp(new_subtree.logweight, current_subtree.logweight)
+    cum_acceptance = current_subtree.cumulative_acceptance + new_subtree.cumulative_acceptance
+    merged_tree = Tree(
+        left=left,
+        right=right,
+        logweight=neg_energy,
+        proposal_candidate=new_sample,
+        turning=turning,
+        diverging=diverging,
+        depth=current_subtree.depth + 1,
+        cumulative_acceptance=cum_acceptance
+    )
+    return merged_tree
diff --git a/src/re/hmc_oo.py b/src/re/hmc_oo.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c001065cc35ae03dbf545ddad1428c4d3cb3b28
--- /dev/null
+++ b/src/re/hmc_oo.py
@@ -0,0 +1,355 @@
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial
+from typing import Any, Callable, NamedTuple, Optional, Tuple, Union
+
+import numpy as np
+from jax import grad
+from jax import numpy as jnp
+from jax import random, tree_util
+
+from .disable_jax_control_flow import fori_loop
+from .hmc import AcceptedAndRejected, Q, QP, Tree
+from .hmc import (
+    generate_hmc_acc_rej,
+    generate_nuts_tree,
+    leapfrog_step,
+    sample_momentum_from_diagonal,
+    tree_index_update,
+)
+
+
+def _parse_diag_mass_matrix(mass_matrix, position_proto: Q) -> Q:
+    if isinstance(mass_matrix,
+                  (float, jnp.ndarray)) and jnp.size(mass_matrix) == 1:
+        mass_matrix = tree_util.tree_map(
+            partial(jnp.full_like, fill_value=mass_matrix), position_proto
+        )
+    elif tree_util.tree_structure(mass_matrix
+                                 ) == tree_util.tree_structure(position_proto):
+        shape_match_tree = tree_util.tree_map(
+            lambda a1, a2: jnp.shape(a1) == jnp.shape(a2), mass_matrix,
+            position_proto
+        )
+        shape_and_structure_match = all(
+            tree_util.tree_flatten(shape_match_tree)
+        )
+        if not shape_and_structure_match:
+            ve = "matrix has same tree_structe as the position but shapes do not match up"
+            raise ValueError(ve)
+    else:
+        te = "matrix must either be float or have same tree structure as the position"
+        raise TypeError(te)
+
+    return mass_matrix
+
+
+class Chain(NamedTuple):
+    """Object carrying chain metadata; think: transposed Tree with new axis.
+    """
+    # Q but with one more dimension on the first axes of the leave tensors
+    samples: Q
+    divergences: jnp.ndarray
+    acceptance: Union[jnp.ndarray, float]
+    depths: Optional[jnp.ndarray] = None
+    trees: Optional[Union[Tree, AcceptedAndRejected]] = None
+
+
+class _Sampler:
+    def __init__(
+        self,
+        potential_energy: Callable[[Q], Union[jnp.ndarray, float]],
+        inverse_mass_matrix,
+        position_proto: Q,
+        step_size: Union[jnp.ndarray, float] = 1.0,
+        max_energy_difference: Union[jnp.ndarray, float] = jnp.inf
+    ):
+        if not callable(potential_energy):
+            raise TypeError()
+        if not isinstance(step_size, (jnp.ndarray, float)):
+            raise TypeError()
+
+        self.potential_energy = potential_energy
+
+        self.inverse_mass_matrix = _parse_diag_mass_matrix(
+            inverse_mass_matrix, position_proto=position_proto
+        )
+        self.mass_matrix_sqrt = self.inverse_mass_matrix**(-0.5)
+
+        self.step_size = step_size
+
+        def kinetic_energy(inverse_mass_matrix, momentum):
+            # NOTE, assume a diagonal mass-matrix
+            return inverse_mass_matrix.dot(momentum**2) / 2.
+
+        self.kinetic_energy = kinetic_energy
+        kinetic_energy_gradient = lambda inv_m, mom: inv_m * mom
+        potential_energy_gradient = grad(self.potential_energy)
+        self.stepper = partial(
+            leapfrog_step, potential_energy_gradient, kinetic_energy_gradient
+        )
+
+        self.max_energy_difference = max_energy_difference
+
+        def sample_next_state(key,
+                              prev_position: Q) -> Tuple[Any, Tuple[Any, Q]]:
+            raise NotImplementedError()
+
+        self.sample_next_state = sample_next_state
+
+    @staticmethod
+    def init_chain(
+        num_samples: int, position_proto, save_intermediates: bool
+    ) -> Chain:
+        raise NotImplementedError()
+
+    @staticmethod
+    def update_chain(
+        chain: Chain, idx: Union[jnp.ndarray, int], tree: Tree
+    ) -> Chain:
+        raise NotImplementedError()
+
+    def generate_n_samples(
+        self,
+        key: Any,
+        initial_position: Q,
+        num_samples,
+        *,
+        save_intermediates: bool = False
+    ) -> Tuple[Chain, Tuple[Any, Q]]:
+        if not isinstance(key, (jnp.ndarray, np.ndarray)):
+            if isinstance(key, int):
+                key = random.PRNGKey(key)
+            else:
+                raise TypeError()
+
+        chain = self.init_chain(
+            num_samples, initial_position, save_intermediates
+        )
+
+        def amend_chain(idx, state):
+            chain, core_state = state
+            tree, core_state = self.sample_next_state(*core_state)
+            chain = self.update_chain(chain, idx, tree)
+            return chain, core_state
+
+        chain, core_state = fori_loop(
+            lower=0,
+            upper=num_samples,
+            body_fun=amend_chain,
+            init_val=(chain, (key, initial_position))
+        )
+
+        return chain, core_state
+
+
+class NUTSChain(_Sampler):
+    def __init__(
+        self,
+        potential_energy: Callable[[Q], Union[float, jnp.ndarray]],
+        inverse_mass_matrix,
+        position_proto: Q,
+        step_size: float = 1.0,
+        max_tree_depth: int = 10,
+        bias_transition: bool = True,
+        max_energy_difference: float = jnp.inf
+    ):
+        super().__init__(
+            potential_energy=potential_energy,
+            inverse_mass_matrix=inverse_mass_matrix,
+            position_proto=position_proto,
+            step_size=step_size,
+            max_energy_difference=max_energy_difference
+        )
+
+        if not isinstance(max_tree_depth, int):
+            raise TypeError()
+        self.bias_transition = bias_transition
+        self.max_tree_depth = max_tree_depth
+
+        def sample_next_state(key,
+                              prev_position: Q) -> Tuple[Tree, Tuple[Any, Q]]:
+            key, key_momentum, key_nuts = random.split(key, 3)
+
+            resampled_momentum = sample_momentum_from_diagonal(
+                key=key_momentum, mass_matrix_sqrt=self.mass_matrix_sqrt
+            )
+            qp = QP(position=prev_position, momentum=resampled_momentum)
+
+            tree = generate_nuts_tree(
+                initial_qp=qp,
+                key=key_nuts,
+                step_size=self.step_size,
+                max_tree_depth=self.max_tree_depth,
+                stepper=self.stepper,
+                potential_energy=self.potential_energy,
+                kinetic_energy=self.kinetic_energy,
+                inverse_mass_matrix=self.inverse_mass_matrix,
+                bias_transition=self.bias_transition,
+                max_energy_difference=self.max_energy_difference
+            )
+            return tree, (key, tree.proposal_candidate.position)
+
+        self.sample_next_state = sample_next_state
+
+    @staticmethod
+    def init_chain(
+        num_samples: int, position_proto, save_intermediates: bool
+    ) -> Chain:
+        samples = tree_util.tree_map(
+            lambda arr: jnp.
+            zeros_like(arr, shape=(num_samples, ) + jnp.shape(arr)),
+            position_proto
+        )
+        depths = jnp.zeros(num_samples, dtype=jnp.uint64)
+        divergences = jnp.zeros(num_samples, dtype=bool)
+        chain = Chain(
+            samples=samples,
+            divergences=divergences,
+            acceptance=0.,
+            depths=depths
+        )
+        if save_intermediates:
+            _qp_proto = QP(position_proto, position_proto)
+            _tree_proto = Tree(
+                _qp_proto,
+                _qp_proto,
+                0.,
+                _qp_proto,
+                turning=True,
+                diverging=True,
+                depth=0,
+                cumulative_acceptance=0.
+            )
+            trees = tree_util.tree_map(
+                lambda leaf: jnp.
+                zeros_like(leaf, shape=(num_samples, ) + jnp.shape(leaf)),
+                _tree_proto
+            )
+            chain = chain._replace(trees=trees)
+
+        return chain
+
+    @staticmethod
+    def update_chain(
+        chain: Chain, idx: Union[jnp.ndarray, int], tree: Tree
+    ) -> Chain:
+        num_proposals = 2**jnp.array(tree.depth, dtype=jnp.uint64) - 1
+        tree_acceptance = jnp.where(
+            num_proposals > 0, tree.cumulative_acceptance / num_proposals, 0.
+        )
+
+        samples = tree_index_update(
+            chain.samples, idx, tree.proposal_candidate.position
+        )
+        divergences = chain.divergences.at[idx].set(tree.diverging)
+        depths = chain.depths.at[idx].set(tree.depth)
+        acceptance = (
+            chain.acceptance + (tree_acceptance - chain.acceptance) / (idx + 1)
+        )
+        chain = chain._replace(
+            samples=samples,
+            divergences=divergences,
+            acceptance=acceptance,
+            depths=depths
+        )
+        if chain.trees is not None:
+            trees = tree_index_update(chain.trees, idx, tree)
+            chain = chain._replace(trees=trees)
+
+        return chain
+
+
+class HMCChain(_Sampler):
+    def __init__(
+        self,
+        potential_energy: Callable,
+        inverse_mass_matrix,
+        position_proto,
+        num_steps,
+        step_size: float = 1.0,
+        max_energy_difference: float = jnp.inf
+    ):
+        super().__init__(
+            potential_energy=potential_energy,
+            inverse_mass_matrix=inverse_mass_matrix,
+            position_proto=position_proto,
+            step_size=step_size,
+            max_energy_difference=max_energy_difference
+        )
+
+        if not isinstance(num_steps, (jnp.ndarray, int)):
+            raise TypeError()
+        self.num_steps = num_steps
+
+        def sample_next_state(key,
+                              prev_position: Q) -> Tuple[Tree, Tuple[Any, Q]]:
+            key, key_choose, key_momentum_resample = random.split(key, 3)
+
+            resampled_momentum = sample_momentum_from_diagonal(
+                key=key_momentum_resample,
+                mass_matrix_sqrt=self.mass_matrix_sqrt
+            )
+            qp = QP(position=prev_position, momentum=resampled_momentum)
+
+            acc_rej = generate_hmc_acc_rej(
+                key=key_choose,
+                initial_qp=qp,
+                potential_energy=self.potential_energy,
+                kinetic_energy=self.kinetic_energy,
+                inverse_mass_matrix=self.inverse_mass_matrix,
+                stepper=self.stepper,
+                num_steps=self.num_steps,
+                step_size=self.step_size,
+                max_energy_difference=self.max_energy_difference
+            )
+            return acc_rej, (key, acc_rej.accepted_qp.position)
+
+        self.sample_next_state = sample_next_state
+
+    @staticmethod
+    def init_chain(
+        num_samples: int, position_proto, save_intermediates: bool
+    ) -> Chain:
+        samples = tree_util.tree_map(
+            lambda arr: jnp.
+            zeros_like(arr, shape=(num_samples, ) + jnp.shape(arr)),
+            position_proto
+        )
+        divergences = jnp.zeros(num_samples, dtype=bool)
+        chain = Chain(samples=samples, divergences=divergences, acceptance=0.)
+        if save_intermediates:
+            _qp_proto = QP(position_proto, position_proto)
+            _acc_rej_proto = AcceptedAndRejected(
+                _qp_proto, _qp_proto, True, True
+            )
+            trees = tree_util.tree_map(
+                lambda leaf: jnp.
+                zeros_like(leaf, shape=(num_samples, ) + jnp.shape(leaf)),
+                _acc_rej_proto
+            )
+            chain = chain._replace(trees=trees)
+
+        return chain
+
+    @staticmethod
+    def update_chain(
+        chain: Chain, idx: Union[jnp.ndarray, int], acc_rej: AcceptedAndRejected
+    ) -> Chain:
+        samples = tree_index_update(
+            chain.samples, idx, acc_rej.accepted_qp.position
+        )
+        divergences = chain.divergences.at[idx].set(acc_rej.diverging)
+        acceptance = (
+            chain.acceptance + (acc_rej.accepted - chain.acceptance) /
+            (idx + 1)
+        )
+        chain = chain._replace(
+            samples=samples, divergences=divergences, acceptance=acceptance
+        )
+        if chain.trees is not None:
+            trees = tree_index_update(chain.trees, idx, acc_rej)
+            chain = chain._replace(trees=trees)
+
+        return chain
diff --git a/src/re/kl.py b/src/re/kl.py
new file mode 100644
index 0000000000000000000000000000000000000000..71a2eaf9c4082f6521fdfb6377f4a88b7411bdd4
--- /dev/null
+++ b/src/re/kl.py
@@ -0,0 +1,646 @@
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial
+from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, Union
+
+import jax
+from jax import lax
+from jax import random
+from jax.tree_util import Partial, register_pytree_node_class
+
+from . import conjugate_gradient
+from .forest_util import assert_arithmetics, map_forest, map_forest_mean, unstack
+from .likelihood import Likelihood, StandardHamiltonian
+from .sugar import random_like
+
+P = TypeVar("P")
+
+
+def sample_likelihood(likelihood: Likelihood, primals, key):
+    white_sample = random_like(key, likelihood.left_sqrt_metric_tangents_shape)
+    return likelihood.left_sqrt_metric(primals, white_sample)
+
+
+def cond_raise(condition, exception):
+    from jax.experimental.host_callback import call
+
+    def maybe_raise(condition):
+        if condition:
+            raise exception
+
+    call(maybe_raise, condition, result_shape=None)
+
+
+def _sample_standard_hamiltonian(
+    hamiltonian: StandardHamiltonian,
+    primals,
+    key,
+    from_inverse: bool,
+    cg: Callable = conjugate_gradient.static_cg,
+    cg_name: Optional[str] = None,
+    cg_kwargs: Optional[dict] = None,
+):
+    if not isinstance(hamiltonian, StandardHamiltonian):
+        te = f"`hamiltonian` of invalid type; got '{type(hamiltonian)}'"
+        raise TypeError(te)
+    cg_kwargs = cg_kwargs if cg_kwargs is not None else {}
+
+    subkey_nll, subkey_prr = random.split(key, 2)
+    nll_smpl = sample_likelihood(
+        hamiltonian.likelihood, primals, key=subkey_nll
+    )
+    prr_inv_metric_smpl = random_like(key=subkey_prr, primals=primals)
+    # One may transform any metric sample to a sample of the inverse
+    # metric by simply applying the inverse metric to it
+    prr_smpl = prr_inv_metric_smpl
+    # Note, we can sample antithetically by swapping the global sign of
+    # the metric sample below (which corresponds to mirroring the final
+    # sample) and additionally by swapping the relative sign between
+    # the prior and the likelihood sample. The first technique is
+    # computationally cheap and empirically known to improve stability.
+    # The latter technique requires an additional inversion and its
+    # impact on stability is still unknown.
+    # TODO: investigate the impact of sampling the prior and likelihood
+    # antithetically.
+    met_smpl = nll_smpl + prr_smpl
+    if from_inverse:
+        inv_metric_at_p = partial(
+            cg, Partial(hamiltonian.metric, primals), **{
+                "name": cg_name,
+                **cg_kwargs
+            }
+        )
+        signal_smpl, info = inv_metric_at_p(met_smpl, x0=prr_inv_metric_smpl)
+        cond_raise(
+            (info is not None) & (info < 0),
+            ValueError("conjugate gradient failed")
+        )
+        return signal_smpl, met_smpl
+    else:
+        return None, met_smpl
+
+
+def sample_standard_hamiltonian(
+    hamiltonian: StandardHamiltonian, primals, *args, **kwargs
+):
+    r"""Draws a sample of which the covariance is the metric or the inverse
+    metric of the Hamiltonian.
+
+    To sample from the inverse metric, we need to be able to draw samples
+    which have the metric as covariance structure and we need to be able to
+    apply the inverse metric. The first part is trivial since we can use
+    the left square root of the metric :math:`L` associated with every
+    likelihood:
+
+    .. math::
+
+        \tilde{d} \leftarrow \mathcal{G}(0,\mathbb{1}) \\
+        t = L \tilde{d}
+
+    with :math:`t` now having a covariance structure of
+
+    .. math::
+        <t t^\dagger> = L <\tilde{d} \tilde{d}^\dagger> L^\dagger = M .
+
+    We now need to apply the inverse metric in order to transform the
+    sample to an inverse sample. We can do so using the conjugate gradient
+    algorithm which yields the solution to :math:`M s = t`, i.e. applies the
+    inverse of :math:`M` to :math:`t`:
+
+    .. math::
+
+        M &s =  t \\
+        &s = M^{-1} t = cg(M, t) .
+
+    Parameters
+    ----------
+    hamiltonian:
+        Hamiltonian with standard prior from which to draw samples.
+    primals : tree-like structure
+        Position at which to draw samples.
+    key : tuple, list or jnp.ndarray of uint32 of length two
+        Random key with which to generate random variables in data domain.
+    cg : callable, optional
+        Implementation of the conjugate gradient algorithm and used to
+        apply the inverse of the metric.
+    cg_kwargs : dict, optional
+        Additional keyword arguments passed on to `cg`.
+
+    Returns
+    -------
+    sample : tree-like structure
+        Sample of which the covariance is the inverse metric.
+
+    See also
+    --------
+    `Metric Gaussian Variational Inference`, Jakob Knollmüller,
+    Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
+    """
+    assert_arithmetics(primals)
+    inv_met_smpl, _ = _sample_standard_hamiltonian(
+        hamiltonian, primals, *args, from_inverse=True, **kwargs
+    )
+    return inv_met_smpl
+
+
+def geometrically_sample_standard_hamiltonian(
+    hamiltonian: StandardHamiltonian,
+    primals,
+    key,
+    mirror_linear_sample: bool,
+    linear_sampling_cg: Callable = conjugate_gradient.static_cg,
+    linear_sampling_name: Optional[str] = None,
+    linear_sampling_kwargs: Optional[dict] = None,
+    non_linear_sampling_method: str = "NewtonCG",
+    non_linear_sampling_name: Optional[str] = None,
+    non_linear_sampling_kwargs: Optional[dict] = None,
+):
+    r"""Draws a sample which follows a standard normal distribution in the
+    canonical coordinate system of the Riemannian manifold associated with the
+    metric of the other distribution. The coordinate transformation is
+    approximated by expanding around a given point `primals`.
+
+    Parameters
+    ----------
+    hamiltonian:
+        Hamiltonian with standard prior from which to draw samples.
+    primals : tree-like structure
+        Position at which to draw samples.
+    key : tuple, list or jnp.ndarray of uint32 of length two
+        Random key with which to generate random variables in data domain.
+    linear_sampling_cg : callable
+        Implementation of the conjugate gradient algorithm and used to
+        apply the inverse of the metric.
+    linear_sampling_kwargs : dict
+        Additional keyword arguments passed on to `cg`.
+    non_linear_sampling_kwargs : dict
+        Additional keyword arguments passed on to the minimzer of the
+        non-linear potential.
+
+    Returns
+    -------
+    sample : tree-like structure
+        Sample of which the covariance is the inverse metric.
+
+    See also
+    --------
+    `Geometric Variational Inference`, Philipp Frank, Reimar Leike,
+    Torsten A. Enßlin, `<https://arxiv.org/abs/2105.10470>`_
+    `<https://doi.org/10.3390/e23070853>`_
+    """
+    if not isinstance(hamiltonian, StandardHamiltonian):
+        te = f"`hamiltonian` of invalid type; got '{type(hamiltonian)}'"
+        raise TypeError(te)
+    assert_arithmetics(primals)
+    from .energy_operators import Gaussian
+    from .optimize import minimize
+
+    inv_met_smpl, met_smpl = _sample_standard_hamiltonian(
+        hamiltonian,
+        primals,
+        key=key,
+        from_inverse=True,
+        cg=linear_sampling_cg,
+        cg_name=linear_sampling_name,
+        cg_kwargs=linear_sampling_kwargs
+    )
+
+    if isinstance(non_linear_sampling_kwargs, dict):
+        nls_kwargs = non_linear_sampling_kwargs
+    elif non_linear_sampling_kwargs is None:
+        nls_kwargs = {}
+    else:
+        te = (
+            "`non_linear_sampling_kwargs` of invalid type"
+            "{type(non_linear_sampling_kwargs)}"
+        )
+        raise TypeError(te)
+    nls_kwargs = {"name": non_linear_sampling_name, **nls_kwargs}
+    if "hessp" in nls_kwargs:
+        ve = "setting the hessian for an unknown function is invalid"
+        raise ValueError(ve)
+    # Abort early if non-linear sampling is effectively disabled
+    if nls_kwargs.get("maxiter") == 0:
+        if mirror_linear_sample:
+            return (inv_met_smpl, -inv_met_smpl)
+        return (inv_met_smpl, )
+
+    lh_trafo_at_p = hamiltonian.likelihood.transformation(primals)
+
+    def draw_non_linear_sample(lh, met_smpl, inv_met_smpl):
+        x0 = primals + inv_met_smpl
+
+        def g(x):
+            return x - primals + lh.left_sqrt_metric(
+                primals,
+                lh.transformation(x) - lh_trafo_at_p
+            )
+
+        r2_half = Gaussian(met_smpl) @ g  # (g - met_smpl)**2 / 2
+
+        options = nls_kwargs.copy()
+        options["hessp"] = r2_half.metric
+
+        opt_state = minimize(
+            r2_half, x0=x0, method=non_linear_sampling_method, options=options
+        )
+
+        return opt_state.x, opt_state.status
+
+    smpl1, smpl1_status = draw_non_linear_sample(
+        hamiltonian.likelihood, met_smpl, inv_met_smpl
+    )
+    cond_raise(
+        (smpl1_status is not None) & (smpl1_status < 0),
+        ValueError("S: failed to invert map")
+    )
+    if not mirror_linear_sample:
+        return (smpl1 - primals, )
+    smpl2, smpl2_status = draw_non_linear_sample(
+        hamiltonian.likelihood, -met_smpl, -inv_met_smpl
+    )
+    cond_raise(
+        (smpl2_status is not None) & (smpl2_status < 0),
+        ValueError("S: failed to invert map")
+    )
+    return (smpl1 - primals, smpl2 - primals)
+
+
+@register_pytree_node_class
+class SampleIter():
+    """Storage class for samples with some convenience methods for applying
+    operators of them
+
+    This class is used to store samples for the Variational Inference schemes
+    MGVI and geoVI where samples are defined relative to some expansion point
+    (a.k.a. latent mean or offset).
+
+    See also
+    --------
+    `Geometric Variational Inference`, Philipp Frank, Reimar Leike,
+    Torsten A. Enßlin, `<https://arxiv.org/abs/2105.10470>`_
+    `<https://doi.org/10.3390/e23070853>`_
+
+    `Metric Gaussian Variational Inference`, Jakob Knollmüller,
+    Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
+    """
+    def __init__(
+        self,
+        *,
+        mean: P = None,
+        samples: Sequence[P],
+        linearly_mirror_samples: bool = False,
+    ):
+        self._samples = tuple(samples)
+        self._mean = mean
+
+        self._n_samples = len(self._samples)
+        if linearly_mirror_samples == True:
+            self._n_samples *= 2
+        self._linearly_mirror_samples = linearly_mirror_samples
+        # TODO/IDEA: Implement a transposed SampleIter object (SampleStack)
+        # akin to `vmap_forest_mean`
+
+    def __iter__(self):
+        for s in self._samples:
+            yield self._mean + s if self._mean is not None else s
+            if self._linearly_mirror_samples:
+                yield self._mean - s if self._mean is not None else -s
+
+    def __len__(self):
+        return self._n_samples
+
+    @property
+    def n_samples(self):
+        """Total number of samples, equivalent to the length of the object"""
+        return len(self)
+
+    def at(self, mean):
+        """Updates the offset (usually the latent mean) of all samples"""
+        return SampleIter(
+            mean=mean,
+            samples=self._samples,
+            linearly_mirror_samples=self._linearly_mirror_samples
+        )
+
+    @property
+    def first(self):
+        """Convenience method to easily retrieve a sample (the first one)"""
+        if self._mean is not None:
+            return self._mean + self._samples[0]
+        return self._samples[0]
+
+    def apply(self, call: Callable, *args, **kwargs):
+        """Applies an operator over all samples, yielding a list of outputs
+
+        Internally, the call is `vmap`-ed over the samples for additional
+        efficiency.
+        """
+        if set(kwargs.keys()) | {"in_axes"} != {"in_axes"}:
+            raise ValueError(f"invalid keyword arguments {kwargs}")
+
+        # TODO: vmap is significantly slower than looping over the samples
+        # for an extremely high dimensional problem.
+        in_axes = kwargs.get("in_axes", (0, ))
+        return map_forest(call, in_axes=in_axes)(tuple(self), *args)
+
+    def mean(self, call: Callable, *args, **kwargs):
+        """Applies an operator over all samples and averages the results
+
+        Internally, the call is `vmap`-ed over the samples for additional
+        efficiency.
+        """
+        if set(kwargs.keys()) | {"in_axes"} != {"in_axes"}:
+            raise ValueError(f"invalid keyword arguments {kwargs}")
+
+        # TODO: vmap is significantly slower than looping over the samples
+        # for an extremely high dimensional problem.
+        in_axes = kwargs.get("in_axes", (0, ))
+        return map_forest_mean(call, in_axes=in_axes)(tuple(self), *args)
+
+    def tree_flatten(self):
+        return ((self._mean, self._samples), (self._linearly_mirror_samples, ))
+
+    @classmethod
+    def tree_unflatten(cls, aux, children):
+        if len(aux) != 1 or len(children) != 2:
+            raise ValueError()
+        return cls(
+            mean=children[0],
+            samples=children[1],
+            linearly_mirror_samples=aux[0]
+        )
+
+
+def MetricKL(
+    hamiltonian: StandardHamiltonian,
+    primals,
+    n_samples: int,
+    key,
+    mirror_samples: bool = True,
+    sample_mapping: Union[str, Callable] = 'lax',
+    linear_sampling_cg: Callable = conjugate_gradient.static_cg,
+    linear_sampling_name: Optional[str] = None,
+    linear_sampling_kwargs: Optional[dict] = None,
+) -> SampleIter:
+    """Provides the sampled Kullback-Leibler divergence between a distribution
+    and a Metric Gaussian.
+
+    A Metric Gaussian is used to approximate another probability distribution.
+    It is a Gaussian distribution that uses the Fisher information metric of
+    the other distribution at the location of its mean to approximate the
+    variance. In order to infer the mean, a stochastic estimate of the
+    Kullback-Leibler divergence is minimized. This estimate is obtained by
+    sampling the Metric Gaussian at the current mean. During minimization these
+    samples are kept constant and only the mean is updated. Due to the
+    typically nonlinear structure of the true distribution these samples have
+    to be updated eventually by re-instantiating the Metric Gaussian again. For
+    the true probability distribution the standard parametrization is assumed.
+
+    Parameters
+    ----------
+
+    hamiltonian : :class:`nifty8.src.re.likelihood.StandardHamiltonian`
+        Hamiltonian of the approximated probability distribution.
+    primals : :class:`nifty8.re.field.Field`
+        Expansion point of the coordinate transformation.
+    n_samples : integer
+        Number of samples used to stochastically estimate the KL.
+    key : DeviceArray
+        A PRNG-key.
+    mirror_samples : boolean
+        Whether the mirrored version of the drawn samples are also used.
+        If true, the number of used samples doubles.
+        Mirroring samples stabilizes the KL estimate as extreme
+        sample variation is counterbalanced.
+        Default is True.
+    sample_mapping : string, callable
+        Can be either a string-key to a mapping function or a mapping function
+        itself. The function is used to map the drawing of samples. Possible
+        string-keys are:
+
+        keys                -       functions
+        -------------------------------------
+        'pmap' or 'p'       -       jax.pmap
+        'lax.map' or 'lax'  -       jax.lax.map
+
+        In case sample_mapping is passed as a function, it should produce a
+        mapped function f_mapped of a general function f as: `f_mapped =
+        sample_mapping(f)`
+    linear_sampling_cg : callable
+        Implementation of the conjugate gradient algorithm and used to
+        apply the inverse of the metric.
+    linear_sampling_name : string, optional
+        'name'-keyword-argument passed to `linear_sampling_cg`.
+    linear_sampling_kwargs : dict, optional
+        Additional keyword arguments passed on to `linear_sampling_cg`.
+
+    See also
+    --------
+    `Metric Gaussian Variational Inference`, Jakob Knollmüller,
+    Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
+    """
+    if not isinstance(hamiltonian, StandardHamiltonian):
+        te = f"`hamiltonian` of invalid type; got '{type(hamiltonian)}'"
+        raise TypeError(te)
+    assert_arithmetics(primals)
+
+    draw = partial(
+        sample_standard_hamiltonian,
+        hamiltonian=hamiltonian,
+        primals=primals,
+        cg=linear_sampling_cg,
+        cg_name=linear_sampling_name,
+        cg_kwargs=linear_sampling_kwargs
+    )
+    subkeys = random.split(key, n_samples)
+    if isinstance(sample_mapping, str):
+        if sample_mapping == 'pmap' or sample_mapping == 'p':
+            sample_mapping = jax.pmap
+        elif sample_mapping == 'lax.map' or sample_mapping == 'lax':
+            sample_mapping = partial(partial, lax.map)
+        else:
+            ve = (
+                f"{sample_mapping} is not an accepted key to a mapping function"
+                "; please pass function directly"
+            )
+            raise ValueError(ve)
+
+    elif not callable(sample_mapping):
+        te = (
+            f"invalid `sample_mapping` of type {type(sample_mapping)!r}"
+            "; expected string or callable"
+        )
+        raise TypeError(te)
+
+    samples_stack = sample_mapping(lambda k: draw(key=k))(subkeys)
+
+    return SampleIter(
+        mean=primals,
+        samples=unstack(samples_stack),
+        linearly_mirror_samples=mirror_samples
+    )
+
+
+def GeoMetricKL(
+    hamiltonian: StandardHamiltonian,
+    primals,
+    n_samples: int,
+    key,
+    mirror_samples: bool = True,
+    linear_sampling_cg: Callable = conjugate_gradient.static_cg,
+    linear_sampling_name: Optional[str] = None,
+    linear_sampling_kwargs: Optional[dict] = None,
+    non_linear_sampling_method: str = "NewtonCG",
+    non_linear_sampling_name: Optional[str] = None,
+    non_linear_sampling_kwargs: Optional[dict] = None,
+) -> SampleIter:
+    """Provides the sampled Kullback-Leibler used in geometric Variational
+    Inference (geoVI).
+
+    In geoVI a probability distribution is approximated with a standard normal
+    distribution in the canonical coordinate system of the Riemannian manifold
+    associated with the metric of the other distribution. The coordinate
+    transformation is approximated by expanding around a point. In order to
+    infer the expansion point, a stochastic estimate of the Kullback-Leibler
+    divergence is minimized. This estimate is obtained by sampling from the
+    approximation using the current expansion point. During minimization these
+    samples are kept constant and only the expansion point is updated. Due to
+    the typically nonlinear structure of the true distribution these samples
+    have to be updated eventually by re-instantiating the geometric Gaussian
+    again. For the true probability distribution the standard parametrization
+    is assumed.
+
+    See also
+    --------
+    `Geometric Variational Inference`, Philipp Frank, Reimar Leike,
+    Torsten A. Enßlin, `<https://arxiv.org/abs/2105.10470>`_
+    `<https://doi.org/10.3390/e23070853>`_
+    """
+    if not isinstance(hamiltonian, StandardHamiltonian):
+        te = f"`hamiltonian` of invalid type; got '{type(hamiltonian)}'"
+        raise TypeError(te)
+    assert_arithmetics(primals)
+
+    draw = partial(
+        geometrically_sample_standard_hamiltonian,
+        hamiltonian=hamiltonian,
+        primals=primals,
+        mirror_linear_sample=mirror_samples,
+        linear_sampling_cg=linear_sampling_cg,
+        linear_sampling_name=linear_sampling_name,
+        linear_sampling_kwargs=linear_sampling_kwargs,
+        non_linear_sampling_method=non_linear_sampling_method,
+        non_linear_sampling_name=non_linear_sampling_name,
+        non_linear_sampling_kwargs=non_linear_sampling_kwargs
+    )
+    subkeys = random.split(key, n_samples)
+    # TODO: Make `geometrically_sample_standard_hamiltonian` jit-able
+    # samples_stack = lax.map(lambda k: draw(key=k), subkeys)
+    # Unpack tuple of samples
+    # samples_stack = tree_map(
+    #     lambda a: a.reshape((-1, ) + a.shape[2:]), samples_stack
+    # )
+    # samples = unstack(samples_stack)
+    samples = tuple(s for ss in map(lambda k: draw(key=k), subkeys) for s in ss)
+
+    return SampleIter(
+        mean=primals, samples=samples, linearly_mirror_samples=False
+    )
+
+
+def mean_value_and_grad(ham: Callable, sample_mapping='vmap', *args, **kwargs):
+    """Thin wrapper around `value_and_grad` and the provided sample mapping
+    function, e.g. `vmap` to apply a cost function to a mean and a list of
+    residual samples.
+
+    Parameters
+    ----------
+
+    ham : :class:`nifty8.src.re.likelihood.StandardHamiltonian`
+        Hamiltonian of the approximated probability distribution,
+        of which the mean value and the mean gradient are to be computed.
+    sample_mapping : string, callable
+        Can be either a string-key to a mapping function or a mapping function
+        itself. The function is used to map the drawing of samples. Possible
+        string-keys are:
+
+        keys                -       functions
+        -------------------------------------
+        'vmap' or 'v'       -       jax.vmap
+        'pmap' or 'p'       -       jax.pmap
+        'lax.map' or 'lax'  -       jax.lax.map
+
+        In case sample_mapping is passed as a function, it should produce a
+        mapped function f_mapped of a general function f as: `f_mapped =
+        sample_mapping(f)`
+    """
+    from jax import value_and_grad
+    vg = value_and_grad(ham, *args, **kwargs)
+
+    def mean_vg(
+        primals: P,
+        primals_samples: Union[None, Sequence[P], SampleIter] = None,
+        **primals_kw
+    ) -> Tuple[Any, P]:
+        ham_vg = partial(vg, **primals_kw)
+        if primals_samples is None:
+            return ham_vg(primals)
+
+        if not isinstance(primals_samples, SampleIter):
+            primals_samples = SampleIter(samples=primals_samples)
+        return map_forest_mean(ham_vg, mapping=sample_mapping, in_axes=(0, ))(
+            tuple(primals_samples.at(primals))
+        )
+
+    return mean_vg
+
+
+def mean_hessp(ham: Callable, *args, **kwargs):
+    """Thin wrapper around `jvp`, `grad` and `vmap` to apply a binary method to
+    a primal mean, a tangent and a list of residual primal samples.
+    """
+    from jax import jvp, grad
+    jac = grad(ham, *args, **kwargs)
+
+    def mean_hp(
+        primals: P,
+        tangents: Any,
+        primals_samples: Union[None, Sequence[P], SampleIter] = None,
+        **primals_kw
+    ) -> P:
+        if primals_samples is None:
+            _, hp = jvp(partial(jac, **primals_kw), (primals, ), (tangents, ))
+            return hp
+
+        if not isinstance(primals_samples, SampleIter):
+            primals_samples = SampleIter(samples=primals_samples)
+        return map_forest_mean(
+            partial(mean_hp, primals_samples=None, **primals_kw),
+            in_axes=(0, None)
+        )(tuple(primals_samples.at(primals)), tangents)
+
+    return mean_hp
+
+
+def mean_metric(metric: Callable):
+    """Thin wrapper around `vmap` to apply a binary method to a primal mean, a
+    tangent and a list of residual primal samples.
+    """
+    def mean_met(
+        primals: P,
+        tangents: Any,
+        primals_samples: Union[None, Sequence[P], SampleIter] = None,
+        **primals_kw
+    ) -> P:
+        if primals_samples is None:
+            return metric(primals, tangents, **primals_kw)
+
+        if not isinstance(primals_samples, SampleIter):
+            primals_samples = SampleIter(samples=primals_samples)
+        return map_forest_mean(
+            partial(metric, **primals_kw), in_axes=(0, None)
+        )(tuple(primals_samples.at(primals)), tangents)
+
+    return mean_met
diff --git a/src/re/lanczos.py b/src/re/lanczos.py
new file mode 100644
index 0000000000000000000000000000000000000000..f684246fe3a529166a4036415325394e6762898b
--- /dev/null
+++ b/src/re/lanczos.py
@@ -0,0 +1,131 @@
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial
+from typing import Callable, Optional, Union
+
+import jax
+from jax import numpy as jnp
+from jax import random
+
+from .forest_util import ShapeWithDtype
+
+
+def lanczos_tridiag(
+    mat: Callable, shape_dtype_struct: ShapeWithDtype, order: int,
+    key: jnp.ndarray
+):
+    """Compute the Lanczos decomposition into a tri-diagonal matrix and its
+    corresponding orthonormal projection matrix.
+    """
+    tridiag = jnp.zeros((order, order), dtype=shape_dtype_struct.dtype)
+    vecs = jnp.zeros(
+        (order, ) + shape_dtype_struct.shape, dtype=shape_dtype_struct.dtype
+    )
+
+    v = random.normal(key, shape=shape_dtype_struct.shape)
+    v = v / jnp.linalg.norm(v)
+    vecs = vecs.at[0].set(v)
+
+    # Zeroth iteration
+    w = mat(v)
+    if w.shape != shape_dtype_struct.shape:
+        ve = f"shape of `mat(v)` {w.shape!r} incompatible with {shape_dtype_struct}"
+        raise ValueError(ve)
+    alpha = jnp.dot(w, v)
+    tridiag = tridiag.at[(0, 0)].set(alpha)
+    w -= alpha * v
+    beta = jnp.linalg.norm(w)
+
+    tridiag = tridiag.at[(0, 1)].set(beta)
+    tridiag = tridiag.at[(1, 0)].set(beta)
+    vecs = vecs.at[1].set(w / beta)
+
+    def reortho_step(j, state):
+        vecs, w = state
+
+        tau = vecs[j, :].reshape(shape_dtype_struct.shape)
+        coeff = jnp.dot(w, tau)
+        w -= coeff * tau
+        return vecs, w
+
+    def lanczos_step(i, state):
+        tridiag, vecs, beta = state
+
+        v = vecs[i, :].reshape(shape_dtype_struct.shape)
+        v_old = vecs[i - 1, :].reshape(shape_dtype_struct.shape)
+
+        w = mat(v) - beta * v_old
+        alpha = jnp.dot(w, v)
+        tridiag = tridiag.at[(i, i)].set(alpha)
+        w -= alpha * v
+
+        # Full reorthogonalization
+        vecs, w = jax.lax.fori_loop(0, i, reortho_step, (vecs, w))
+
+        # TODO: Raise if lanczos vectors are independent i.e. `beta` small?
+        beta = jnp.linalg.norm(w)
+
+        tridiag = tridiag.at[(i, i + 1)].set(beta)
+        tridiag = tridiag.at[(i + 1, i)].set(beta)
+        vecs = vecs.at[i + 1].set(w / beta)
+
+        return tridiag, vecs, beta
+
+    tridiag, vecs, beta = jax.lax.fori_loop(
+        1, order - 1, lanczos_step, (tridiag, vecs, beta)
+    )
+
+    # Final tridiag value and reorthogonalization
+    v = vecs[order - 1, :].reshape(shape_dtype_struct.shape)
+    v_old = vecs[order - 2, :].reshape(shape_dtype_struct.shape)
+    w = mat(v) - beta * v_old
+    alpha = jnp.dot(w, v)
+    tridiag = tridiag.at[(order - 1, order - 1)].set(alpha)
+    w -= alpha * v
+    vecs, w = jax.lax.fori_loop(0, order - 1, reortho_step, (vecs, w))
+
+    return (tridiag, vecs)
+
+
+def stochastic_logdet_from_lanczos(
+    tridiag_stack: jnp.ndarray, matrix_shape0: int, func: Callable = jnp.log
+):
+    """Computes a stochastic estimate of the log-determinate of a matrix using
+    its Lanczos decomposition.
+
+    Implemented via the stoachstic Lanczos quadrature.
+    """
+    eig_vals, eig_vecs = jnp.linalg.eigh(tridiag_stack)
+    # TODO: Mask Eigenvalues <= 0?
+
+    num_random_probes = tridiag_stack.shape[0]
+
+    eig_ves_first_component = eig_vecs[..., 0, :]
+    func_of_eig_vals = func(eig_vals)
+
+    dot_products = jnp.sum(eig_ves_first_component**2 * func_of_eig_vals)
+    return matrix_shape0 / float(num_random_probes) * dot_products
+
+
+def stochastic_lq_logdet(
+    mat: Union[jnp.ndarray, Callable],
+    order: int,
+    n_samples: int,
+    key: Union[int, jnp.ndarray],
+    *,
+    shape0: Optional[int] = None,
+    dtype=None
+):
+    """Computes a stochastic estimate of the log-determinate of a matrix using
+    the stoachstic Lanczos quadrature algorithm.
+    """
+    shape0 = shape0 if shape0 is not None else mat.shape[0]
+    mat = mat.__matmul__ if not hasattr(mat, "__call__") else mat
+    if not isinstance(key, jnp.ndarray):
+        key = random.PRNGKey(key)
+    keys = random.split(key, n_samples)
+
+    lanczos = partial(lanczos_tridiag, mat, ShapeWithDtype(shape0, dtype))
+    tridiags, _ = jax.vmap(lanczos, in_axes=(None, 0),
+                           out_axes=(0, 0))(order, keys)
+    return stochastic_logdet_from_lanczos(tridiags, shape0)
diff --git a/src/re/likelihood.py b/src/re/likelihood.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4a9621e33486dde2b2060126f3f5f359da0b25c
--- /dev/null
+++ b/src/re/likelihood.py
@@ -0,0 +1,390 @@
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from typing import Any, Callable, Optional, TypeVar, Union
+
+from jax import numpy as jnp
+from jax import linear_transpose, linearize, vjp
+from jax.tree_util import Partial, tree_leaves
+
+from .forest_util import ShapeWithDtype, split
+from .sugar import is1d, isiterable, sum_of_squares, doc_from
+
+Q = TypeVar("Q")
+
+
+class Likelihood():
+    """Storage class for keeping track of the energy, the associated
+    left-square-root of the metric and the metric.
+    """
+    def __init__(
+        self,
+        energy: Callable[..., Union[jnp.ndarray, float]],
+        transformation: Optional[Callable[[Q], Any]] = None,
+        left_sqrt_metric: Optional[Callable[[Q, Q], Any]] = None,
+        metric: Optional[Callable[[Q, Q], Any]] = None,
+        lsm_tangents_shape=None
+    ):
+        """Instantiates a new likelihood.
+
+        Parameters
+        ----------
+        energy : callable
+            Function evaluating the negative log-likelihood.
+        transformation : callable, optional
+            Function evaluating the geometric transformation of the likelihood.
+        left_sqrt_metric : callable, optional
+            Function applying the left-square-root of the metric.
+        metric : callable, optional
+            Function applying the metric.
+        lsm_tangents_shape : tree-like structure of ShapeWithDtype, optional
+            Structure of the data space.
+        """
+        self._hamiltonian = energy
+        self._transformation = transformation
+        self._left_sqrt_metric = left_sqrt_metric
+        self._metric = metric
+
+        if lsm_tangents_shape is not None:
+            leaves = tree_leaves(lsm_tangents_shape)
+            if not all(
+                hasattr(e, "shape") and hasattr(e, "dtype") for e in leaves
+            ):
+                if is1d(lsm_tangents_shape
+                       ) or not isiterable(lsm_tangents_shape):
+                    lsm_tangents_shape = ShapeWithDtype(lsm_tangents_shape)
+                else:
+                    te = "`lsm_tangent_shapes` of invalid type"
+                    raise TypeError(te)
+        self._lsm_tan_shp = lsm_tangents_shape
+
+    def __call__(self, primals, **primals_kw):
+        """Convenience method to access the `energy` method of this instance.
+        """
+        return self.energy(primals, **primals_kw)
+
+    def energy(self, primals, **primals_kw):
+        """Applies the energy to `primals`.
+
+        Parameters
+        ----------
+        primals : tree-like structure
+            Position at which to evaluate the energy.
+        **primals_kw : Any
+           Additional arguments passed on to the energy.
+
+        Returns
+        -------
+        energy : float
+            Energy at the position `primals`.
+        """
+        return self._hamiltonian(primals, **primals_kw)
+
+    def metric(self, primals, tangents, **primals_kw):
+        """Applies the metric at `primals` to `tangents`.
+
+        Parameters
+        ----------
+        primals : tree-like structure
+            Position at which to evaluate the metric.
+        tangents : tree-like structure
+            Instance to which to apply the metric.
+        **primals_kw : Any
+           Additional arguments passed on to the metric.
+
+        Returns
+        -------
+        naturally_curved : tree-like structure
+            Tree-like structure of the same type as primals to which the metric
+            has been applied to.
+        """
+        if self._metric is None:
+            from jax import linear_transpose
+
+            lsm_at_p = Partial(self.left_sqrt_metric, primals, **primals_kw)
+            rsm_at_p = linear_transpose(
+                lsm_at_p, self.left_sqrt_metric_tangents_shape
+            )
+            res = lsm_at_p(*rsm_at_p(tangents))
+            return res
+        return self._metric(primals, tangents, **primals_kw)
+
+    def left_sqrt_metric(self, primals, tangents, **primals_kw):
+        """Applies the left-square-root of the metric at `primals` to
+        `tangents`.
+
+        Parameters
+        ----------
+        primals : tree-like structure
+            Position at which to evaluate the metric.
+        tangents : tree-like structure
+            Instance to which to apply the metric.
+        **primals_kw : Any
+           Additional arguments passed on to the LSM.
+
+        Returns
+        -------
+        metric_sample : tree-like structure
+            Tree-like structure of the same type as primals to which the
+            left-square-root of the metric has been applied to.
+        """
+        if self._left_sqrt_metric is None:
+            _, bwd = vjp(Partial(self.transformation, **primals_kw), primals)
+            res = bwd(tangents)
+            return res[0]
+        return self._left_sqrt_metric(primals, tangents, **primals_kw)
+
+    def transformation(self, primals, **primals_kw):
+        """Applies the coordinate transformation that maps into a coordinate
+        system in which the metric of the likelihood is the Euclidean metric.
+
+        Parameters
+        ----------
+        primals : tree-like structure
+            Position at which to transform.
+        **primals_kw : Any
+           Additional arguments passed on to the transformation.
+
+        Returns
+        -------
+        transformed_sample : tree-like structure
+            Structure of the same type as primals to which the geometric
+            transformation has been applied to.
+        """
+        if self._transformation is None:
+            nie = "`transformation` is not implemented"
+            raise NotImplementedError(nie)
+        return self._transformation(primals, **primals_kw)
+
+    @property
+    def left_sqrt_metric_tangents_shape(self):
+        """Retrieves the shape of the tangent domain of the
+        left-square-root of the metric.
+        """
+        return self._lsm_tan_shp
+
+    @property
+    def lsm_tangents_shape(self):
+        """Alias for `left_sqrt_metric_tangents_shape`."""
+        return self.left_sqrt_metric_tangents_shape
+
+    def new(
+        self, energy: Callable, transformation: Optional[Callable],
+        left_sqrt_metric: Optional[Callable], metric: Optional[Callable]
+    ):
+        """Instantiates a new likelihood with the same `lsm_tangents_shape`.
+
+        Parameters
+        ----------
+        energy : callable
+            Function evaluating the negative log-likelihood.
+        transformation : callable, optional
+            Function evaluating the geometric transformation of the
+            log-likelihood.
+        left_sqrt_metric : callable, optional
+            Function applying the left-square-root of the metric.
+        metric : callable, optional
+            Function applying the metric.
+        """
+        return Likelihood(
+            energy,
+            transformation=transformation,
+            left_sqrt_metric=left_sqrt_metric,
+            metric=metric,
+            lsm_tangents_shape=self._lsm_tan_shp
+        )
+
+    def jit(self, **kwargs):
+        """Returns a new likelihood with jit-compiled energy, left-square-root
+        of metric and metric.
+        """
+        from jax import jit
+
+        if self._transformation is not None:
+            j_trafo = jit(self.transformation, **kwargs)
+            j_lsm = jit(self.left_sqrt_metric, **kwargs)
+            j_m = jit(self.metric, **kwargs)
+        elif self._left_sqrt_metric is not None:
+            j_trafo = None
+            j_lsm = jit(self.left_sqrt_metric, **kwargs)
+            j_m = jit(self.metric, **kwargs)
+        elif self._metric is not None:
+            j_trafo, j_lsm = None, None
+            j_m = jit(self.metric, **kwargs)
+        else:
+            j_trafo, j_lsm, j_m = None, None, None
+
+        return self.new(
+            jit(self._hamiltonian, **kwargs),
+            transformation=j_trafo,
+            left_sqrt_metric=j_lsm,
+            metric=j_m
+        )
+
+    def __matmul__(self, f: Callable):
+        return self.matmul(f, left_argnames=(), right_argnames=None)
+
+    def matmul(self, f: Callable, left_argnames=(), right_argnames=None):
+        """Amend the function `f` to the right of the likelihood.
+
+        Parameters
+        ----------
+        f : Callable
+            Function which to amend to the likelihood.
+        left_argnames : tuple or None
+            Keys of the keyword arguments of the joined likelihood which
+            to pass to the original likelihood. Passing `None` indicates
+            the intent to absorb everything not explicitly absorbed by
+            the other call.
+        right_argnames : tuple or None
+            Keys of the keyword arguments of the joined likelihood which
+            to pass to the amended function. Passing `None` indicates
+            the intent to absorb everything not explicitly absorbed by
+            the other call.
+
+        Returns
+        -------
+        lh : Likelihood
+        """
+        if (left_argnames is None and right_argnames is None) or \
+        (left_argnames is not None and right_argnames is not None):
+            ve = "only one of `left_argnames` and `right_argnames` can be (not) `None`"
+            raise ValueError(ve)
+
+        def split_kwargs(**kwargs):
+            if left_argnames is None:  # right_argnames must be not None
+                right_kw, left_kw = split(kwargs, right_argnames)
+            else:  # right_argnames must be None
+                left_kw, right_kw = split(kwargs, left_argnames)
+            return left_kw, right_kw
+
+        def energy_at_f(primals, **primals_kw):
+            kw_l, kw_r = split_kwargs(**primals_kw)
+            return self.energy(f(primals, **kw_r), **kw_l)
+
+        def transformation_at_f(primals, **primals_kw):
+            kw_l, kw_r = split_kwargs(**primals_kw)
+            return self.transformation(f(primals, **kw_r), **kw_l)
+
+        def metric_at_f(primals, tangents, **primals_kw):
+            kw_l, kw_r = split_kwargs(**primals_kw)
+            # Note, judging by a simple benchmark on a large problem,
+            # transposing the JVP seems faster than computing the VJP again. On
+            # small problems there seems to be no measurable difference.
+            y, fwd = linearize(Partial(f, **kw_r), primals)
+            bwd = linear_transpose(fwd, primals)
+            return bwd(self.metric(y, fwd(tangents), **kw_l))[0]
+
+        def left_sqrt_metric_at_f(primals, tangents, **primals_kw):
+            kw_l, kw_r = split_kwargs(**primals_kw)
+            y, bwd = vjp(Partial(f, **kw_r), primals)
+            left_at_fp = self.left_sqrt_metric(y, tangents, **kw_l)
+            return bwd(left_at_fp)[0]
+
+        return self.new(
+            energy_at_f,
+            transformation=transformation_at_f,
+            left_sqrt_metric=left_sqrt_metric_at_f,
+            metric=metric_at_f
+        )
+
+    def __add__(self, other):
+        if not isinstance(other, Likelihood):
+            te = (
+                "object which to add to this instance is of invalid type"
+                f" {type(other)!r}"
+            )
+            raise TypeError(te)
+
+        def joined_hamiltonian(p, **pkw):
+            return self.energy(p, **pkw) + other.energy(p, **pkw)
+
+        def joined_metric(p, t, **pkw):
+            return self.metric(p, t, **pkw) + other.metric(p, t, **pkw)
+
+        joined_tangents_shape = {
+            "lh_left": self._lsm_tan_shp,
+            "lh_right": other._lsm_tan_shp
+        }
+
+        def joined_transformation(p, **pkw):
+            from warnings import warn
+
+            # FIXME
+            warn("adding transformations is untested", UserWarning)
+            return {
+                "lh_left": self.transformation(p, **pkw),
+                "lh_right": other.transformation(p, **pkw)
+            }
+
+        def joined_left_sqrt_metric(p, t, **pkw):
+            return self.left_sqrt_metric(
+                p, t["lh_left"], **pkw
+            ) + other.left_sqrt_metric(p, t["lh_right"], **pkw)
+
+        return Likelihood(
+            joined_hamiltonian,
+            transformation=joined_transformation,
+            left_sqrt_metric=joined_left_sqrt_metric,
+            metric=joined_metric,
+            lsm_tangents_shape=joined_tangents_shape
+        )
+
+
+class StandardHamiltonian():
+    """Joined object storage composed of a user-defined likelihood and a
+    standard normal likelihood as prior.
+    """
+    def __init__(
+        self,
+        likelihood: Likelihood,
+        _compile_joined: bool = False,
+        _compile_kwargs: dict = {}
+    ):
+        """Instantiates a new standardized Hamiltonian, i.e. a likelihood
+        joined with a standard normal prior.
+
+        Parameters
+        ----------
+        likelihood : Likelihood
+            Energy, left-square-root of metric and metric of the likelihood.
+        """
+        self._lh = likelihood
+
+        def joined_hamiltonian(primals, **primals_kw):
+            # Assume the first primals to be the parameters
+            return self._lh(primals, **
+                            primals_kw) + 0.5 * sum_of_squares(primals)
+
+        def joined_metric(primals, tangents, **primals_kw):
+            return self._lh.metric(primals, tangents, **primals_kw) + tangents
+
+        if _compile_joined:
+            from jax import jit
+            joined_hamiltonian = jit(joined_hamiltonian, **_compile_kwargs)
+            joined_metric = jit(joined_metric, **_compile_kwargs)
+        self._hamiltonian = joined_hamiltonian
+        self._metric = joined_metric
+
+    @doc_from(Likelihood.__call__)
+    def __call__(self, primals, **primals_kw):
+        return self.energy(primals, **primals_kw)
+
+    @doc_from(Likelihood.energy)
+    def energy(self, primals, **primals_kw):
+        return self._hamiltonian(primals, **primals_kw)
+
+    @doc_from(Likelihood.metric)
+    def metric(self, primals, tangents, **primals_kw):
+        return self._metric(primals, tangents, **primals_kw)
+
+    @property
+    def likelihood(self):
+        return self._lh
+
+    def jit(self, **kwargs):
+        return StandardHamiltonian(
+            self.likelihood.jit(**kwargs),
+            _compile_joined=True,
+            _compile_kwargs=kwargs
+        )
diff --git a/src/re/optimize.py b/src/re/optimize.py
new file mode 100644
index 0000000000000000000000000000000000000000..6699cf5da6c4df8cd42775bb628659466f2313af
--- /dev/null
+++ b/src/re/optimize.py
@@ -0,0 +1,481 @@
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+import sys
+from datetime import datetime
+from typing import Any, Callable, Dict, Mapping, NamedTuple, Optional, Tuple, Union
+
+from jax import lax
+from jax import numpy as jnp
+from jax.tree_util import Partial
+
+from . import conjugate_gradient
+from .forest_util import assert_arithmetics, common_type, size, where
+from .forest_util import norm as jft_norm
+from .sugar import sum_of_squares
+
+
+class OptimizeResults(NamedTuple):
+    """Object holding optimization results inspired by JAX and scipy.
+
+    Attributes
+    ----------
+    x : ndarray
+        The solution of the optimization.
+    success : bool
+        Whether or not the optimizer exited successfully.
+    status : int
+        Termination status of the optimizer. Its value depends on the
+        underlying solver. NOTE, in contrast to scipy there is no `message` for
+        details since strings are not statically memory bound.
+    fun, jac, hess: ndarray
+        Values of objective function, its Jacobian and its Hessian (if
+        available). The Hessians may be approximations, see the documentation
+        of the function in question.
+    hess_inv : object
+        Inverse of the objective function's Hessian; may be an approximation.
+        Not available for all solvers.
+    nfev, njev, nhev : int
+        Number of evaluations of the objective functions and of its
+        Jacobian and Hessian.
+    nit : int
+        Number of iterations performed by the optimizer.
+    """
+    x: Any
+    success: Union[bool, jnp.ndarray]
+    status: Union[int, jnp.ndarray]
+    fun: Any
+    jac: Any
+    hess: Optional[jnp.ndarray] = None
+    hess_inv: Optional[jnp.ndarray] = None
+    nfev: Union[None, int, jnp.ndarray] = None
+    njev: Union[None, int, jnp.ndarray] = None
+    nhev: Union[None, int, jnp.ndarray] = None
+    nit: Union[None, int, jnp.ndarray] = None
+    # Trust-Region specific slots
+    trust_radius: Union[None, float, jnp.ndarray] = None
+    jac_magnitude: Union[None, float, jnp.ndarray] = None
+    good_approximation: Union[None, bool, jnp.ndarray] = None
+
+
+def _prepare_vag_hessp(fun, jac, hessp,
+                       fun_and_grad) -> Tuple[Callable, Callable]:
+    """Returns a tuple of functions for computing the value-and-gradient and
+    the Hessian-Vector-Product.
+    """
+    from warnings import warn
+
+    if fun_and_grad is None:
+        if fun is not None and jac is not None:
+            uw = "computing the function together with its gradient would be faster"
+            warn(uw, UserWarning)
+
+            def fun_and_grad(x):
+                return (fun(x), jac(x))
+        elif fun is not None:
+            from jax import value_and_grad
+
+            fun_and_grad = value_and_grad(fun)
+        else:
+            ValueError("no function specified")
+
+    if hessp is None:
+        from jax import grad, jvp
+
+        jac = grad(fun) if jac is None else jac
+
+        def hessp(primals, tangents):
+            return jvp(jac, (primals, ), (tangents, ))[1]
+
+    return fun_and_grad, hessp
+
+
+def newton_cg(fun=None, x0=None, *args, **kwargs):
+    """Minimize a scalar-valued function using the Newton-CG algorithm."""
+    if x0 is not None:
+        assert_arithmetics(x0)
+    return _newton_cg(fun, x0, *args, **kwargs).x
+
+
+def _newton_cg(
+    fun=None,
+    x0=None,
+    *,
+    miniter=None,
+    maxiter=None,
+    energy_reduction_factor=0.1,
+    old_fval=None,
+    absdelta=None,
+    norm_ord=None,
+    xtol=1e-5,
+    jac: Optional[Callable] = None,
+    fun_and_grad=None,
+    hessp=None,
+    cg=conjugate_gradient._cg,
+    name=None,
+    time_threshold=None,
+    cg_kwargs=None
+):
+    norm_ord = 1 if norm_ord is None else norm_ord
+    miniter = 0 if miniter is None else miniter
+    maxiter = 200 if maxiter is None else maxiter
+    xtol = xtol * size(x0)
+
+    pos = x0
+    fun_and_grad, hessp = _prepare_vag_hessp(
+        fun, jac, hessp, fun_and_grad=fun_and_grad
+    )
+    cg_kwargs = {} if cg_kwargs is None else cg_kwargs
+    cg_name = name + "CG" if name is not None else None
+
+    energy, g = fun_and_grad(pos)
+    nfev, njev, nhev = 1, 1, 0
+    if jnp.isnan(energy):
+        raise ValueError("energy is Nan")
+    status = -1
+    i = 0
+    for i in range(1, maxiter + 1):
+        # Newton approximates the potential up to second order. The CG energy
+        # (`0.5 * x.T @ A @ x - x.T @ b`) and the approximation to the true
+        # potential in Newton thus live on comparable energy scales. Hence, the
+        # energy in a Newton minimization can be used to set the CG energy
+        # convergence criterion.
+        if old_fval and energy_reduction_factor:
+            cg_absdelta = energy_reduction_factor * (old_fval - energy)
+        else:
+            cg_absdelta = None if absdelta is None else absdelta / 100.
+        mag_g = jft_norm(g, ord=cg_kwargs.get("norm_ord", 1), ravel=True)
+        cg_resnorm = jnp.minimum(
+            0.5, jnp.sqrt(mag_g)
+        ) * mag_g  # taken from SciPy
+        default_kwargs = {
+            "absdelta": cg_absdelta,
+            "resnorm": cg_resnorm,
+            "norm_ord": 1,
+            "_within_newton": True,  # handle non-pos-def
+            "name": cg_name,
+            "time_threshold": time_threshold
+        }
+        cg_res = cg(Partial(hessp, pos), g, **{**default_kwargs, ** cg_kwargs})
+        nat_g, info = cg_res.x, cg_res.info
+        nhev += cg_res.nfev
+        if info is not None and info < 0:
+            raise ValueError("conjugate gradient failed")
+
+        naive_ls_it = 0
+        dd = nat_g  # negative descent direction
+        grad_scaling = 1.
+        ls_reset = False
+        for naive_ls_it in range(9):
+            new_pos = pos - grad_scaling * dd
+            new_energy, new_g = fun_and_grad(new_pos)
+            nfev, njev = nfev + 1, njev + 1
+            if new_energy <= energy:
+                break
+
+            grad_scaling /= 2
+            if naive_ls_it == 5:
+                ls_reset = True
+                gam = float(sum_of_squares(g))
+                curv = float(g.dot(hessp(pos, g)))
+                nhev += 1
+                grad_scaling = 1.
+                dd = gam / curv * g
+        else:
+            grad_scaling = 0.
+            nm = "N" if name is None else name
+            msg = f"{nm}: WARNING: Energy would increase; aborting"
+            print(msg, file=sys.stderr)
+            status = -1
+            break
+
+        energy_diff = energy - new_energy
+        old_fval = energy
+        energy = new_energy
+        pos = new_pos
+        g = new_g
+
+        descent_norm = grad_scaling * jft_norm(dd, ord=norm_ord, ravel=True)
+        if name is not None:
+            msg = (
+                f"{name}: →:{grad_scaling} ↺:{ls_reset} #∇²:{nhev:02d}"
+                f" |↘|:{descent_norm:.6e} 🞋:{xtol:.6e}"
+                f"\n{name}: Iteration {i} ⛰:{energy:+.6e} Δ⛰:{energy_diff:.6e}"
+                + (f" 🞋:{absdelta:.6e}" if absdelta is not None else "")
+            )
+            print(msg, file=sys.stderr)
+        if jnp.isnan(new_energy):
+            raise ValueError("energy is NaN")
+        min_cond = naive_ls_it < 2 and i > miniter
+        if absdelta is not None and 0. <= energy_diff < absdelta and min_cond:
+            status = 0
+            break
+        if descent_norm <= xtol and i > miniter:
+            status = 0
+            break
+        if time_threshold is not None and datetime.now() > time_threshold:
+            status = i
+            break
+    else:
+        status = i
+        nm = "N" if name is None else name
+        print(f"{nm}: Iteration Limit Reached", file=sys.stderr)
+    return OptimizeResults(
+        x=pos,
+        success=True,
+        status=status,
+        fun=energy,
+        jac=g,
+        nit=i,
+        nfev=nfev,
+        njev=njev,
+        nhev=nhev
+    )
+
+
+class _TrustRegionState(NamedTuple):
+    x: Any
+    converged: Union[bool, jnp.ndarray]
+    status: Union[int, jnp.ndarray]
+    fun: Any
+    jac: Any
+    nfev: Union[int, jnp.ndarray]
+    njev: Union[int, jnp.ndarray]
+    nhev: Union[int, jnp.ndarray]
+    nit: Union[int, jnp.ndarray]
+    trust_radius: Union[float, jnp.ndarray]
+    jac_magnitude: Union[float, jnp.ndarray]
+    old_fval: Union[float, jnp.ndarray]
+
+
+def _trust_ncg(
+    fun=None,
+    x0=None,
+    *,
+    maxiter: Optional[int] = None,
+    energy_reduction_factor=0.1,
+    old_fval=jnp.nan,
+    absdelta=None,
+    gtol: float = 1e-4,
+    max_trust_radius: Union[float, jnp.ndarray] = 1000.,
+    initial_trust_radius: Union[float, jnp.ndarray] = 1.0,
+    eta: Union[float, jnp.ndarray] = 0.15,
+    subproblem=conjugate_gradient._cg_steihaug_subproblem,
+    jac: Optional[Callable] = None,
+    hessp: Optional[Callable] = None,
+    fun_and_grad: Optional[Callable] = None,
+    subproblem_kwargs: Optional[Dict[str, Any]] = None,
+    name: Optional[str] = None
+) -> OptimizeResults:
+    maxiter = 200 if maxiter is None else maxiter
+
+    status = jnp.where(maxiter == 0, 1, 0)
+
+    if not (0 <= eta < 0.25):
+        raise Exception("invalid acceptance stringency")
+    # Exception("gradient tolerance must be positive")
+    status = jnp.where(gtol < 0., -1, status)
+    # Exception("max trust radius must be positive")
+    status = jnp.where(max_trust_radius <= 0, -1, status)
+    # ValueError("initial trust radius must be positive")
+    status = jnp.where(initial_trust_radius <= 0, -1, status)
+    # ValueError("initial trust radius must be less than the max trust radius")
+    status = jnp.where(initial_trust_radius >= max_trust_radius, -1, status)
+
+    common_dtp = common_type(x0)
+    eps = 6. * jnp.finfo(
+        common_dtp
+    ).eps  # Inspired by SciPy's NewtonCG minimzer
+
+    fun_and_grad, hessp = _prepare_vag_hessp(
+        fun, jac, hessp, fun_and_grad=fun_and_grad
+    )
+    subproblem_kwargs = {} if subproblem_kwargs is None else subproblem_kwargs
+    cg_name = name + "SP" if name is not None else None
+
+    f_0, g_0 = fun_and_grad(x0)
+    g_0_mag = jft_norm(
+        g_0, ord=subproblem_kwargs.get("norm_ord", 1), ravel=True
+    )
+    status = jnp.where(jnp.isfinite(g_0_mag), status, 2)
+    init_params = _TrustRegionState(
+        converged=False,
+        status=status,
+        nit=0,
+        x=x0,
+        fun=f_0,
+        jac=g_0,
+        jac_magnitude=g_0_mag,
+        nfev=1,
+        njev=1,
+        nhev=0,
+        trust_radius=initial_trust_radius,
+        old_fval=old_fval
+    )
+
+    def _trust_region_body_f(params: _TrustRegionState) -> _TrustRegionState:
+        x_k, g_k, g_k_mag = params.x, params.jac, params.jac_magnitude
+        i, f_k, old_fval = params.nit, params.fun, params.old_fval
+        tr = params.trust_radius
+
+        i += 1
+
+        if energy_reduction_factor:
+            cg_absdelta = energy_reduction_factor * (old_fval - f_k)
+        else:
+            cg_absdelta = None if absdelta is None else absdelta / 100.
+        cg_resnorm = jnp.minimum(0.5, jnp.sqrt(g_k_mag)) * g_k_mag
+        # TODO: add an internal success check for future subproblem approaches
+        # that might not be solvable
+        default_kwargs = {
+            "absdelta": cg_absdelta,
+            "resnorm": cg_resnorm,
+            "trust_radius": tr,
+            "norm_ord": 1,
+            "name": cg_name
+        }
+        sub_result = subproblem(
+            f_k, g_k, Partial(hessp, x_k),
+            **{**default_kwargs, **subproblem_kwargs}
+        )
+
+        pred_f_kp1 = sub_result.pred_f
+        x_kp1 = x_k + sub_result.step
+        f_kp1, g_kp1 = fun_and_grad(x_kp1)
+
+        actual_reduction = f_k - f_kp1
+        pred_reduction = f_k - pred_f_kp1
+
+        # update the trust radius according to the actual/predicted ratio
+        rho = actual_reduction / pred_reduction
+        tr_kp1 = jnp.where(rho < 0.25, tr * 0.25, tr)
+        tr_kp1 = jnp.where(
+            (rho > 0.75) & sub_result.hits_boundary,
+            jnp.minimum(2. * tr, max_trust_radius), tr_kp1
+        )
+
+        # compute norm to check for convergence
+        g_kp1_mag = jft_norm(
+            g_kp1, ord=subproblem_kwargs.get("norm_ord", 1), ravel=True
+        )
+
+        # if the ratio is high enough then accept the proposed step
+        f_kp1, x_kp1, g_kp1, g_kp1_mag = where(
+            rho > eta, (f_kp1, x_kp1, g_kp1, g_kp1_mag),
+            (f_k, x_k, g_k, g_k_mag)
+        )
+
+        # Check whether we arrived at the float precision
+        energy_eps = eps * jnp.abs(f_kp1)
+        converged = (actual_reduction <=
+                     energy_eps) & (actual_reduction > -energy_eps)
+
+        converged |= g_kp1_mag < gtol
+        if absdelta:
+            converged |= (rho > eta) & (actual_reduction >
+                                        0.) & (actual_reduction < absdelta)
+
+        status = jnp.where(converged, 0, params.status)
+        status = jnp.where(i >= maxiter, 1, status)
+        status = jnp.where(pred_reduction <= 0, 2, status)
+        params = _TrustRegionState(
+            converged=converged,
+            nit=i,
+            x=x_kp1,
+            fun=f_kp1,
+            jac=g_kp1,
+            jac_magnitude=g_kp1_mag,
+            nfev=params.nfev + sub_result.nfev + 1,
+            njev=params.njev + sub_result.njev + 1,
+            nhev=params.nhev + sub_result.nhev,
+            trust_radius=tr_kp1,
+            status=status,
+            old_fval=f_k
+        )
+        if name is not None:
+            from jax.experimental.host_callback import call
+
+            def pp(arg):
+                i = arg["i"]
+                msg = (
+                    "{name}: ↗:{tr:.6e} ⬤:{hit} ∝:{rho:.2e} #∇²:{nhev:02d}"
+                    "\n{name}: Iteration {i} ⛰:{energy:+.6e} Δ⛰:{energy_diff:.6e}"
+                    + (" 🞋:{absdelta:.6e}" if absdelta is not None else "") + (
+                        "\n{name}: Iteration Limit Reached"
+                        if i == maxiter else ""
+                    )
+                )
+                print(msg.format(name=name, **arg), file=sys.stderr)
+
+            printable_state = {
+                "i": i,
+                "energy": params.fun,
+                "energy_diff": actual_reduction,
+                "maxiter": maxiter,
+                "absdelta": absdelta,
+                "tr": params.trust_radius,
+                "rho": rho,
+                "nhev": params.nhev,
+                "hit": sub_result.hits_boundary
+            }
+            call(pp, printable_state, result_shape=None)
+        return params
+
+    def _trust_region_cond_f(params: _TrustRegionState) -> bool:
+        return jnp.logical_not(params.converged) & (params.status == 0)
+
+    state = lax.while_loop(
+        _trust_region_cond_f, _trust_region_body_f, init_params
+    )
+
+    return OptimizeResults(
+        success=state.converged & (state.status == 0),
+        nit=state.nit,
+        x=state.x,
+        fun=state.fun,
+        jac=state.jac,
+        nfev=state.nfev,
+        njev=state.njev,
+        nhev=state.nhev,
+        jac_magnitude=state.jac_magnitude,
+        trust_radius=state.trust_radius,
+        status=state.status
+    )
+
+
+def trust_ncg(fun=None, x0=None, *args, **kwargs):
+    if x0 is not None:
+        assert_arithmetics(x0)
+    return _trust_ncg(fun, x0, *args, **kwargs).x
+
+
+def minimize(
+    fun: Optional[Callable[..., float]],
+    x0,
+    args: Tuple = (),
+    *,
+    method: str,
+    tol: Optional[float] = None,
+    options: Optional[Mapping[str, Any]] = None
+) -> OptimizeResults:
+    """Minimize fun."""
+    assert_arithmetics(x0)
+    if options is None:
+        options = {}
+    if not isinstance(args, tuple):
+        te = f"args argument must be a tuple, got {type(args)!r}"
+        raise TypeError(te)
+
+    fun_with_args = fun
+    if args:
+        fun_with_args = lambda x: fun(x, *args)
+
+    if tol is not None:
+        raise ValueError("use solver-specific options")
+
+    if method.lower() in ('newton-cg', 'newtoncg', 'ncg'):
+        return _newton_cg(fun_with_args, x0, **options)
+    elif method.lower() in ('trust-ncg', 'trustncg'):
+        return _trust_ncg(fun_with_args, x0, **options)
+
+    raise ValueError(f"method {method} not recognized")
diff --git a/src/re/refine.py b/src/re/refine.py
new file mode 100644
index 0000000000000000000000000000000000000000..b95572e5481f17e8ec05ef37ccc998b62b6fb923
--- /dev/null
+++ b/src/re/refine.py
@@ -0,0 +1,511 @@
+#!/usr/bin/env python3
+
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial
+from math import ceil
+from string import ascii_uppercase
+from typing import Callable, Literal, Optional, Union
+
+from jax import vmap
+from jax import numpy as jnp
+from jax.lax import conv_general_dilated, dynamic_slice, fori_loop
+import numpy as np
+
+NDARRAY = Union[jnp.ndarray, np.ndarray]
+# N - batch dimension
+# C - feature dimension of data (channel)
+# I - input dimension of kernel
+# O - output dimension of kernel
+CONV_DIMENSION_NAMES = "".join(el for el in ascii_uppercase if el not in "NCIO")
+
+
+def _assert(assertion):
+    if not assertion:
+        raise AssertionError()
+
+
+def _get_cov_from_loc(kernel=None,
+                      cov_from_loc=None
+                     ) -> Callable[[NDARRAY, NDARRAY], NDARRAY]:
+    if cov_from_loc is None and callable(kernel):
+
+        def cov_from_loc_sngl(x, y):
+            return kernel(jnp.linalg.norm(x - y))
+
+        cov_from_loc = vmap(
+            vmap(cov_from_loc_sngl, in_axes=(None, 0)), in_axes=(0, None)
+        )
+    else:
+        if not callable(cov_from_loc):
+            ve = "exactly one of `cov_from_loc` or `kernel` must be set and callable"
+            raise ValueError(ve)
+    # TODO: benchmark whether using `triu_indices(n, k=1)` and
+    # `diag_indices(n)` is advantageous
+    return cov_from_loc
+
+
+def layer_refinement_matrices(
+    distances,
+    kernel: Optional[Callable] = None,
+    cov_from_loc: Optional[Callable] = None,
+    *,
+    _coarse_size: int = 3,
+    _fine_size: int = 2,
+    _fine_strategy: Literal["jump", "extend"] = "jump",
+    _with_zeros: bool = False,
+):
+    cov_from_loc = _get_cov_from_loc(kernel, cov_from_loc)
+    distances = jnp.asarray(distances)
+    # TODO: distances must be a tensor iff _coarse_size > 3
+    # TODO: allow different grid sizes for different axis
+    csz = int(_coarse_size)  # coarse size
+    if _coarse_size % 2 != 1:
+        raise ValueError("only odd numbers allowed for `_coarse_size`")
+    fsz = int(_fine_size)  # fine size
+    if _fine_size % 2 != 0:
+        raise ValueError("only even numbers allowed for `_fine_size`")
+
+    ndim = distances.size
+    csz_half = int((csz - 1) / 2)
+    gc = jnp.arange(-csz_half, csz_half + 1, dtype=float)
+    gc = distances.reshape(ndim, 1) * gc
+    gc = jnp.stack(jnp.meshgrid(*gc, indexing="ij"), axis=-1)
+    if _fine_strategy == "jump":
+        gf = jnp.arange(fsz, dtype=float) / fsz - 0.5 + 0.5 / fsz
+    elif _fine_strategy == "extend":
+        gf = jnp.arange(fsz, dtype=float) / 2 - 0.25 * (fsz - 1)
+    else:
+        raise ValueError(f"invalid `_fine_strategy`; got {_fine_strategy}")
+    gf = distances.reshape(ndim, 1) * gf
+    gf = jnp.stack(jnp.meshgrid(*gf, indexing="ij"), axis=-1)
+    # On the GPU a single `cov_from_loc` call is about twice as fast as three
+    # separate calls for coarse-coarse, fine-fine and coarse-fine.
+    coord = jnp.concatenate(
+        (gc.reshape(-1, ndim), gf.reshape(-1, ndim)), axis=0
+    )
+    cov = cov_from_loc(coord, coord)
+    cov_ff = cov[-fsz**ndim:, -fsz**ndim:]
+    cov_fc = cov[-fsz**ndim:, :-fsz**ndim]
+    cov_cc = cov[:-fsz**ndim, :-fsz**ndim]
+    cov_cc_inv = jnp.linalg.inv(cov_cc)
+
+    olf = cov_fc @ cov_cc_inv
+    # Also see Schur-Complement
+    if _with_zeros:
+        r = jnp.linalg.norm(gc.reshape(-1, ndim), axis=1)
+        r_cutoff = jnp.max(distances) * csz_half
+        # dampening is chosen somewhat arbitrarily
+        r_dampening = jnp.max(distances)**-ndim
+        olf_wgt_sphere = jnp.where(
+            r <= r_cutoff, 1.,
+            jnp.exp(-r_dampening * jnp.abs(r - r_cutoff)**ndim)
+        )
+        olf *= olf_wgt_sphere[jnp.newaxis, ...]
+        fine_kernel = cov_ff - olf @ cov_cc @ olf.T
+    else:
+        fine_kernel = cov_ff - cov_fc @ cov_cc_inv @ cov_fc.T
+    # Implicitly assume a white power spectrum beyond the numerics limit. Use
+    # the diagonal as estimate for the magnitude of the variance.
+    fine_kernel_fallback = jnp.diag(jnp.abs(jnp.diag(fine_kernel)))
+    # Never produce NaNs (https://github.com/google/jax/issues/1052)
+    fine_kernel = jnp.where(
+        jnp.all(jnp.diag(fine_kernel) > 0.), fine_kernel, fine_kernel_fallback
+    )
+    fine_kernel_sqrt = jnp.linalg.cholesky(fine_kernel)
+
+    return olf, fine_kernel_sqrt
+
+
+def refinement_matrices(
+    shape0,
+    depth,
+    distances0,
+    kernel: Optional[Callable] = None,
+    cov_from_loc: Optional[Callable] = None,
+    *,
+    _coarse_size: int = 3,
+    _fine_size: int = 2,
+    _fine_strategy: Literal["jump", "extend"] = "jump",
+    **kwargs,
+):
+    cov_from_loc = _get_cov_from_loc(kernel, cov_from_loc)
+
+    shape0 = np.atleast_1d(shape0)
+    distances0 = jnp.atleast_1d(distances0)
+    if shape0.shape != distances0.shape:
+        ve = (
+            f"shape of `shape0` {shape0.shape} is incompatible with"
+            f" shape of `distances0` {distances0.shape}"
+        )
+        raise ValueError(ve)
+    c0 = [d * jnp.arange(sz, dtype=float) for d, sz in zip(distances0, shape0)]
+    coord0 = jnp.stack(jnp.meshgrid(*c0, indexing="ij"), axis=-1)
+    coord0 = coord0.reshape(-1, len(shape0))
+    cov_sqrt0 = jnp.linalg.cholesky(cov_from_loc(coord0, coord0))
+
+    if _fine_strategy == "jump":
+        dist_by_depth = distances0 / _fine_size**jnp.arange(0, depth
+                                                           ).reshape(-1, 1)
+    elif _fine_strategy == "extend":
+        dist_by_depth = distances0 / 2**jnp.arange(0, depth).reshape(-1, 1)
+    else:
+        raise ValueError(f"invalid `_fine_strategy`; got {_fine_strategy}")
+    olaf = partial(
+        layer_refinement_matrices,
+        cov_from_loc=cov_from_loc,
+        _coarse_size=_coarse_size,
+        _fine_size=_fine_size,
+        _fine_strategy=_fine_strategy,
+        **kwargs
+    )
+    opt_lin_filter, kernel_sqrt = vmap(olaf, in_axes=0,
+                                       out_axes=(0, 0))(dist_by_depth)
+    return opt_lin_filter, (cov_sqrt0, kernel_sqrt)
+
+
+def _vmap_squeeze_first(fun, *args, **kwargs):
+    vfun = vmap(fun, *args, **kwargs)
+
+    def vfun_apply(*x):
+        return vfun(jnp.squeeze(x[0], axis=0), *x[1:])
+
+    return vfun_apply
+
+
+def refine_conv_general(
+    coarse_values,
+    excitations,
+    olf,
+    fine_kernel_sqrt,
+    precision=None,
+    _coarse_size: int = 3,
+    _fine_size: int = 2,
+    _fine_strategy: Literal["jump", "extend"] = "jump",
+):
+    ndim = np.ndim(coarse_values)
+    # Introduce an artificial channel dimension for the matrix product
+    # TODO: allow different grid sizes for different axis
+    csz = int(_coarse_size)  # coarse size
+    if _coarse_size % 2 != 1:
+        raise ValueError("only odd numbers allowed for `_coarse_size`")
+    fsz = int(_fine_size)  # fine size
+    if _fine_size % 2 != 0:
+        raise ValueError("only even numbers allowed for `_fine_size`")
+    if olf.shape[:-2] != fine_kernel_sqrt.shape[:-2]:
+        ve = (
+            "incompatible optimal linear filter (`olf`) and `fine_kernel_sqrt` shapes"
+            f"; got {olf.shape} and {fine_kernel_sqrt.shape}"
+        )
+        raise ValueError(ve)
+    if olf.ndim > 2:
+        irreg_shape = olf.shape[:-2]
+    elif olf.ndim == 2:
+        irreg_shape = (1, ) * ndim
+    else:
+        ve = f"invalid shape of optimal linear filter (`olf`); got {olf.shape}"
+        raise ValueError(ve)
+    olf = olf.reshape(
+        irreg_shape + (fsz**ndim, ) + (csz, ) * (ndim - 1) + (1, csz)
+    )
+    fine_kernel_sqrt = fine_kernel_sqrt.reshape(irreg_shape + (fsz**ndim, ) * 2)
+
+    if _fine_strategy == "jump":
+        window_strides = (1, ) * ndim
+        fine_init_shape = tuple(n - (csz - 1)
+                                for n in coarse_values.shape) + (fsz**ndim, )
+        fine_final_shape = tuple(
+            fsz * (n - (csz - 1)) for n in coarse_values.shape
+        )
+        convolution_slices = list(range(csz))
+    elif _fine_strategy == "extend":
+        window_strides = (fsz // 2, ) * ndim
+        fine_init_shape = tuple(
+            ceil((n - (csz - 1)) / (fsz // 2)) for n in coarse_values.shape
+        ) + (fsz**ndim, )
+        fine_final_shape = tuple(
+            fsz * ceil((n - (csz - 1)) / (fsz // 2))
+            for n in coarse_values.shape
+        )
+        convolution_slices = list(range(0, csz * fsz // 2, fsz // 2))
+
+        if fsz // 2 > csz:
+            ve = "extrapolation is not allowed (use `fine_size / 2 <= coarse_size`)"
+            raise ValueError(ve)
+    else:
+        raise ValueError(f"invalid `_fine_strategy`; got {_fine_strategy}")
+
+    if ndim > len(CONV_DIMENSION_NAMES):
+        ve = f"convolution for {ndim} dimensions not yet implemented"
+        raise ValueError(ve)
+    dim_names = CONV_DIMENSION_NAMES[:ndim]
+    conv = partial(
+        conv_general_dilated,
+        window_strides=window_strides,
+        padding="valid",
+        # channel-last layout is most efficient for vision models (at least in
+        # PyTorch)
+        dimension_numbers=(
+            f"N{dim_names}C", f"O{dim_names}I", f"N{dim_names}C"
+        ),
+        precision=precision,
+    )
+
+    c_shp_n1 = coarse_values.shape[-1]
+    c_slc_shp = (1, )
+    c_slc_shp += tuple(
+        c if i == 1 else csz
+        for i, c in zip(irreg_shape, coarse_values.shape[:-1])
+    )
+    c_slc_shp += (-1, csz)
+
+    fine = jnp.zeros(fine_init_shape)
+    PLC = -1 << 31  # integer placeholder outside of the here encountered regimes
+    irreg_indices = jnp.stack(
+        jnp.meshgrid(
+            *[
+                jnp.arange(sz) if sz != 1 else jnp.array([PLC])
+                for sz in irreg_shape
+            ],
+            indexing="ij"
+        ),
+        axis=-1
+    )
+
+    def single_refinement_step(i, fine: jnp.ndarray) -> jnp.ndarray:
+        irreg_idx = jnp.unravel_index(i, irreg_indices.shape[:-1])
+        _assert(
+            len(irreg_shape) == len(irreg_indices[irreg_idx]) ==
+            len(window_strides)
+        )
+        fine_init_idx = tuple(
+            idx if sz != 1 else slice(None)
+            for sz, idx in zip(irreg_shape, irreg_indices[irreg_idx])
+        )
+        # Make JAX/XLA happy with `dynamic_slice`
+        coarse_idx = tuple(
+            (ws * idx, csz) if sz != 1 else (0, cend)
+            for ws, sz, idx, cend in zip(
+                window_strides, irreg_shape, irreg_indices[irreg_idx],
+                coarse_values.shape
+            )
+        )
+        coarse_idx_select = partial(
+            dynamic_slice,
+            start_indices=list(zip(*coarse_idx))[0],
+            slice_sizes=list(zip(*coarse_idx))[1]
+        )
+
+        olf_at_i = jnp.squeeze(
+            olf[fine_init_idx],
+            axis=tuple(range(sum(i == 1 for i in irreg_shape)))
+        )
+        if irreg_shape[-1] == 1 and fine_init_shape[-1] != 1:
+            _assert(fine_init_idx[-1] == slice(None))
+            # loop over conv channel offsets to apply the filter matrix in a convolution
+            for i_f, i_c in enumerate(convolution_slices):
+                c = conv(
+                    coarse_idx_select(coarse_values)[..., i_c:c_shp_n1 -
+                                                     (c_shp_n1 - i_c) %
+                                                     csz].reshape(c_slc_shp),
+                    olf_at_i
+                )[0]
+                c = jnp.squeeze(
+                    c,
+                    axis=tuple(a for a, i in enumerate(irreg_shape) if i != 1)
+                )
+                toti = fine_init_idx[:-1] + (slice(i_f, None, csz), )
+                fine = fine.at[toti].set(c)
+        else:
+            _assert(
+                not isinstance(fine_init_idx[-1], slice) and
+                fine_init_idx[-1].ndim == 0
+            )
+            c = conv(
+                coarse_idx_select(coarse_values).reshape(c_slc_shp), olf_at_i
+            )[0]
+            c = jnp.squeeze(
+                c, axis=tuple(a for a, i in enumerate(irreg_shape) if i != 1)
+            )
+            fine = fine.at[fine_init_idx].set(c)
+
+        return fine
+
+    fine = fori_loop(
+        0, np.prod(irreg_indices.shape[:-1]), single_refinement_step, fine
+    )
+
+    matmul = partial(jnp.matmul, precision=precision)
+    for i in irreg_shape[::-1]:
+        if i != 1:
+            matmul = vmap(matmul, in_axes=(0, 0))
+        else:
+            matmul = _vmap_squeeze_first(matmul, in_axes=(None, 0))
+    m = matmul(fine_kernel_sqrt, excitations.reshape(fine_init_shape))
+    rm_axs = tuple(
+        ax for ax, i in enumerate(m.shape[len(irreg_shape):], len(irreg_shape))
+        if i == 1
+    )
+    fine += jnp.squeeze(m, axis=rm_axs)
+
+    fine = fine.reshape(fine.shape[:-1] + (fsz, ) * ndim)
+    ax_label = np.arange(2 * ndim)
+    ax_t = [e for els in zip(ax_label[:ndim], ax_label[ndim:]) for e in els]
+    fine = jnp.transpose(fine, axes=ax_t)
+
+    return fine.reshape(fine_final_shape)
+
+
+def refine_slice(
+    coarse_values,
+    excitations,
+    olf,
+    fine_kernel_sqrt,
+    precision=None,
+    _coarse_size: int = 3,
+    _fine_size: int = 2,
+    _fine_strategy: Literal["jump", "extend"] = "jump",
+):
+    ndim = np.ndim(coarse_values)
+    csz = int(_coarse_size)  # coarse size
+    if _coarse_size % 2 != 1:
+        raise ValueError("only odd numbers allowed for `_coarse_size`")
+    fsz = int(_fine_size)  # fine size
+    if _fine_size % 2 != 0:
+        raise ValueError("only even numbers allowed for `_fine_size`")
+
+    if olf.shape[:-2] != fine_kernel_sqrt.shape[:-2]:
+        ve = (
+            "incompatible optimal linear filter (`olf`) and `fine_kernel_sqrt` shapes"
+            f"; got {olf.shape} and {fine_kernel_sqrt.shape}"
+        )
+        raise ValueError(ve)
+    if olf.ndim > 2:
+        irreg_shape = olf.shape[:-2]
+    elif olf.ndim == 2:
+        irreg_shape = (1, ) * ndim
+    else:
+        ve = f"invalid shape of optimal linear filter (`olf`); got {olf.shape}"
+        raise ValueError(ve)
+    olf = olf.reshape(irreg_shape + (fsz**ndim, ) + (csz, ) * ndim)
+    fine_kernel_sqrt = fine_kernel_sqrt.reshape(irreg_shape + (fsz**ndim, ) * 2)
+
+    if _fine_strategy == "jump":
+        window_strides = (1, ) * ndim
+        fine_init_shape = tuple(n - (csz - 1)
+                                for n in coarse_values.shape) + (fsz**ndim, )
+        fine_final_shape = tuple(
+            fsz * (n - (csz - 1)) for n in coarse_values.shape
+        )
+    elif _fine_strategy == "extend":
+        window_strides = (fsz // 2, ) * ndim
+        fine_init_shape = tuple(
+            ceil((n - (csz - 1)) / (fsz // 2)) for n in coarse_values.shape
+        ) + (fsz**ndim, )
+        fine_final_shape = tuple(
+            fsz * ceil((n - (csz - 1)) / (fsz // 2))
+            for n in coarse_values.shape
+        )
+
+        if fsz // 2 > csz:
+            ve = "extrapolation is not allowed (use `fine_size / 2 <= coarse_size`)"
+            raise ValueError(ve)
+    else:
+        raise ValueError(f"invalid `_fine_strategy`; got {_fine_strategy}")
+
+    def matmul_with_window_into(x, y, idx):
+        return jnp.tensordot(
+            x,
+            dynamic_slice(y, idx, slice_sizes=(csz, ) * ndim),
+            axes=ndim,
+            precision=precision
+        )
+
+    filter_coarse = matmul_with_window_into
+    corr_fine = partial(jnp.matmul, precision=precision)
+    for i in irreg_shape[::-1]:
+        if i != 1:
+            filter_coarse = vmap(filter_coarse, in_axes=(0, None, 1))
+            corr_fine = vmap(corr_fine, in_axes=(0, 0))
+        else:
+            filter_coarse = _vmap_squeeze_first(filter_coarse, in_axes=(None, None, 1))
+            corr_fine = _vmap_squeeze_first(corr_fine, in_axes=(None, 0))
+
+    cv_idx = np.mgrid[tuple(
+        slice(None, sz - csz + 1, ws)
+        for sz, ws in zip(coarse_values.shape, window_strides)
+    )]
+    fine = filter_coarse(olf, coarse_values, cv_idx)
+
+    m = corr_fine(fine_kernel_sqrt, excitations.reshape(fine_init_shape))
+    rm_axs = tuple(
+        ax for ax, i in enumerate(m.shape[len(irreg_shape):], len(irreg_shape))
+        if i == 1
+    )
+    fine += jnp.squeeze(m, axis=rm_axs)
+
+    fine = fine.reshape(fine.shape[:-1] + (fsz, ) * ndim)
+    ax_label = np.arange(2 * ndim)
+    ax_t = [e for els in zip(ax_label[:ndim], ax_label[ndim:]) for e in els]
+    fine = jnp.transpose(fine, axes=ax_t)
+
+    return fine.reshape(fine_final_shape)
+
+
+def refine_conv(
+    coarse_values, excitations, olf, fine_kernel_sqrt, precision=None
+):
+    fine_m = vmap(
+        partial(jnp.convolve, mode="valid", precision=precision),
+        in_axes=(None, 0),
+        out_axes=0
+    )(coarse_values, olf[::-1])
+    fine_m = jnp.moveaxis(fine_m, (0, ), (1, ))
+    fine_std = vmap(jnp.matmul, in_axes=(None, 0))(
+        fine_kernel_sqrt, excitations.reshape(-1, fine_kernel_sqrt.shape[-1])
+    )
+
+    return (fine_m + fine_std).ravel()
+
+
+def refine_loop(
+    coarse_values, excitations, olf, fine_kernel_sqrt, precision=None
+):
+    fine_m = [
+        jnp.convolve(coarse_values, o, mode="valid", precision=precision)
+        for o in olf[::-1]
+    ]
+    fine_m = jnp.stack(fine_m, axis=1)
+    fine_std = vmap(jnp.matmul, in_axes=(None, 0))(
+        fine_kernel_sqrt, excitations.reshape(-1, fine_kernel_sqrt.shape[-1])
+    )
+
+    return (fine_m + fine_std).ravel()
+
+
+def refine_vmap(
+    coarse_values, excitations, olf, fine_kernel_sqrt, precision=None
+):
+    sh0 = coarse_values.shape[0]
+    conv = vmap(
+        partial(jnp.matmul, precision=precision), in_axes=(None, 0), out_axes=0
+    )
+    fine_m = jnp.zeros((coarse_values.size - 2, 2))
+    fine_m = fine_m.at[0::3].set(
+        conv(olf, coarse_values[:sh0 - sh0 % 3].reshape(-1, 3))
+    )
+    fine_m = fine_m.at[1::3].set(
+        conv(olf, coarse_values[1:sh0 - (sh0 - 1) % 3].reshape(-1, 3))
+    )
+    fine_m = fine_m.at[2::3].set(
+        conv(olf, coarse_values[2:sh0 - (sh0 - 2) % 3].reshape(-1, 3))
+    )
+
+    fine_std = vmap(jnp.matmul, in_axes=(None, 0))(
+        fine_kernel_sqrt, excitations.reshape(-1, fine_kernel_sqrt.shape[-1])
+    )
+
+    return (fine_m + fine_std).ravel()
+
+
+refine = refine_slice
diff --git a/src/re/refine_chart.py b/src/re/refine_chart.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0fbe7f6a4fb879664514482d68cbc959771b485
--- /dev/null
+++ b/src/re/refine_chart.py
@@ -0,0 +1,916 @@
+#!/usr/bin/env python3
+
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from collections import namedtuple
+from functools import partial
+from typing import Callable, Iterable, Literal, Optional, Tuple, Union
+
+from jax import numpy as jnp
+from jax import vmap
+import numpy as np
+
+from .refine import _get_cov_from_loc, refine
+from .refine_util import (
+    coarse2fine_distances,
+    coarse2fine_shape,
+    fine2coarse_distances,
+    fine2coarse_shape,
+    get_refinement_shapewithdtype,
+)
+
+DEPTH_RANGE = (0, 32)
+MAX_SIZE0 = 1024
+
+
+class CoordinateChart():
+    def __init__(
+        self,
+        min_shape: Optional[Iterable[int]] = None,
+        depth: Optional[int] = None,
+        *,
+        shape0: Optional[Iterable[int]] = None,
+        _coarse_size: int = 5,
+        _fine_size: int = 4,
+        _fine_strategy: Literal["jump", "extend"] = "extend",
+        rg2cart: Optional[Callable[[
+            Iterable,
+        ], Iterable]] = None,
+        cart2rg: Optional[Callable[[
+            Iterable,
+        ], Iterable]] = None,
+        regular_axes: Optional[Union[Iterable[int], Tuple]] = None,
+        irregular_axes: Optional[Union[Iterable[int], Tuple]] = None,
+        distances: Optional[Union[Iterable[float], float]] = None,
+        distances0: Optional[Union[Iterable[float], float]] = None,
+    ):
+        """Initialize a refinement chart.
+
+        Parameters
+        ----------
+        min_shape :
+            Minimal extent in pixels along each axes at the final refinement
+            level.
+        depth :
+            Number of refinement iterations.
+        shape0 :
+            Alternative to `min_shape` and specifies the extent in pixels along
+            each axes at the zeroth refinement level.
+        _coarse_size :
+            Number of coarse pixels which to refine to `_fine_size` fine
+            pixels.
+        _fine_size :
+            Number of fine pixels which to refine from `_coarse_size` coarse
+            pixels.
+        _fine_strategy :
+            Whether to space fine pixels solely within the centermost coarse
+            pixel ("jump"), or whether to always space them out s.t. each fine
+            pixels takes up half the Euclidean volume of a coarse pixel
+            ("extend").
+        rg2cart :
+            Function to translate Euclidean points on a regular coordinate
+            system to the Cartesian coordinate system of the modeled points.
+        cart2rg :
+            Inverse of `rg2cart`.
+        regular_axes :
+            Informs the coordinate chart on symmetries within the Cartesian
+            coordinate system of the modeled points. If specified, refinement
+            matrices are broadcasted as need instead of recomputed.
+        irregular_axes :
+            Negative of `regular_axes`. Specifying either is sufficient.
+        distances :
+            Special case of a coordinate chart in which the regular grid points
+            are merely stretched or compressed. `distances` are used to set the
+            distance between points along every axes at the final refinement
+            level.
+        distances0:
+            Same as `distances` except that `distances0` refers to the
+            distances along every axes at the zeroth refinement level.
+
+        Note
+        ----
+        The functions `rg2cart` and `cart2rg` are always w.r.t. the grid at
+        zero depth. In other words, it is straight forward to increase the
+        resolution of an existing chart by simply increasing its depth.
+        However, extending a grid spatially is more cumbersome and is best done
+        via `shape0`.
+        """
+        if depth is None:
+            if min_shape is None:
+                raise ValueError("specify `min_shape` to infer `depth`")
+            if shape0 is not None or distances0 is not None:
+                ve = "can not infer `depth` with `shape0` or `distances0` set"
+                raise ValueError(ve)
+            for depth in range(*DEPTH_RANGE):
+                shape0 = fine2coarse_shape(
+                    min_shape,
+                    depth=depth,
+                    ceil_sizes=True,
+                    _coarse_size=_coarse_size,
+                    _fine_size=_fine_size,
+                    _fine_strategy=_fine_strategy
+                )
+                if np.prod(shape0, dtype=int) <= MAX_SIZE0:
+                    break
+            else:
+                ve = f"unable to find suitable `depth`; please specify manually"
+                raise ValueError(ve)
+        if depth < 0:
+            raise ValueError(f"invalid `depth`; got {depth!r}")
+        self._depth = depth
+
+        if shape0 is None and min_shape is not None:
+            shape0 = fine2coarse_shape(
+                min_shape,
+                depth,
+                ceil_sizes=True,
+                _coarse_size=_coarse_size,
+                _fine_size=_fine_size,
+                _fine_strategy=_fine_strategy
+            )
+        elif shape0 is None:
+            raise ValueError("either `shape0` or `min_shape` must be specified")
+        self._shape0 = (shape0, ) if isinstance(shape0, int) else tuple(shape0)
+        self._shape = coarse2fine_shape(
+            shape0,
+            depth,
+            _coarse_size=_coarse_size,
+            _fine_size=_fine_size,
+            _fine_strategy=_fine_strategy
+        )
+
+        if _fine_strategy not in ("jump", "extend"):
+            ve = f"invalid `_fine_strategy`; got {_fine_strategy}"
+            raise ValueError(ve)
+
+        self._shape_at = partial(
+            coarse2fine_shape,
+            self.shape0,
+            _coarse_size=_coarse_size,
+            _fine_size=_fine_size,
+            _fine_strategy=_fine_strategy
+        )
+
+        self._coarse_size = int(_coarse_size)
+        self._fine_size = int(_fine_size)
+        self._fine_strategy = _fine_strategy
+
+        # Derived attributes
+        self._ndim = len(self.shape)
+        self._size = np.prod(self.shape, dtype=int)
+
+        if rg2cart is None and cart2rg is None:
+            if distances0 is None and distances is None:
+                distances = jnp.ones((self.ndim, ))
+                distances0 = fine2coarse_distances(
+                    distances,
+                    depth,
+                    _fine_size=_fine_size,
+                    _fine_strategy=_fine_strategy
+                )
+            elif distances0 is not None:
+                distances0 = jnp.broadcast_to(
+                    jnp.atleast_1d(distances0), (self.ndim, )
+                )
+                distances = coarse2fine_distances(
+                    distances0,
+                    depth,
+                    _fine_size=_fine_size,
+                    _fine_strategy=_fine_strategy
+                )
+            else:
+                distances = jnp.broadcast_to(
+                    jnp.atleast_1d(distances), (self.ndim, )
+                )
+                distances0 = fine2coarse_distances(
+                    distances,
+                    depth,
+                    _fine_size=_fine_size,
+                    _fine_strategy=_fine_strategy
+                )
+
+            def _rg2cart(x):
+                x = jnp.asarray(x)
+                return x * distances0.reshape((-1, ) + (1, ) * (x.ndim - 1))
+
+            def _cart2rg(x):
+                x = jnp.asarray(x)
+                return x / distances0.reshape((-1, ) + (1, ) * (x.ndim - 1))
+
+            if regular_axes is None and irregular_axes is None:
+                regular_axes = tuple(range(self.ndim))
+            self._rg2cart = _rg2cart
+            self._cart2rg = _cart2rg
+        elif rg2cart is not None and cart2rg is not None:
+            c0 = jnp.mgrid[tuple(slice(s) for s in self.shape0)]
+            if not all(
+                jnp.allclose(r, c) for r, c in zip(cart2rg(rg2cart(c0)), c0)
+            ):
+                raise ValueError("`cart2rg` is not the inverse of `rg2cart`")
+            self._rg2cart = rg2cart
+            self._cart2rg = cart2rg
+            distances = distances0 = None
+        else:
+            ve = "invalid combination of `cart2rg`, `rg2cart` and `distances`"
+            raise ValueError(ve)
+        self.distances = distances
+        self.distances0 = distances0
+
+        self.distances_at = partial(
+            coarse2fine_distances,
+            self.distances0,
+            _fine_size=_fine_size,
+            _fine_strategy=_fine_strategy
+        )
+
+        if regular_axes is None and irregular_axes is not None:
+            regular_axes = tuple(set(range(self.ndim)) - set(irregular_axes))
+        elif regular_axes is not None and irregular_axes is None:
+            irregular_axes = tuple(set(range(self.ndim)) - set(regular_axes))
+        elif regular_axes is None and irregular_axes is None:
+            regular_axes = ()
+            irregular_axes = tuple(range(self.ndim))
+        else:
+            if set(regular_axes) | set(irregular_axes) != set(range(self.ndim)):
+                ve = "`regular_axes` and `irregular_axes` do not span the full axes"
+                raise ValueError(ve)
+            if set(regular_axes) & set(irregular_axes) != set():
+                ve = "`regular_axes` and `irregular_axes` must be exclusive"
+                raise ValueError(ve)
+        self._regular_axes = tuple(regular_axes)
+        self._irregular_axes = tuple(irregular_axes)
+        if len(self.regular_axes) + len(self.irregular_axes) != self.ndim:
+            ve = (
+                f"length of regular_axes and irregular_axes"
+                f" ({len(self.regular_axes)} + {len(self.irregular_axes)} respectively)"
+                f" incompatible with overall dimension {self.ndim}"
+            )
+            raise ValueError(ve)
+
+        self._descr = {
+            "depth": self.depth,
+            "shape0": self.shape0,
+            "_coarse_size": self.coarse_size,
+            "_fine_size": self.fine_size,
+            "_fine_strategy": self.fine_strategy,
+        }
+        if distances0 is not None:
+            self._descr["distances0"] = tuple(distances0)
+        else:
+            self._descr["rg2cart"] = repr(rg2cart)
+            self._descr["cart2rg"] = repr(cart2rg)
+        self._descr["regular_axes"] = self.regular_axes
+
+    @property
+    def shape(self):
+        """Shape at the final refinement level"""
+        return self._shape
+
+    @property
+    def shape0(self):
+        """Shape at the zeroth refinement level"""
+        return self._shape0
+
+    @property
+    def size(self):
+        return self._size
+
+    @property
+    def ndim(self):
+        return self._ndim
+
+    @property
+    def depth(self):
+        return self._depth
+
+    @property
+    def coarse_size(self):
+        return self._coarse_size
+
+    @property
+    def fine_size(self):
+        return self._fine_size
+
+    @property
+    def fine_strategy(self):
+        return self._fine_strategy
+
+    @property
+    def regular_axes(self):
+        return self._regular_axes
+
+    @property
+    def irregular_axes(self):
+        return self._irregular_axes
+
+    def rg2cart(self, positions):
+        """Translates positions from the regular Euclidean coordinate system to
+        the (in general) irregular Cartesian coordinate system.
+
+        Parameters
+        ----------
+        positions :
+            Positions on a regular Euclidean coordinate system.
+
+        Returns
+        -------
+        positions :
+            Positions on an (in general) irregular Cartesian coordinate system.
+
+        Note
+        ----
+        This method is independent of the refinement level!
+        """
+        return self._rg2cart(positions)
+
+    def cart2rg(self, positions):
+        """Translates positions from the (in general) irregular Cartesian
+        coordinate system to the regular Euclidean coordinate system.
+
+        Parameters
+        ----------
+        positions :
+            Positions on an (in general) irregular Cartesian coordinate system.
+
+        Returns
+        -------
+        positions :
+            Positions on a regular Euclidean coordinate system.
+
+        Note
+        ----
+        This method is independent of the refinement level!
+        """
+        return self._cart2rg(positions)
+
+    def rgoffset(self, lvl: int) -> Tuple[float]:
+        """Calculate the offset on the regular Euclidean grid due to shrinking
+        of the grid with increasing refinement level.
+
+        Parameters
+        ----------
+        lvl :
+            Level of the refinement.
+
+        Returns
+        -------
+        offset :
+            The offset on the regular Euclidean grid along each axes.
+
+        Note
+        ----
+        Indices are assumed to denote the center of the pixels, i.e. the pixel
+        with index `0` is assumed to be at `(0., ) * ndim`.
+        """
+        csz = self.coarse_size  # abbreviations for readability
+        fsz = self.fine_size
+
+        leftmost_center = 0.
+        # Assume the indices denote the center of the pixels, i.e. the pixel
+        # with index 0 is at (0., ) * ndim
+        if self.fine_strategy == "jump":
+            # for i in range(lvl):
+            #     leftmost_center += ((csz - 1) / 2 - 0.5 + 0.5 / fsz) / fsz**i
+            lm0 = (csz - 1) / 2 - 0.5 + 0.5 / fsz
+            geo = (1. - fsz**
+                   -lvl) / (1. - 1. / fsz)  # sum(fsz**-i for i in range(lvl))
+            leftmost_center = lm0 * geo
+        elif self.fine_strategy == "extend":
+            # for i in range(lvl):
+            #     leftmost_center += ((csz - 1) / 2 - 0.25 * (fsz - 1)) / 2**i
+            lm0 = ((csz - 1) / 2 - 0.25 * (fsz - 1))
+            geo = (1. - 2.**-lvl) * 2.  # sum(fsz**-i for i in range(lvl))
+            leftmost_center = lm0 * geo
+        else:
+            raise AssertionError()
+        return tuple((leftmost_center, ) * self.ndim)
+
+    def ind2rg(self, indices: Iterable[Union[float, int]],
+               lvl: int) -> Tuple[float]:
+        """Converts pixel indices to a continuous regular Euclidean grid
+        coordinates.
+
+        Parameters
+        ----------
+        indices :
+            Indices of shape `(n_dim, n_indices)` into the NDArray at
+            refinement level `lvl` which to convert to points in our regular
+            Euclidean grid.
+        lvl :
+            Level of the refinement.
+
+        Returns
+        -------
+        rg :
+            Regular Euclidean grid coordinates of shape `(n_dim, n_indices)`.
+        """
+        offset = self.rgoffset(lvl)
+
+        if self.fine_strategy == "jump":
+            dvol = 1 / self.fine_size**lvl
+        elif self.fine_strategy == "extend":
+            dvol = 1 / 2**lvl
+        else:
+            raise AssertionError()
+        return tuple(off + idx * dvol for off, idx in zip(offset, indices))
+
+    def rg2ind(
+        self,
+        positions: Iterable[Union[float, int]],
+        lvl: int,
+        discretize: bool = True
+    ) -> Union[Tuple[float], Tuple[int]]:
+        """Converts continuous regular grid positions to pixel indices.
+
+        Parameters
+        ----------
+        positions :
+            Positions on the regular Euclidean coordinate system of shape
+            `(n_dim, n_indices)` at refinement level `lvl` which to convert to
+            indices in a NDArray at the refinement level `lvl`.
+        lvl :
+            Level of the refinement.
+        discretize :
+            Whether to round indices to the next closest integer.
+
+        Returns
+        -------
+        indices :
+            Indices into the NDArray at refinement level `lvl`.
+        """
+        offset = self.rgoffset(lvl)
+
+        if self.fine_strategy == "jump":
+            dvol = 1 / self.fine_size**lvl
+        elif self.fine_strategy == "extend":
+            dvol = 1 / 2**lvl
+        else:
+            raise AssertionError()
+        indices = tuple(pos / dvol - off for off, pos in zip(offset, positions))
+        if discretize:
+            indices = tuple(jnp.rint(idx).astype(jnp.int32) for idx in indices)
+        return indices
+
+    def ind2cart(self, indices: Iterable[Union[float, int]], lvl: int):
+        """Computes the Cartesian coordinates of a pixel given the indices of
+        it.
+
+        Parameters
+        ----------
+        indices :
+            Indices of shape `(n_dim, n_indices)` into the NDArray at
+            refinement level `lvl` which to convert to locations in our (in
+            general) irregular coordinate system of the modeled points.
+        lvl :
+            Level of the refinement.
+
+        Returns
+        -------
+        positions :
+            Positions in the (in general) irregular coordinate system of the
+            modeled points of shape `(n_dim, n_indices)`.
+        """
+        return self.rg2cart(self.ind2rg(indices, lvl))
+
+    def cart2ind(self, positions, lvl, discretize=True):
+        """Computes the indices of a pixel given the Cartesian coordinates of
+        it.
+
+        Parameters
+        ----------
+        positions :
+            Positions on the Cartesian (in general) irregular coordinate system
+            of the modeled points of shape `(n_dim, n_indices)` at refinement
+            level `lvl` which to convert to indices in a NDArray at the
+            refinement level `lvl`.
+        lvl :
+            Level of the refinement.
+        discretize :
+            Whether to round indices to the next closest integer.
+
+        Returns
+        -------
+        indices :
+            Indices into the NDArray at refinement level `lvl`.
+        """
+        return self.rg2ind(self.cart2rg(positions), lvl, discretize=discretize)
+
+    def shape_at(self, lvl):
+        """Retrieves the shape at a given refinement level `lvl`."""
+        return self._shape_at(lvl)
+
+    def level_of(self, shape: Tuple[int]):
+        """Finds the refinement level at which the number of grid points
+        equate.
+        """
+        if not isinstance(shape, tuple):
+            raise TypeError(f"invalid type of `shape`; got {type(shape)}")
+
+        for lvl in range(self.depth + 1):
+            if shape == self.shape_at(lvl):
+                return lvl
+        else:
+            raise ValueError(f"invalid shape {shape!r}")
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(**{self._descr})"
+
+    def __eq__(self, other):
+        return repr(self) == repr(other)
+
+
+RefinementMatrices = namedtuple(
+    "RefinementMatrices", ("filter", "propagator_sqrt", "cov_sqrt0")
+)
+
+
+class RefinementField():
+    def __init__(
+        self,
+        *args,
+        kernel: Optional[Callable] = None,
+        dtype=None,
+        skip0: bool = False,
+        **kwargs
+    ):
+        """Initialize an Iterative Charted Refinement (ICR) field.
+
+        There are multiple ways to initialize a charted refinement field. The
+        recommended way is to first instantiate a `CoordinateChart` and pass it
+        as first argument to this method. Alternatively, you may pass any and
+        all arguments of `CoordinateChart` also to this method and it will
+        instantiate the `CoordinateChart` for you and use it in the same way as
+        if directly specified.
+
+        Parameters
+        ----------
+        chart : CoordinateChart
+            The `CoordinateChart` with which to refine.
+        kernel :
+            Covariance kernel of the refinement field.
+        dtype :
+            Data-type of the excitations which to add during refining.
+        skip0 :
+            Whether to skip the first refinement level. This is useful to e.g.
+            stack multiple refinement fields on top of each other.
+        **kwargs :
+            Alternatively to `chart` any parameters accepted by
+            `CoordinateChart`.
+        """
+        self._kernel = kernel
+        self._dtype = dtype
+        self._skip0 = skip0
+
+        if len(args) > 0 and isinstance(args[0], CoordinateChart):
+            if kwargs:
+                raise TypeError(f"expected no keyword arguments, got {kwargs}")
+
+            if len(args) == 1:
+                self._chart, = args
+            elif len(args) == 2 and callable(args[1]) and kernel is None:
+                self._chart, self._kernel = args
+            elif len(args) == 3 and callable(
+                args[1]
+            ) and kernel is None and dtype is None:
+                self._chart, self._kernel, self._dtype = args
+            elif len(args) == 4 and callable(
+                args[1]
+            ) and kernel is None and dtype is None and skip0 == False:
+                self._chart, self._kernel, self._dtype, self._skip0 = args
+            else:
+                te = "got unexpected arguments in addition to CoordinateChart"
+                raise TypeError(te)
+        else:
+            self._chart = CoordinateChart(*args, **kwargs)
+
+    @property
+    def kernel(self):
+        """Yields the kernel specified during initialization or throw a
+        `TypeError`.
+        """
+        if self._kernel is None:
+            te = (
+                "either specify a fixed kernel during initialization of the"
+                f" {self.__class__.__name__} class or provide one here"
+            )
+            raise TypeError(te)
+        return self._kernel
+
+    @property
+    def dtype(self):
+        """Yields the data-type of the excitations."""
+        return jnp.float64 if self._dtype is None else self._dtype
+
+    @property
+    def skip0(self):
+        """Whether to skip the zeroth refinement"""
+        return self._skip0
+
+    @property
+    def chart(self):
+        """Associated `CoordinateChart` with which to iterative refine."""
+        return self._chart
+
+    def matrices(
+        self,
+        kernel: Optional[Callable] = None,
+        depth: Optional[int] = None,
+        skip0: Optional[bool] = None,
+        **kwargs
+    ) -> RefinementMatrices:
+        """Computes the refinement matrices namely the optimal linear filter
+        and the square root of the information propagator (a.k.a. the square
+        root of the fine covariance matrix for the excitations) for all
+        refinement levels and all pixel indices in the coordinate chart.
+
+        Parameters
+        ----------
+        kernel :
+            Covariance kernel of the refinement field if not specified during
+            initialization.
+        depth :
+            Maximum refinement depth if different to the one of the `CoordinateChart`.
+        skip0 :
+            Whether to skip the first refinement level.
+        """
+        kernel = self.kernel if kernel is None else kernel
+        depth = self.chart.depth if depth is None else depth
+        skip0 = self.skip0 if skip0 is None else skip0
+
+        return _coordinate_refinement_matrices(
+            self.chart, kernel=kernel, depth=depth, skip0=skip0, **kwargs
+        )
+
+    def matrices_at(
+        self,
+        level: int,
+        pixel_index: Optional[Iterable[int]] = None,
+        kernel: Optional[Callable] = None,
+        **kwargs
+    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
+        """Computes the refinement matrices namely the optimal linear filter
+        and the square root of the information propagator (a.k.a. the square
+        root of the fine covariance matrix for the excitations) at the
+        specified level and pixel index.
+
+        Parameters
+        ----------
+        level :
+            Refinement level.
+        pixel_index :
+            Index of the NDArray at the refinement level `level` which to
+            refine, i.e. use as center coarse pixel.
+        kernel :
+            Covariance kernel of the refinement field if not specified during
+            initialization.
+        """
+        kernel = self.kernel if kernel is None else kernel
+
+        return _coordinate_pixel_refinement_matrices(
+            self.chart,
+            level=level,
+            pixel_index=pixel_index,
+            kernel=kernel,
+            **kwargs
+        )
+
+    @property
+    def shapewithdtype(self):
+        """Yields the `ShapeWithDtype` of the primals."""
+        return get_refinement_shapewithdtype(
+            shape0=self.chart.shape0,
+            depth=self.chart.depth,
+            dtype=self.dtype,
+            skip0=self.skip0,
+            _coarse_size=self.chart.coarse_size,
+            _fine_size=self.chart.fine_size,
+            _fine_strategy=self.chart.fine_strategy,
+        )
+
+    @staticmethod
+    def apply(
+        xi,
+        chart,
+        kernel: Union[Callable, RefinementMatrices],
+        *,
+        skip0: bool = False,
+        depth: Optional[int] = None,
+        coerce_fine_kernel: bool = True,
+        _refine: Optional[Callable] = None,
+        precision=None,
+    ):
+        """Static method to apply a refinement field given some excitations, a
+        chart and a kernel.
+
+        Parameters
+        ----------
+        xi :
+            Latent parameters which to use for refining.
+        chart :
+            Chart with which to refine.
+        kernel :
+            Covariance kernel with which to build the refinement matrices.
+        skip0 :
+            Whether to skip the first refinement level.
+        depth :
+            Refinement depth if different to the depth of the coordinate chart.
+        coerce_fine_kernel :
+            Whether to coerce the refinement matrices at scales at which the
+            kernel matrix becomes singular or numerically highly unstable.
+        precision :
+            See JAX's precision.
+        """
+        depth = chart.depth if depth is None else depth
+        if depth != len(xi) - 1:
+            ve = (
+                f"incompatible refinement depths of `xi` ({len(xi) - 1})"
+                f" and `depth` (of chart) {depth}"
+            )
+            raise ValueError(ve)
+
+        if isinstance(kernel, RefinementMatrices):
+            refinement = kernel
+        else:
+            refinement = _coordinate_refinement_matrices(
+                chart,
+                kernel=kernel,
+                depth=depth,
+                skip0=skip0,
+                coerce_fine_kernel=coerce_fine_kernel
+            )
+        refine_w_chart = partial(
+            refine if _refine is None else _refine,
+            _coarse_size=chart.coarse_size,
+            _fine_size=chart.fine_size,
+            _fine_strategy=chart.fine_strategy,
+            precision=precision
+        )
+
+        if not skip0:
+            fine = (refinement.cov_sqrt0 @ xi[0].ravel()).reshape(xi[0].shape)
+        else:
+            if refinement.cov_sqrt0 is not None:
+                raise AssertionError()
+            fine = xi[0]
+        for x, olf, k in zip(
+            xi[1:], refinement.filter, refinement.propagator_sqrt
+        ):
+            fine = refine_w_chart(fine, x, olf, k)
+        return fine
+
+    def __call__(self, xi, kernel=None, *, skip0=None, **kwargs):
+        """See `RefinementField.apply`."""
+        kernel = self.kernel if kernel is None else kernel
+        skip0 = self.skip0 if skip0 is None else skip0
+        return self.apply(xi, self.chart, kernel=kernel, skip0=skip0, **kwargs)
+
+    def __repr__(self):
+        descr = f"{self.__class__.__name__}({self.chart!r}"
+        descr += f", kernel={self._kernel!r}" if self._kernel is not None else ""
+        descr += f", dtype={self._dtype!r}" if self._dtype is not None else ""
+        descr += f", skip0={self.skip0!r}" if self.skip0 is not False else ""
+        descr += ")"
+        return descr
+
+    def __eq__(self, other):
+        return repr(self) == repr(other)
+
+
+def _coordinate_pixel_refinement_matrices(
+    chart: CoordinateChart,
+    level: int,
+    pixel_index: Optional[Iterable[int]] = None,
+    kernel: Optional[Callable] = None,
+    *,
+    coerce_fine_kernel: bool = True,
+    _cov_from_loc: Optional[Callable] = None,
+) -> Tuple[jnp.ndarray, jnp.ndarray]:
+    cov_from_loc = _get_cov_from_loc(kernel, _cov_from_loc)
+    csz = int(chart.coarse_size)  # coarse size
+    if csz % 2 != 1:
+        raise ValueError("only odd numbers allowed for `_coarse_size`")
+    fsz = int(chart.fine_size)  # fine size
+    if fsz % 2 != 0:
+        raise ValueError("only even numbers allowed for `_fine_size`")
+    ndim = chart.ndim
+    if pixel_index is None:
+        pixel_index = (0, ) * ndim
+    pixel_index = jnp.asarray(pixel_index)
+    if pixel_index.size != ndim:
+        ve = f"`pixel_index` has {pixel_index.size} dimensions but `chart` has {ndim}"
+        raise ValueError(ve)
+
+    csz_half = int((csz - 1) / 2)
+    gc = jnp.arange(-csz_half, csz_half + 1, dtype=float)
+    gc = jnp.ones((ndim, 1)) * gc
+    gc = jnp.stack(jnp.meshgrid(*gc, indexing="ij"), axis=-1)
+    if chart.fine_strategy == "jump":
+        gf = jnp.arange(fsz, dtype=float) / fsz - 0.5 + 0.5 / fsz
+    elif chart.fine_strategy == "extend":
+        gf = jnp.arange(fsz, dtype=float) / 2 - 0.25 * (fsz - 1)
+    else:
+        raise ValueError(f"invalid `_fine_strategy`; got {chart.fine_strategy}")
+    gf = jnp.ones((ndim, 1)) * gf
+    gf = jnp.stack(jnp.meshgrid(*gf, indexing="ij"), axis=-1)
+    # On the GPU a single `cov_from_loc` call is about twice as fast as three
+    # separate calls for coarse-coarse, fine-fine and coarse-fine.
+    coord = jnp.concatenate(
+        (gc.reshape(-1, ndim), gf.reshape(-1, ndim)), axis=0
+    )
+    coord = chart.ind2cart((coord + pixel_index.reshape((1, ndim))).T, level)
+    coord = jnp.stack(coord, axis=-1)
+    cov = cov_from_loc(coord, coord)
+    cov_ff = cov[-fsz**ndim:, -fsz**ndim:]
+    cov_fc = cov[-fsz**ndim:, :-fsz**ndim]
+    cov_cc = cov[:-fsz**ndim, :-fsz**ndim]
+    cov_cc_inv = jnp.linalg.inv(cov_cc)
+
+    olf = cov_fc @ cov_cc_inv
+    # Also see Schur-Complement
+    fine_kernel = cov_ff - cov_fc @ cov_cc_inv @ cov_fc.T
+    if coerce_fine_kernel:
+        # Implicitly assume a white power spectrum beyond the numerics limit.
+        # Use the diagonal as estimate for the magnitude of the variance.
+        fine_kernel_fallback = jnp.diag(jnp.abs(jnp.diag(fine_kernel)))
+        # Never produce NaNs (https://github.com/google/jax/issues/1052)
+        # This is expensive but necessary (worse but cheaper:
+        # `jnp.all(jnp.diag(fine_kernel) > 0.)`)
+        is_pos_def = jnp.all(jnp.linalg.eigvalsh(fine_kernel) > 0)
+        fine_kernel = jnp.where(is_pos_def, fine_kernel, fine_kernel_fallback)
+        # NOTE, subsequently use the Cholesky decomposition, even though
+        # already having computed the eigenvalues, as to get consistent results
+        # across platforms
+    fine_kernel_sqrt = jnp.linalg.cholesky(fine_kernel)
+
+    return olf, fine_kernel_sqrt
+
+
+def _coordinate_refinement_matrices(
+    chart: CoordinateChart,
+    kernel: Callable,
+    *,
+    depth: Optional[int] = None,
+    skip0=False,
+    coerce_fine_kernel: bool = True,
+    _cov_from_loc=None
+) -> RefinementMatrices:
+    cov_from_loc = _get_cov_from_loc(kernel, _cov_from_loc)
+    depth = chart.depth if depth is None else depth
+
+    if not skip0:
+        rg0 = jnp.mgrid[tuple(slice(s) for s in chart.shape0)]
+        c0 = jnp.stack(chart.ind2cart(rg0, 0), axis=-1).reshape(-1, chart.ndim)
+        cov_sqrt0 = jnp.linalg.cholesky(cov_from_loc(c0, c0))
+    else:
+        cov_sqrt0 = None
+
+    opt_lin_filter, kernel_sqrt = [], []
+    olf_at = vmap(
+        partial(
+            _coordinate_pixel_refinement_matrices,
+            chart,
+            coerce_fine_kernel=coerce_fine_kernel,
+            _cov_from_loc=cov_from_loc,
+        ),
+        in_axes=(None, 0),
+        out_axes=(0, 0)
+    )
+
+    for lvl in range(depth):
+        shape_lvl = chart.shape_at(lvl)
+        pixel_indices = []
+        for ax in range(chart.ndim):
+            pad = (chart.coarse_size - 1) / 2
+            if int(pad) != pad:
+                raise ValueError("`coarse_size` must be odd")
+            pad = int(pad)
+            if chart.fine_strategy == "jump":
+                stride = 1
+            elif chart.fine_strategy == "extend":
+                stride = chart.fine_size / 2
+                if int(stride) != stride:
+                    raise ValueError("`fine_size` must be even")
+                stride = int(stride)
+            else:
+                raise AssertionError()
+            if ax in chart.irregular_axes:
+                pixel_indices.append(
+                    jnp.arange(pad, shape_lvl[ax] - pad, stride)
+                )
+            else:
+                pixel_indices.append(jnp.array([pad]))
+        pixel_indices = jnp.stack(
+            jnp.meshgrid(*pixel_indices, indexing="ij"), axis=-1
+        )
+        shape_filtered_lvl = pixel_indices.shape[:-1]
+        pixel_indices = pixel_indices.reshape(-1, chart.ndim)
+
+        olf, ks = olf_at(lvl, pixel_indices)
+        shape_bc_lvl = tuple(
+            shape_filtered_lvl[i] if i in chart.irregular_axes else 1
+            for i in range(chart.ndim)
+        )
+        opt_lin_filter.append(olf.reshape(shape_bc_lvl + olf.shape[-2:]))
+        kernel_sqrt.append(ks.reshape(shape_bc_lvl + ks.shape[-2:]))
+
+    return RefinementMatrices(opt_lin_filter, kernel_sqrt, cov_sqrt0)
diff --git a/src/re/refine_util.py b/src/re/refine_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..443baeb00f6566a1751a70415b5be233a19c0c01
--- /dev/null
+++ b/src/re/refine_util.py
@@ -0,0 +1,330 @@
+#!/usr/bin/env python3
+
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial
+from math import ceil
+import sys
+from typing import Callable, Iterable, Literal, Optional, Tuple, Union
+from warnings import warn
+
+import jax
+from jax import numpy as jnp
+import numpy as np
+from scipy.spatial import distance_matrix
+
+from .forest_util import zeros_like
+
+
+def get_refinement_shapewithdtype(
+    shape0: Union[int, tuple],
+    depth: int,
+    dtype=None,
+    *,
+    _coarse_size: int = 3,
+    _fine_size: int = 2,
+    _fine_strategy: Literal["jump", "extend"] = "jump",
+    skip0: bool = False,
+):
+    from .forest_util import ShapeWithDtype
+
+    if depth < 0:
+        raise ValueError(f"invalid `depth`; got {depth!r}")
+    csz = int(_coarse_size)  # coarse size
+    fsz = int(_fine_size)  # fine size
+
+    swd = partial(ShapeWithDtype, dtype=dtype)
+
+    shape0 = (shape0, ) if isinstance(shape0, int) else shape0
+    ndim = len(shape0)
+    exc_shp = [swd(shape0)] if not skip0 else [None]
+    if depth > 0:
+        if _fine_strategy == "jump":
+            exc_lvl = tuple(el - (csz - 1) for el in shape0) + (fsz**ndim, )
+        elif _fine_strategy == "extend":
+            exc_lvl = tuple(
+                ceil((el - (csz - 1)) / (fsz // 2)) for el in shape0
+            ) + (fsz**ndim, )
+        else:
+            raise ValueError(f"invalid `_fine_strategy`; got {_fine_strategy}")
+        exc_shp += [swd(exc_lvl)]
+    for lvl in range(1, depth):
+        if _fine_strategy == "jump":
+            exc_lvl = tuple(
+                fsz * el - (csz - 1) for el in exc_shp[-1].shape[:-1]
+            ) + (fsz**ndim, )
+        elif _fine_strategy == "extend":
+            exc_lvl = tuple(
+                ceil((fsz * el - (csz - 1)) / (fsz // 2))
+                for el in exc_shp[-1].shape[:-1]
+            ) + (fsz**ndim, )
+        else:
+            raise AssertionError()
+        if any(el <= 0 for el in exc_lvl):
+            ve = (
+                f"`shape0` ({shape0}) with `depth` ({depth}) yield an"
+                f" invalid shape ({exc_lvl}) at level {lvl}"
+            )
+            raise ValueError(ve)
+        exc_shp += [swd(exc_lvl)]
+
+    return exc_shp
+
+
+def coarse2fine_shape(
+    shape0: Union[int, Iterable[int]],
+    depth: int,
+    *,
+    _coarse_size: int = 3,
+    _fine_size: int = 2,
+    _fine_strategy: Literal["jump", "extend"] = "jump",
+):
+    """Translates a coarse shape to its corresponding fine shape."""
+    shape0 = (shape0, ) if isinstance(shape0, int) else shape0
+    csz = int(_coarse_size)  # coarse size
+    fsz = int(_fine_size)  # fine size
+    if _fine_size % 2 != 0:
+        raise ValueError("only even numbers allowed for `_fine_size`")
+
+    shape = []
+    for shp in shape0:
+        sz_at = shp
+        for lvl in range(depth):
+            if _fine_strategy == "jump":
+                sz_at = fsz * (sz_at - (csz - 1))
+            elif _fine_strategy == "extend":
+                sz_at = fsz * ceil((sz_at - (csz - 1)) / (fsz // 2))
+            else:
+                ve = f"invalid `_fine_strategy`; got {_fine_strategy}"
+                raise ValueError(ve)
+            if sz_at <= 0:
+                ve = (
+                    f"`shape0` ({shape0}) with `depth` ({depth}) yield an"
+                    f" invalid shape ({sz_at}) at level {lvl}"
+                )
+                raise ValueError(ve)
+        shape.append(int(sz_at))
+    return tuple(shape)
+
+
+def fine2coarse_shape(
+    shape: Union[int, Iterable[int]],
+    depth: int,
+    *,
+    _coarse_size: int = 3,
+    _fine_size: int = 2,
+    _fine_strategy: Literal["jump", "extend"] = "jump",
+    ceil_sizes: bool = False,
+):
+    """Translates a fine shape to its corresponding coarse shape."""
+    shape = (shape, ) if isinstance(shape, int) else shape
+    csz = int(_coarse_size)  # coarse size
+    fsz = int(_fine_size)  # fine size
+    if _fine_size % 2 != 0:
+        raise ValueError("only even numbers allowed for `_fine_size`")
+
+    shape0 = []
+    for shp in shape:
+        sz_at = shp
+        for lvl in range(depth, 0, -1):
+            if _fine_strategy == "jump":
+                # solve for n: `fsz * (n - (csz - 1))`
+                sz_at = sz_at / fsz + (csz - 1)
+            elif _fine_strategy == "extend":
+                # solve for n: `fsz * ceil((n - (csz - 1)) / (fsz // 2))`
+                # NOTE, not unique because of `ceil`; use lower limit
+                sz_at_max = (sz_at / fsz) * (fsz // 2) + (csz - 1)
+                sz_at_min = ceil(sz_at_max - (fsz // 2 - 1))
+                for sz_at_cand in range(sz_at_min, ceil(sz_at_max) + 1):
+                    try:
+                        shp_cand = coarse2fine_shape(
+                            (sz_at_cand, ),
+                            depth=depth - lvl + 1,
+                            _coarse_size=csz,
+                            _fine_size=fsz,
+                            _fine_strategy=_fine_strategy
+                        )[0]
+                    except ValueError as e:
+                        if "invalid shape" not in "".join(e.args):
+                            ve = "unexpected behavior of `coarse2fine_shape`"
+                            raise ValueError(ve) from e
+                        shp_cand = -1
+                    if shp_cand >= shp:
+                        sz_at = sz_at_cand
+                        break
+                else:
+                    ve = f"interval search within [{sz_at_min}, {ceil(sz_at_max)}] failed"
+                    raise ValueError(ve)
+            else:
+                ve = f"invalid `_fine_strategy`; got {_fine_strategy}"
+                raise ValueError(ve)
+
+            sz_at = ceil(sz_at) if ceil_sizes else sz_at
+            if sz_at != int(sz_at):
+                raise ValueError(f"invalid shape at level {lvl}")
+        shape0.append(int(sz_at))
+    return tuple(shape0)
+
+
+def coarse2fine_distances(
+    distances0: Union[float, Iterable[float]],
+    depth: int,
+    *,
+    _fine_size: int = 2,
+    _fine_strategy: Literal["jump", "extend"] = "jump",
+):
+    """Translates coarse distances to its corresponding fine distances."""
+    fsz = int(_fine_size)  # fine size
+    if _fine_strategy == "jump":
+        fpx_in_cpx = fsz**depth
+    elif _fine_strategy == "extend":
+        fpx_in_cpx = 2**depth
+    else:
+        ve = f"invalid `_fine_strategy`; got {_fine_strategy}"
+        raise ValueError(ve)
+
+    return jnp.atleast_1d(distances0) / fpx_in_cpx
+
+
+def fine2coarse_distances(
+    distances: Union[float, Iterable[float]],
+    depth: int,
+    *,
+    _fine_size: int = 2,
+    _fine_strategy: Literal["jump", "extend"] = "jump",
+):
+    """Translates fine distances to its corresponding coarse distances."""
+    fsz = int(_fine_size)  # fine size
+    if _fine_strategy == "jump":
+        fpx_in_cpx = fsz**depth
+    elif _fine_strategy == "extend":
+        fpx_in_cpx = 2**depth
+    else:
+        ve = f"invalid `_fine_strategy`; got {_fine_strategy}"
+        raise ValueError(ve)
+
+    return jnp.atleast_1d(distances) * fpx_in_cpx
+
+
+def _clipping_posdef_logdet(mat, msg_prefix=""):
+    sign, logdet = jnp.linalg.slogdet(mat)
+    if sign <= 0:
+        ve = "not positive definite; clipping eigenvalues"
+        warn(msg_prefix + ve)
+        eps = jnp.finfo(mat.dtype.type).eps
+        evs = jnp.linalg.eigvalsh(mat)
+        logdet = jnp.sum(jnp.log(jnp.clip(evs, a_min=eps * evs.max())))
+    return logdet
+
+
+def gauss_kl(cov_desired, cov_approx, *, m_desired=None, m_approx=None):
+    cov_t_dl = _clipping_posdef_logdet(cov_desired, msg_prefix="`cov_desired` ")
+    cov_a_dl = _clipping_posdef_logdet(cov_approx, msg_prefix="`cov_approx` ")
+    cov_a_inv = jnp.linalg.inv(cov_approx)
+
+    kl = -cov_desired.shape[0]  # number of dimensions
+    kl += cov_a_dl - cov_t_dl + jnp.trace(cov_a_inv @ cov_desired)
+    if m_approx is not None and m_desired is not None:
+        m_diff = m_approx - m_desired
+        kl += m_diff @ cov_a_inv @ m_diff
+    elif not (m_approx is None and m_approx is None):
+        ve = "either both or neither of `m_approx` and `m_desired` must be `None`"
+        raise ValueError(ve)
+    return 0.5 * kl
+
+
+def refinement_covariance(chart, kernel, jit=True):
+    """Computes the implied covariance as modeled by the refinement scheme."""
+    from .refine_chart import RefinementField
+
+    cf = RefinementField(chart, kernel=kernel)
+    try:
+        cf_T = jax.linear_transpose(cf, cf.shapewithdtype)
+        cov_implicit = lambda x: cf(*cf_T(x))
+        cov_implicit = jax.jit(cov_implicit) if jit else cov_implicit
+        _ = cov_implicit(jnp.zeros(chart.shape))  # Test transpose
+    except (NotImplementedError, AssertionError):
+        # Workaround JAX not yet implementing the transpose of the scanned
+        # refinement
+        _, cf_T = jax.vjp(cf, zeros_like(cf.shapewithdtype))
+        cov_implicit = lambda x: cf(*cf_T(x))
+        cov_implicit = jax.jit(cov_implicit) if jit else cov_implicit
+
+    probe = jnp.zeros(chart.shape)
+    indices = np.indices(chart.shape).reshape(chart.ndim, -1)
+    cov_empirical = jax.lax.map(
+        lambda idx: cov_implicit(probe.at[tuple(idx)].set(1.)).ravel(),
+        indices.T
+    ).T  # vmap over `indices` w/ `in_axes=1, out_axes=-1`
+
+    return cov_empirical
+
+
+def true_covariance(chart, kernel, depth=None):
+    """Computes the true covariance at the final grid."""
+    depth = chart.depth if depth is None else depth
+
+    c0_slc = tuple(slice(sz) for sz in chart.shape_at(depth))
+    pos = jnp.stack(chart.ind2cart(jnp.mgrid[c0_slc], depth),
+                    axis=-1).reshape(-1, chart.ndim)
+    dist_mat = distance_matrix(pos, pos)
+    return kernel(dist_mat)
+
+
+def refinement_approximation_error(
+    chart,
+    kernel: Callable,
+    cutout: Optional[Union[slice, int, Tuple[slice], Tuple[int]]] = None,
+):
+    """Computes the Kullback-Leibler (KL) divergence of the true covariance versus the
+    approximative one for a given kernel and shape of the fine grid.
+
+    If the desired shape can not be matched, the next larger one is used and
+    the field is subsequently cropped to the desired shape.
+    """
+
+    suggested_min_shape = 2 * 4**chart.depth
+    if any(s <= suggested_min_shape for s in chart.shape):
+        msg = (
+            f"shape {chart.shape} potentially too small"
+            f" (desired {(suggested_min_shape, ) * chart.ndim} (=`2*4^depth`))"
+        )
+        warn(msg)
+
+    cov_empirical = refinement_covariance(chart, kernel)
+    cov_truth = true_covariance(chart, kernel)
+
+    if cutout is None and all(s > suggested_min_shape for s in chart.shape):
+        cutout = (suggested_min_shape, ) * chart.ndim
+        print(
+            f"cropping field (w/ shape {chart.shape}) to {cutout}",
+            file=sys.stderr
+        )
+    if cutout is not None:
+        if isinstance(cutout, slice):
+            cutout = (cutout, ) * chart.ndim
+        elif isinstance(cutout, int):
+            cutout = (slice(cutout), ) * chart.ndim
+        elif isinstance(cutout, tuple):
+            if all(isinstance(el, slice) for el in cutout):
+                pass
+            elif all(isinstance(el, int) for el in cutout):
+                cutout = tuple(slice(el) for el in cutout)
+            else:
+                raise TypeError("elements of `cutout` of invalid type")
+        else:
+            raise TypeError("`cutout` of invalid type")
+
+        cov_empirical = cov_empirical.reshape(chart.shape * 2)[cutout * 2]
+        cov_truth = cov_truth.reshape(chart.shape * 2)[cutout * 2]
+        sz = np.prod(cov_empirical.shape[:chart.ndim])
+        if np.prod(cov_truth.shape[:chart.ndim]) != sz or not sz.dtype == int:
+            raise AssertionError()
+        cov_empirical = cov_empirical.reshape(sz, sz)
+        cov_truth = cov_truth.reshape(sz, sz)
+
+    aux = {
+        "cov_empirical": cov_empirical,
+        "cov_truth": cov_truth,
+    }
+    return gauss_kl(cov_truth, cov_empirical), aux
diff --git a/src/re/stats_distributions.py b/src/re/stats_distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..de6980b8196909ec2ca7f342698f3824afdfb2b8
--- /dev/null
+++ b/src/re/stats_distributions.py
@@ -0,0 +1,254 @@
+from typing import Callable, Optional
+
+from jax import numpy as jnp
+
+
+def laplace_prior(alpha) -> Callable:
+    """
+    Takes random normal samples and outputs samples distributed according to
+
+    .. math::
+        P(x|a) = exp(-|x|/a)/a/2
+
+    """
+    from jax.scipy.stats import norm
+
+    def standard_to_laplace(xi):
+        res = (xi < 0) * (norm.logcdf(xi) + jnp.log(2))
+        res -= (xi > 0) * (norm.logcdf(-xi) + jnp.log(2))
+        return res * alpha
+
+    return standard_to_laplace
+
+
+def normal_prior(mean, std) -> Callable:
+    """Match standard normally distributed random variables to non-standard
+    variables.
+    """
+    def standard_to_normal(xi):
+        return mean + std * xi
+
+    return standard_to_normal
+
+
+def lognormal_moments(mean, std):
+    """Compute the cumulants a log-normal process would need to comply with the
+    provided mean and standard-deviation `std`
+    """
+    if jnp.any(mean <= 0.):
+        raise ValueError(f"`mean` must be greater zero; got {mean!r}")
+    if jnp.any(std <= 0.):
+        raise ValueError(f"`std` must be greater zero; got {std!r}")
+
+    logstd = jnp.sqrt(jnp.log1p((std / mean)**2))
+    logmean = jnp.log(mean) - 0.5 * logstd**2
+    return logmean, logstd
+
+
+def lognormal_prior(mean, std) -> Callable:
+    """Moment-match standard normally distributed random variables to log-space
+
+    Takes random normal samples and outputs samples distributed according to
+
+    .. math::
+        P(xi|mu,sigma) \\propto exp(mu + sigma * xi)
+
+    such that the mean and standard deviation of the distribution matches the
+    specified values.
+    """
+    standard_to_normal = normal_prior(*lognormal_moments(mean, std))
+
+    def standard_to_lognormal(xi):
+        return jnp.exp(standard_to_normal(xi))
+
+    return standard_to_lognormal
+
+
+def lognormal_invprior(mean, std) -> Callable:
+    """Get the inverse transform to `lognormal_prior`."""
+    ln_m, ln_std = lognormal_moments(mean, std)
+
+    def lognormal_to_standard(y):
+        return (jnp.log(y) - ln_m) / ln_std
+
+    return lognormal_to_standard
+
+
+def uniform_prior(a_min=0., a_max=1.) -> Callable:
+    """Transform a standard normal into a uniform distribution.
+
+    Parameters
+    ----------
+    a_min : float
+        Minimum value.
+    a_max : float
+        Maximum value.
+    """
+    from jax.scipy.stats import norm
+
+    if a_min == 0. and a_max == 1.:
+        return norm.cdf
+
+    scale = a_max - a_min
+
+    def standard_to_uniform(xi):
+        return a_min + scale * norm.cdf(xi)
+
+    return standard_to_uniform
+
+
+def interpolator(
+    func: Callable,
+    xmin: float,
+    xmax: float,
+    *,
+    step: Optional[float] = None,
+    num: Optional[int] = None,
+    table_func: Optional[Callable] = None,
+    inv_table_func: Optional[Callable] = None,
+    return_inverse: Optional[bool] = False
+):  # Adapted from NIFTy
+    """
+    Evaluate a function point-wise by interpolation.  Can be supplied with a
+    table_func to increase the interpolation accuracy, Best results are
+    achieved when `lambda x: table_func(func(x))` is roughly linear.
+
+    Parameters
+    ----------
+    func : function
+        Function to interpolate.
+    xmin : float
+        The smallest value for which `func` will be evaluated.
+    xmax : float
+        The largest value for which `func` will be evaluated.
+    step : float
+        Distance between sampling points for linear interpolation. Either of
+        `step` or `num` must be specified.
+    num : int
+        The number of interpolation points. Either of `step` of `num` must be
+        specified.
+    table_func : function
+        Non-linear function applied to the tabulated function in order to
+        transform the table to a more linear space.
+    inv_table_func : function
+        Inverse of `table_func`.
+    return_inverse : bool
+        Whether to also return the interpolation of the inverse of `func`. Only
+        sensible if `func` is invertible.
+    """
+    # from scipy.interpolate import CubicSpline
+
+    if step is not None and num is not None:
+        ve = "either but not both of `step` and `num` must be specified"
+        raise ValueError(ve)
+    if step is not None:
+        xs = jnp.arange(xmin, xmax + step, step)
+    elif num is not None:
+        xs = jnp.linspace(xmin, xmax, num)
+    else:
+        ve = "either of `step` or `num` must be specified"
+        raise ValueError(ve)
+
+    ys = func(xs)
+    if table_func is not None:
+        if inv_table_func is None:
+            raise ValueError("no `inv_table_func` specified")
+        ys = table_func(ys)
+
+    # interpolator = CubicSpline(xs, ys)
+    # deriv = interpolator.derivative()
+
+    def interp(x):
+        # res = interpolator(x)
+        res = jnp.interp(x, xs, ys)
+        if inv_table_func is not None:
+            res = inv_table_func(res)
+        return res
+
+    if return_inverse:
+
+        def inverse_interp(y):
+            if table_func is not None:
+                y = table_func(y)
+            return jnp.interp(y, ys, xs)
+
+        return interp, inverse_interp
+
+    return interp
+
+
+def invgamma_prior(a, scale, loc=0., step=1e-2) -> Callable:
+    """Transform a standard normal into an inverse gamma distribution.
+
+    The pdf of the inverse gamma distribution is defined as follows using
+    :math:`q` to denote the scale:
+
+    .. math::
+
+        P(x|q, a) = \\frac{q^a}{\\Gamma(a)}x^{-a -1}
+        \\exp \\left(-\\frac{q}{x}\\right)
+
+    That means that for large x the pdf falls off like :math:`x^{(-a -1)}`.
+    The mean of the pdf is at :math:`q / (a - 1)` if :math:`a > 1`.
+    The mode is :math:`q / (a + 1)`.
+
+    This transformation is implemented as a linear interpolation which maps a
+    Gaussian onto an inverse gamma distribution.
+
+    Parameters
+    ----------
+    a : float
+        The shape-parameter of the inverse-gamma distribution.
+    scale : float
+        The scale-parameter of the inverse-gamma distribution.
+    loc : float
+        An option shift of the whole distribution.
+    step : float
+        Distance between sampling points for linear interpolation.
+    """
+    from scipy.stats import invgamma, norm
+
+    if not jnp.isscalar(a) or not jnp.isscalar(loc):
+        te = (
+            "Shape `a` and location `loc` must be of scalar type"
+            f"; got {type(a)} and {type(loc)} respectively"
+        )
+        raise TypeError(te)
+    if loc == 0.:
+        # Pull out `scale` to interpolate less
+        s2i = lambda x: invgamma.ppf(norm._cdf(x), a=a)
+    elif jnp.isscalar(scale):
+        s2i = lambda x: invgamma.ppf(norm._cdf(x), a=a, loc=loc, scale=scale)
+    else:
+        raise TypeError("`scale` may only be array-like for `loc == 0.`")
+
+    xmin, xmax = -8.2, 8.2  # (1. - norm.cdf(8.2)) * 2 < 1e-15
+    standard_to_invgamma_interp = interpolator(
+        s2i, xmin, xmax, step=step, table_func=jnp.log, inv_table_func=jnp.exp
+    )
+
+    def standard_to_invgamma(x):
+        # Allow for array-like `scale` without separate interpolations and only
+        # interpolate for shape `a` and `loc`
+        if loc == 0.:
+            return standard_to_invgamma_interp(x) * scale
+        return standard_to_invgamma_interp(x)
+
+    return standard_to_invgamma
+
+
+def invgamma_invprior(a, scale, loc=0., step=1e-2) -> Callable:
+    """Get the inverse transformation to `invgamma_prior`."""
+    from scipy.stats import invgamma, norm
+
+    xmin, xmax = -8.2, 8.2  # (1. - norm.cdf(8.2)) * 2 < 1e-15
+    _, invgamma_to_standard = interpolator(
+        lambda x: invgamma.ppf(norm._cdf(x), a=a, loc=loc, scale=scale),
+        xmin,
+        xmax,
+        step=step,
+        table_func=jnp.log,
+        inv_table_func=jnp.exp,
+        return_inverse=True
+    )
+    return invgamma_to_standard
diff --git a/src/re/structured_kernel_interpolation.py b/src/re/structured_kernel_interpolation.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b2e388282aebc2cb3f3e03f2c4fa9d322983a99
--- /dev/null
+++ b/src/re/structured_kernel_interpolation.py
@@ -0,0 +1,265 @@
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from typing import Callable, Optional, Tuple, Union
+
+import jax
+from jax import numpy as jnp
+import numpy as np
+
+from .correlated_field import get_fourier_mode_distributor, hartley
+
+NDArray = Union[jnp.ndarray, np.ndarray]
+
+
+def interp_mat(grid_shape, grid_bounds, sampling_points, *, distances=None):
+    from scipy.sparse import coo_matrix  # TODO: use only JAX w/o SciPy or NumPy
+    from jax.experimental.sparse import BCOO
+
+    if sampling_points.ndim != 2:
+        ve = f"invalid dimension of sampling_points {sampling_points.ndim!r}"
+        raise ValueError(ve)
+    ndim, n_points = sampling_points.shape
+    if grid_bounds is not None and len(grid_bounds) != ndim:
+        ve = (
+            f"grid_bounds of length {len(grid_bounds)} incompatible with"
+            " sampling_points of shape {sampling_points.shape!r}"
+        )
+        raise ValueError(ve)
+    elif grid_bounds is not None:
+        offset = np.array(list(zip(*grid_bounds))[0])
+    else:
+        offset = np.zeros(ndim)
+    if distances is not None and np.size(distances) != ndim:
+        ve = (
+            f"distances of size {np.size(distances)} incompatible with"
+            " sampling_points of shape {sampling_points.shape!r}"
+        )
+        raise ValueError(ve)
+    distances = np.asarray(distances) if distances is not None else None
+    if (distances is not None and grid_bounds
+        is not None) or (distances is None and grid_bounds is None):
+        raise ValueError("exactly one of `distances` or `grid_shape` expected")
+    elif grid_bounds is not None:
+        distances = np.array(
+            [(b[1] - b[0]) / sz for b, sz in zip(grid_bounds, grid_shape)]
+        )
+    if distances is None:
+        raise AssertionError()
+
+    mg = np.mgrid[(slice(0, 2), ) * ndim].reshape(ndim, -1)
+    pos = (sampling_points - offset.reshape(-1, 1)) / distances.reshape(-1, 1)
+    excess, pos = np.modf(pos)
+    pos = pos.astype(np.int64)
+    # max_index = np.array(grid_shape).reshape(-1, 1)
+    weights = np.zeros((2**ndim, n_points))
+    ii = np.zeros((2**ndim, n_points), dtype=np.int64)
+    jj = np.zeros((2**ndim, n_points), dtype=np.int64)
+    for i in range(2**ndim):
+        weights[i, :] = np.prod(
+            np.abs(1 - mg[:, i].reshape(-1, 1) - excess), axis=0
+        )
+        fromi = (pos + mg[:, i].reshape(-1, 1))  # % max_index
+        ii[i, :] = np.arange(n_points)
+        jj[i, :] = np.ravel_multi_index(fromi, grid_shape)
+
+    mat = coo_matrix(
+        (weights.ravel(), (ii.ravel(), jj.ravel())),
+        shape=(n_points, np.prod(grid_shape))
+    )
+    # BCOO(
+    #     (weights.ravel(), jnp.stack((ii.ravel(), jj.ravel()), axis=1)),
+    #     shape=(n_points, np.prod(grid_shape))
+    # )
+    return BCOO.from_scipy_sparse(mat)
+
+
+class HarmonicSKI():
+    def __init__(
+        self,
+        grid_shape: Tuple[int],
+        grid_bounds: Tuple[Tuple[float, float]],
+        sampling_points: NDArray,
+        harmonic_kernel: Optional[Callable] = None,
+        padding: float = 0.5,
+        subslice=None,
+        jitter: Union[bool, float, None] = True
+    ):
+        """Instantiate a KISS-GP model of the covariance using a harmonic
+        representation of the kernel.
+
+        Parameters
+        ----------
+        grid_shape :
+            Number of pixels along each axes of the inducing points within
+            `grid_bounds`.
+        grid_bounds :
+            Tuple of boundaries of length of the number of dimensions. The
+            boundaries should denote the leftmost and rightmost edge of the
+            modeling space.
+        sampling_points :
+            Locations of the modeled points within the grid.
+        harmonic_kernel :
+            Harmonically transformed kernel.
+        padding :
+            Padding factor which to apply along each axis.
+        subslice :
+            Slice of the inducing points which to use to model
+            `sampling_points`. By default, the subslice is determined by the
+            padding.
+        jitter :
+            Strength of the diagonal jitter which to add to the covariance.
+        """
+        if jitter is True:
+            if sampling_points.dtype.type == np.float64:
+                self.jitter = 1e-8
+            elif sampling_points.dtype.type == np.float32:
+                self.jitter = 1e-6
+            else:
+                raise NotImplementedError()
+        elif jitter is False:
+            self.jitter = None
+        else:
+            self.jitter = jitter
+
+        self.grid_unpadded_shape = np.asarray(grid_shape)
+        self.grid_unpadded_bounds = np.asarray(grid_bounds)
+        self.grid_unpadded_distances = jnp.diff(
+            self.grid_unpadded_bounds, axis=1
+        ).ravel() / self.grid_unpadded_shape
+        self.grid_unpadded_total_volume = jnp.prod(
+            self.grid_unpadded_shape * self.grid_unpadded_distances
+        )
+        self.w = interp_mat(grid_shape, grid_bounds, sampling_points)
+
+        if padding is not None and padding != 0.:
+            pad = 1. + padding
+            grid_shape = np.asarray(grid_shape)
+            grid_shape_wpad = np.ceil(grid_shape * pad).astype(int)
+            scl = grid_shape_wpad / grid_shape
+            scl_end = jnp.diff(jnp.asarray(grid_bounds), axis=1).ravel() * scl
+            grid_bounds_wpad = jnp.asarray(grid_bounds)
+            grid_bounds_wpad = grid_bounds_wpad.at[:, 1].set(
+                grid_bounds_wpad[:, 0].ravel() + scl_end
+            )
+            if subslice is None:
+                subslice = tuple(map(int, grid_shape))
+            grid_shape = grid_shape_wpad
+            grid_bounds = grid_bounds_wpad
+        self.grid_shape = np.asarray(grid_shape)
+        self.grid_bounds = np.asarray(grid_bounds)
+        self.grid_distances = jnp.diff(self.grid_bounds,
+                                       axis=1).ravel() / self.grid_shape
+        self.grid_total_volume = jnp.prod(self.grid_shape * self.grid_distances)
+
+        self.power_distributor, self.unique_mode_lengths, _ = get_fourier_mode_distributor(
+            self.grid_shape, self.grid_distances
+        )
+
+        if subslice is not None:
+            if isinstance(subslice, slice):
+                subslice = (subslice, ) * len(self.grid_shape)
+            elif isinstance(subslice, int):
+                subslice = (slice(subslice), ) * len(self.grid_shape)
+            elif isinstance(subslice, tuple):
+                if all(isinstance(el, slice) for el in subslice):
+                    pass
+                elif all(isinstance(el, int) for el in subslice):
+                    subslice = tuple(slice(el) for el in subslice)
+                else:
+                    raise TypeError("elements of `subslice` of invalid type")
+            else:
+                raise TypeError("`subslice` of invalid type")
+        self.grid_subslice = subslice
+
+        self._harmonic_kernel = harmonic_kernel
+
+    @property
+    def harmonic_kernel(self) -> Callable:
+        """Yields the harmonic kernel specified during initialization or throw
+        a `TypeError`.
+        """
+        if self._harmonic_kernel is None:
+            te = (
+                "either specify a fixed harmonic kernel during initialization"
+                f" of the {self.__class__.__name__} class or provide one here"
+            )
+            raise TypeError(te)
+        return self._harmonic_kernel
+
+    def power(self, harmonic_kernel=None) -> NDArray:
+        if harmonic_kernel is None:
+            harmonic_kernel = self.harmonic_kernel
+        power = harmonic_kernel(self.unique_mode_lengths)
+        power *= self.grid_total_volume / self.grid_unpadded_total_volume
+        return power
+
+    def amplitude(self, harmonic_kernel=None):
+        power = self.power(harmonic_kernel)
+        # Assume that the kernel scales linear with the total volume
+        return jnp.sqrt(power)
+
+    def harmonic_transform(self, x) -> NDArray:
+        return 1. / self.grid_total_volume * hartley(x)
+
+    def correlated_field(self, x, harmonic_kernel=None) -> NDArray:
+        amp = self.amplitude(harmonic_kernel)
+        f = self.harmonic_transform(amp[self.power_distributor] * x)
+        if self.grid_subslice is None:
+            return f
+        return f[self.grid_subslice]
+
+    def sandwich(self, x, harmonic_kernel=None) -> NDArray:
+        if self.grid_subslice is None:
+            x_wpad = x
+        else:
+            x_wpad = jnp.zeros(tuple(self.grid_shape))
+            x_wpad = x_wpad.at[self.grid_subslice].set(x)
+
+        swd = jax.ShapeDtypeStruct(tuple(self.grid_shape), x.dtype)
+        ht = self.harmonic_transform
+        ht_T = jax.linear_transpose(self.harmonic_transform, swd)
+
+        power = self.power(harmonic_kernel=harmonic_kernel)
+        s = ht(power[self.power_distributor] * ht_T(x_wpad)[0])
+        if self.grid_subslice is None:
+            return s
+        return s[self.grid_subslice]
+
+    def __call__(self, x, harmonic_kernel=None) -> NDArray:
+        """Applies the Covariance matrix."""
+        x_shp = x.shape
+        jitter = 0. if self.jitter is None else self.jitter * x
+
+        x = (self.w.T @ x.ravel()).reshape(tuple(self.grid_unpadded_shape))
+        x = self.sandwich(x, harmonic_kernel=harmonic_kernel)
+        x = (self.w @ x.ravel()).reshape(x_shp)
+        return x + jitter
+
+    def evaluate(self, harmonic_kernel=None):
+        """Instantiate the full covariance matrix."""
+        probe = jnp.zeros(self.w.shape[0])
+        indices = jnp.arange(self.w.shape[0]).reshape(1, -1)
+
+        return jax.lax.map(
+            lambda idx: self(
+                probe.at[tuple(idx)].set(1.), harmonic_kernel=harmonic_kernel
+            ).ravel(), indices.T
+        ).T  # vmap over `indices` w/ `in_axes=1, out_axes=-1`
+
+    def evaluate_(self, kernel) -> NDArray:
+        from scipy.spatial import distance_matrix
+
+        if self.jitter is None:
+            jitter = 0.
+        else:
+            jitter = self.jitter * jnp.eye(self.w.shape[0])
+
+        p = [
+            np.linspace(*b, num=sz, endpoint=True) for b, sz in
+            zip(self.grid_unpadded_bounds, self.grid_unpadded_shape)
+        ]
+        p = np.stack(np.meshgrid(*p, indexing="ij"),
+                     axis=-1).reshape(-1, len(self.grid_unpadded_shape))
+        kernel_inducing = kernel(distance_matrix(p, p))
+
+        return self.w @ kernel_inducing @ self.w.T + jitter
diff --git a/src/re/sugar.py b/src/re/sugar.py
new file mode 100644
index 0000000000000000000000000000000000000000..33dfd0468366804c1b2f3dbc2cdba3110ff98195
--- /dev/null
+++ b/src/re/sugar.py
@@ -0,0 +1,137 @@
+# Copyright(C) 2013-2021 Max-Planck-Society
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from collections.abc import Iterable
+from typing import Any, Callable, Dict, Hashable, Mapping, TypeVar, Union
+
+from jax import numpy as jnp
+from jax import random
+from jax.tree_util import tree_map, tree_reduce, tree_structure, tree_unflatten
+
+from .field import Field
+
+O = TypeVar('O')
+I = TypeVar('I')
+
+
+def isiterable(candidate):
+    try:
+        iter(candidate)
+        return True
+    except (TypeError, AttributeError):
+        return False
+
+
+def is1d(ls: Any) -> bool:
+    """Indicates whether the input is one dimensional.
+
+    An object is considered one dimensional if it is an iterable of
+    non-iterable items.
+    """
+    if hasattr(ls, "ndim"):
+        return ls.ndim == 1
+    if not isiterable(ls):
+        return False
+    return all(not isiterable(e) for e in ls)
+
+
+def doc_from(original):
+    def wrapper(target):
+        target.__doc__ = original.__doc__
+        return target
+
+    return wrapper
+
+
+def ducktape(call: Callable[[I], O],
+             key: Hashable) -> Callable[[Mapping[Hashable, I]], O]:
+    def named_call(p):
+        return call(p[key])
+
+    return named_call
+
+
+def ducktape_left(call: Callable[[I], O],
+                  key: Hashable) -> Callable[[I], Dict[Hashable, O]]:
+    def named_call(p):
+        return {key: call(p)}
+
+    return named_call
+
+
+def sum_of_squares(tree) -> Union[jnp.ndarray, jnp.inexact]:
+    return tree_reduce(jnp.add, tree_map(lambda x: jnp.sum(x**2), tree), 0.)
+
+
+def mean(forest):
+    from functools import reduce
+
+    norm = 1. / len(forest)
+    if isinstance(forest[0], Field):
+        m = norm * reduce(Field.__add__, forest)
+        return m
+    else:
+        m = norm * reduce(Field.__add__, (Field(t) for t in forest))
+        return m.val
+
+
+def mean_and_std(forest, correct_bias=True):
+    if isinstance(forest[0], Field):
+        m = mean(forest)
+        mean_of_sq = mean(tuple(t**2 for t in forest))
+    else:
+        m = Field(mean(forest))
+        mean_of_sq = Field(mean(tuple(Field(t)**2 for t in forest)))
+
+    n = len(forest)
+    scl = jnp.sqrt(n / (n - 1)) if correct_bias else 1.
+    std = scl * tree_map(jnp.sqrt, mean_of_sq - m**2)
+    if isinstance(forest[0], Field):
+        return m, std
+    else:
+        return m.val, std.val
+
+
+def random_like(key: Iterable, primals, rng: Callable = random.normal):
+    import numpy as np
+
+    struct = tree_structure(primals)
+    # Cast the subkeys to the structure of `primals`
+    subkeys = tree_unflatten(struct, random.split(key, struct.num_leaves))
+
+    def draw(key, x):
+        shp = x.shape if hasattr(x, "shape") else jnp.shape(x)
+        dtp = x.dtype if hasattr(x, "dtype") else np.common_type(x)
+        return rng(key=key, shape=shp, dtype=dtp)
+
+    return tree_map(draw, subkeys, primals)
+
+
+def interpolate(xmin=-7., xmax=7., N=14000) -> Callable:
+    """Replaces a local nonlinearity such as jnp.exp with a linear interpolation
+
+    Interpolating functions speeds up code and increases numerical stability in
+    some cases, but at a cost of precision and range.
+
+    Parameters
+    ----------
+    xmin : float
+        Minimal interpolation value. Default: -7.
+    xmax : float
+        Maximal interpolation value. Default: 7.
+    N : int
+        Number of points used for the interpolation. Default: 14000
+    """
+    def decorator(f):
+        from functools import wraps
+
+        x = jnp.linspace(xmin, xmax, N)
+        y = f(x)
+
+        @wraps(f)
+        def wrapper(t):
+            return jnp.interp(t, x, y)
+
+        return wrapper
+
+    return decorator
diff --git a/src/sugar.py b/src/sugar.py
index 7f5097f005b2340375d2d030bd9308f7b12238c4..36b7ac333e217ef50e63e4edcf78d33a300c0a23 100644
--- a/src/sugar.py
+++ b/src/sugar.py
@@ -60,7 +60,7 @@ def PS_field(pspace, function):
 
     Returns
     -------
-    Field
+    :class:`nifty8.field.Field`
         A field defined on (pspace,) containing the computed function values
     """
     if not isinstance(pspace, PowerSpace):
@@ -119,7 +119,7 @@ def power_analyze(field, spaces=None, binbounds=None,
 
     Parameters
     ----------
-    field : Field
+    field : :class:`nifty8.field.Field`
         The field to be analyzed
     spaces : None or int or tuple of int, optional
         The indices of subdomains for which the power spectrum shall be
@@ -142,7 +142,7 @@ def power_analyze(field, spaces=None, binbounds=None,
 
     Returns
     -------
-    Field
+    :class:`nifty8.field.Field`
         The output object. Its domain is a PowerSpace and it contains
         the power spectrum of `field`.
     """
@@ -203,7 +203,7 @@ def create_power_operator(domain, power_spectrum, space=None,
     ----------
     domain : Domain, tuple of Domain or DomainTuple
         Domain on which the power operator shall be defined.
-    power_spectrum : callable or Field
+    power_spectrum : callable or :class:`nifty8.field.Field`
         An object that contains the power spectrum as a function of k.
     space : int
         the domain index on which the power operator will work
@@ -318,7 +318,7 @@ def full(domain, val):
 
     Returns
     -------
-    Field or MultiField
+    :class:`nifty8.field.Field` or:class:`nifty8.mulit_field.MultiField`
         The newly created uniform field
     """
     if isinstance(domain, (dict, MultiDomain)):
@@ -344,7 +344,7 @@ def from_random(domain, random_type='normal', dtype=np.float64, **kwargs):
 
     Returns
     -------
-    Field or MultiField
+    :class:`nifty8.field.Field` or:class:`nifty8.mulit_field.MultiField`
         The newly created random field
 
     Notes
@@ -372,7 +372,7 @@ def makeField(domain, arr):
 
     Returns
     -------
-    Field or MultiField
+    :class:`nifty8.field.Field` or:class:`nifty8.mulit_field.MultiField`
         The newly created random field
     """
     if isinstance(domain, (dict, MultiDomain)):
@@ -405,7 +405,7 @@ def makeOp(inp, dom=None, sampling_dtype=None):
 
     Parameters
     ----------
-    inp : None, Field or MultiField
+    inp : None, :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField`
         - if None, None is returned.
         - if Field on scalar-domain, a ScalingOperator with the coefficient
             given by the Field is returned.
diff --git a/test/test_operators/test_interpolated.py b/test/test_operators/test_interpolated.py
index eae1c230d3534992297a1b6dca7f48760a387973..432b7c71c3ccb124dd21f577a1aa549a64c3f1ac 100644
--- a/test/test_operators/test_interpolated.py
+++ b/test/test_operators/test_interpolated.py
@@ -25,7 +25,6 @@ import nifty8 as ift
 
 from ..common import list2fixture, setup_function, teardown_function
 
-pmp = pytest.mark.parametrize
 pmp = pytest.mark.parametrize
 space = list2fixture([ift.GLSpace(15),
                       ift.RGSpace(64, distances=.789),
@@ -34,6 +33,7 @@ seed = list2fixture([4, 78, 23])
 
 
 def testInterpolationAccuracy(space, seed):
+    ift.random.push_sseq_from_seed(seed)
     pos = ift.from_random(space, 'normal')
     alpha = 1.5
     qs = [0.73, pos.ptw("exp").val]
diff --git a/test/test_re/test_energies.py b/test/test_re/test_energies.py
new file mode 100644
index 0000000000000000000000000000000000000000..9df3cd1b5735164518d7c7df2c913622f719a60a
--- /dev/null
+++ b/test/test_re/test_energies.py
@@ -0,0 +1,194 @@
+#!/usr/bin/env python3
+
+import jax.numpy as jnp
+import pytest
+from functools import partial
+from jax import random
+from jax.tree_util import tree_map
+from numpy.testing import assert_allclose
+
+import nifty8.re as jft
+
+pmp = pytest.mark.parametrize
+
+
+def lst2fixt(lst):
+    @pytest.fixture(params=lst)
+    def fixt(request):
+        return request.param
+
+    return fixt
+
+
+def random_noise_std_inv(key, shape):
+    diag = 1. / random.exponential(key, shape)
+
+    def noise_std_inv(tangents):
+        return diag * tangents
+
+    return noise_std_inv
+
+
+seed = lst2fixt((3639, 12, 41, 42))
+shape = lst2fixt(((4, 2), (2, 1), (5, )))
+lh_init_true = (
+    (
+        jft.Gaussian, {
+            "data": random.normal,
+            "noise_std_inv": random_noise_std_inv
+        }, None
+    ), (
+        jft.StudentT, {
+            "data": random.normal,
+            "dof": random.exponential,
+            "noise_std_inv": random_noise_std_inv
+        }, None
+    ), (
+        jft.Poissonian, {
+            "data": partial(random.poisson, lam=3.14)
+        }, random.exponential
+    )
+)
+lh_init_approx = (
+    (
+        jft.VariableCovarianceGaussian, {
+            "data": random.normal
+        }, lambda key, shape: (
+            random.normal(key, shape=shape), 1. / jnp.
+            exp(random.normal(key, shape=shape))
+        )
+    ), (
+        jft.VariableCovarianceStudentT, {
+            "data": random.normal,
+            "dof": random.exponential
+        }, lambda key, shape: (
+            random.normal(key, shape=shape),
+            jnp.exp(1. + random.normal(key, shape=shape))
+        )
+    )
+)
+
+
+def test_gaussian_vs_vcgaussian_consistency(seed, shape):
+    rtol = 10 * jnp.finfo(jnp.zeros(0).dtype).eps
+    atol = 5 * jnp.finfo(jnp.zeros(0).dtype).eps
+
+    key = random.PRNGKey(seed)
+    sk = list(random.split(key, 5))
+    d = random.normal(sk.pop(), shape=shape)
+    m1 = random.normal(sk.pop(), shape=shape)
+    m2 = random.normal(sk.pop(), shape=shape)
+    t = random.normal(sk.pop(), shape=shape)
+    inv_std = 1. / jnp.exp(1. + random.normal(sk.pop(), shape=shape))
+
+    gauss = jft.Gaussian(d, noise_std_inv=lambda x: inv_std * x)
+    vcgauss = jft.VariableCovarianceGaussian(d)
+
+    diff_g = gauss(m2) - gauss(m1)
+    diff_vcg = vcgauss((m2, inv_std)) - vcgauss((m1, inv_std))
+    assert_allclose(diff_g, diff_vcg, rtol=rtol, atol=atol)
+
+    met_g = gauss.metric(m1, t)
+    met_vcg = vcgauss.metric((m1, inv_std), (t, d / 2))[0]
+    assert_allclose(met_g, met_vcg, rtol=rtol, atol=atol)
+
+
+def test_studt_vs_vcstudt_consistency(seed, shape):
+    rtol = 10 * jnp.finfo(jnp.zeros(0).dtype).eps
+    atol = 4 * jnp.finfo(jnp.zeros(0).dtype).eps
+
+    key = random.PRNGKey(seed)
+    sk = list(random.split(key, 6))
+    d = random.normal(sk.pop(), shape=shape)
+    dof = random.normal(sk.pop(), shape=shape)
+    m1 = random.normal(sk.pop(), shape=shape)
+    m2 = random.normal(sk.pop(), shape=shape)
+    t = random.normal(sk.pop(), shape=shape)
+    inv_std = 1. / jnp.exp(1. + random.normal(sk.pop(), shape=shape))
+
+    studt = jft.StudentT(d, dof, noise_std_inv=lambda x: inv_std * x)
+    vcstudt = jft.VariableCovarianceStudentT(d, dof)
+
+    diff_t = studt(m2) - studt(m1)
+    diff_vct = vcstudt((m2, 1. / inv_std)) - vcstudt((m1, 1. / inv_std))
+    assert_allclose(diff_t, diff_vct, rtol=rtol, atol=atol)
+
+    met_g = studt.metric(m1, t)
+    met_vcg = vcstudt.metric((m1, 1. / inv_std), (t, d / 2))[0]
+    assert_allclose(met_g, met_vcg, rtol=rtol, atol=atol)
+
+
+@pmp("lh_init", lh_init_true + lh_init_approx)
+def test_left_sqrt_metric_vs_metric_consistency(seed, shape, lh_init):
+    rtol = 4 * jnp.finfo(jnp.zeros(0).dtype).eps
+    atol = 0.
+    aallclose = partial(assert_allclose, rtol=rtol, atol=atol)
+
+    N_TRIES = 5
+
+    lh_init_method, draw, latent_init = lh_init
+    key = random.PRNGKey(seed)
+    key, *subkeys = random.split(key, 1 + len(draw))
+    init_kwargs = {
+        k: method(key=sk, shape=shape)
+        for (k, method), sk in zip(draw.items(), subkeys)
+    }
+    lh = lh_init_method(**init_kwargs)
+
+    energy, lsm, lsm_shp = lh.energy, lh.left_sqrt_metric, lh.lsm_tangents_shape
+    # Let JIFTy infer the metric from the left-square-root-metric
+    lh_mini = jft.Likelihood(
+        energy, left_sqrt_metric=lsm, lsm_tangents_shape=lsm_shp
+    )
+
+    rng_method = latent_init if latent_init is not None else random.normal
+    for _ in range(N_TRIES):
+        key, *sk = random.split(key, 3)
+        p = rng_method(sk.pop(), shape=shape)
+        t = rng_method(sk.pop(), shape=shape)
+        tree_map(aallclose, lh.metric(p, t), lh_mini.metric(p, t))
+
+
+@pmp("lh_init", lh_init_true)
+def test_transformation_vs_left_sqrt_metric_consistency(seed, shape, lh_init):
+    rtol = 4 * jnp.finfo(jnp.zeros(0).dtype).eps
+    atol = 0.
+
+    N_TRIES = 5
+
+    lh_init_method, draw, latent_init = lh_init
+    key = random.PRNGKey(seed)
+    key, *subkeys = random.split(key, 1 + len(draw))
+    init_kwargs = {
+        k: method(key=sk, shape=shape)
+        for (k, method), sk in zip(draw.items(), subkeys)
+    }
+    lh = lh_init_method(**init_kwargs)
+    if lh._transformation is None:
+        pytest.skip("no transformation rule implemented yet")
+
+    energy, lsm, lsm_shp = lh.energy, lh.left_sqrt_metric, lh.lsm_tangents_shape
+    # Let JIFTy infer the left-square-root-metric and the metric from the
+    # transformation
+    lh_mini = jft.Likelihood(
+        energy, left_sqrt_metric=lsm, lsm_tangents_shape=lsm_shp
+    )
+
+    rng_method = latent_init if latent_init is not None else random.normal
+    for _ in range(N_TRIES):
+        key, *sk = random.split(key, 3)
+        p = rng_method(sk.pop(), shape=shape)
+        t = rng_method(sk.pop(), shape=shape)
+        assert_allclose(
+            lh.left_sqrt_metric(p, t),
+            lh_mini.left_sqrt_metric(p, t),
+            rtol=rtol,
+            atol=atol
+        )
+        assert_allclose(
+            lh.metric(p, t), lh_mini.metric(p, t), rtol=rtol, atol=atol
+        )
+
+
+if __name__ == "__main__":
+    test_gaussian_vs_vcgaussian_consistency(42, (5, ))
diff --git a/test/test_re/test_hmc_1d_distributions.py b/test/test_re/test_hmc_1d_distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1eb520c7211f11b8815a67ec55869b5807082268
--- /dev/null
+++ b/test/test_re/test_hmc_1d_distributions.py
@@ -0,0 +1,120 @@
+import sys
+
+from jax import numpy as jnp
+from jax.scipy import stats
+from numpy.testing import assert_allclose
+import pytest
+import scipy
+from scipy.special import comb
+
+import nifty8.re as jft
+
+pmp = pytest.mark.parametrize
+
+
+def mnc2mc(mnc, wmean=True):
+    """Convert non-central to central moments, uses recursive formula
+    optionally adjusts first moment to return mean.
+    """
+
+    # https://www.statsmodels.org/stable/_modules/statsmodels/stats/moment_helpers.html
+    def _local_counts(mnc):
+        mean = mnc[0]
+        mnc = [1] + list(mnc)  # add zero moment = 1
+        mu = []
+        for n, m in enumerate(mnc):
+            mu.append(0)
+            for k in range(n + 1):
+                sgn_comb = (-1)**(n - k) * comb(n, k, exact=True)
+                mu[n] += sgn_comb * mnc[k] * mean**(n - k)
+        if wmean:
+            mu[1] = mean
+        return mu[1:]
+
+    res = jnp.apply_along_axis(_local_counts, 0, mnc)
+    # for backward compatibility convert 1-dim output to list/tuple
+    return res
+
+
+# Test simple distributions with no extra parameters
+dists = [stats.cauchy, stats.expon, stats.laplace, stats.logistic, stats.norm]
+# Tuple of `rtol` and `atol` for every tested moment
+moments_tol = {1: (0., 2e-1), 2: (3e-1, 0.), 3: (4e-1, 8e-1), 4: (4., 0.)}
+
+
+@pmp("distribution", dists)
+def test_moment_consistency(distribution, plot=False):
+    name = distribution.__name__.split('.')[-1]
+
+    max_tree_depth = 20
+    sampler = jft.NUTSChain(
+        potential_energy=lambda x: -1 * distribution.logpdf(x),
+        inverse_mass_matrix=1.,
+        position_proto=jnp.array(0.),
+        step_size=0.7193,
+        max_tree_depth=max_tree_depth,
+    )
+    chain, _ = sampler.generate_n_samples(
+        42, jnp.array(1.03890), num_samples=1000, save_intermediates=True
+    )
+
+    # unique, counts = jnp.unique(chain.depths, return_counts=True)
+    # depths_frequencies = jnp.asarray((unique, counts)).T
+
+    if plot is True:
+        import matplotlib.pyplot as plt
+
+        fig, axs = plt.subplots(1, 2)
+
+        bins = jnp.linspace(-10, 10)
+        if distribution is stats.expon:
+            bins = jnp.linspace(0, 10)
+        axs.flat[0].hist(
+            chain.samples, bins=bins, density=True, histtype="step"
+        )
+        axs.flat[0].plot(bins, distribution.pdf(bins), color='r')
+        axs.flat[0].set_title(f"{name} PDF")
+
+        axs.flat[1].hist(
+            chain.depths,
+            bins=jnp.arange(max_tree_depth + 1),
+            density=True,
+            histtype="step"
+        )
+        axs.flat[1].set_title(f"Tree-depth")
+        fig.tight_layout()
+        plt.show()
+
+    # central moments; except for the first (i.e. mean)
+    sample_moms_central = scipy.stats.moment(chain.samples, [1, 2, 3, 4, 5, 6])
+    sample_moms_central[0] = jnp.mean(chain.samples)
+
+    scipy_dist = getattr(scipy.stats, name)
+    dist_moms_non_central = jnp.array(
+        [scipy_dist.moment(i) for i in [1, 2, 3, 4, 5, 6]]
+    )
+    dist_moms_central = mnc2mc(dist_moms_non_central, wmean=True)
+
+    for i, (smpl_mom, dist_mom) in enumerate(
+        zip(sample_moms_central, dist_moms_central), start=1
+    ):
+        msg = (
+            f"{name} (moment {i}) :: sampled: {smpl_mom:+.2e}"
+            f" true: {dist_mom:+.2e} tested: "
+        )
+        print(msg, end="", file=sys.stderr)
+        test = not jnp.isnan(dist_mom)
+        test &= not (jnp.allclose(dist_mom, 0.) and i > 1)
+        if i in moments_tol and test:
+            assert_allclose(
+                dist_mom, smpl_mom,
+                **dict(zip(("rtol", "atol"), moments_tol[i]))
+            )
+            print("✓", file=sys.stderr)
+        else:
+            print("✗", file=sys.stderr)
+
+
+if __name__ == "__main__":
+    for d in dists:
+        test_moment_consistency(d, plot=True)
diff --git a/test/test_re/test_hmc_hashes.py b/test/test_re/test_hmc_hashes.py
new file mode 100644
index 0000000000000000000000000000000000000000..219951905c4d65ed738a74dfddd488ecd5169e03
--- /dev/null
+++ b/test/test_re/test_hmc_hashes.py
@@ -0,0 +1,90 @@
+import sys
+
+from jax import numpy as jnp
+from jax.config import config as jax_config
+from numpy import ndarray
+
+import nifty8.re as jft
+
+
+NDARRAY_TYPE = [ndarray]
+
+try:
+    from jax.numpy import ndarray as jndarray
+
+    NDARRAY_TYPE.append(jndarray)
+except ImportError:
+    pass
+
+NDARRAY_TYPE = tuple(NDARRAY_TYPE)
+
+
+def _json_serialize(obj):
+    if isinstance(obj, NDARRAY_TYPE):
+        return obj.tolist()
+    raise TypeError(f"unknown type {type(obj)}")
+
+
+def hashit(obj, n_chars=8) -> str:
+    """Get first `n_chars` characters of Blake2B hash of `obj`."""
+    import hashlib
+    import json
+
+    return hashlib.blake2b(
+        bytes(json.dumps(obj, default=_json_serialize), "utf-8")
+    ).hexdigest()[:n_chars]
+
+
+def test_hmc_hash():
+    """Test sapmler output against known hash from previous commits."""
+    x0 = jnp.array([0.1, 1.223], dtype=jnp.float32)
+    sampler = jft.HMCChain(
+        potential_energy=lambda x: jnp.sum(x**2),
+        inverse_mass_matrix=1.,
+        position_proto=x0,
+        step_size=0.193,
+        num_steps=100,
+        max_energy_difference=1.
+    )
+    chain, (key, pos) = sampler.generate_n_samples(
+        key=42, initial_position=x0, num_samples=1000, save_intermediates=True
+    )
+    assert chain.divergences.sum() == 0
+    accepted = chain.trees.accepted
+    results = (pos, key, chain.samples, accepted)
+    results_hash = hashit(results, n_chars=20)
+    print(f"full hash: {results_hash}", file=sys.stderr)
+    old_hash = "3d665689f809a98c81b3"
+    assert results_hash == old_hash
+
+
+def test_nuts_hash():
+    """Test sapmler output against known hash from previous commits."""
+    jax_config.update("jax_enable_x64", False)
+
+    x0 = jnp.array([0.1, 1.223], dtype=jnp.float32)
+    sampler = jft.NUTSChain(
+        potential_energy=lambda x: jnp.sum(x**2),
+        inverse_mass_matrix=1.,
+        position_proto=x0,
+        step_size=0.193,
+        max_tree_depth=10,
+        bias_transition=False,
+        max_energy_difference=1.
+    )
+    chain, (key, pos) = sampler.generate_n_samples(
+        key=42, initial_position=x0, num_samples=1000, save_intermediates=False
+    )
+    assert chain.divergences.sum() == 0
+    results = (pos, key, chain.samples)
+    results_hash = hashit(results, n_chars=20)
+    print(f"full hash: {results_hash}", file=sys.stderr)
+    old_hash = "8043850d7249acb77b26"
+    assert results_hash == old_hash
+
+    jax_config.update("jax_enable_x64", True)
+
+
+if __name__ == "__main__":
+    test_hmc_hash()
+    test_nuts_hash()
diff --git a/test/test_re/test_hmc_leapfrog.py b/test/test_re/test_hmc_leapfrog.py
new file mode 100644
index 0000000000000000000000000000000000000000..8062d2fc97e7374910edb4ff6970c171c0d3b27d
--- /dev/null
+++ b/test/test_re/test_hmc_leapfrog.py
@@ -0,0 +1,89 @@
+import pytest
+import sys
+from jax import grad
+from jax import numpy as jnp
+from numpy.testing import assert_allclose
+
+import nifty8.re as jft
+
+pmp = pytest.mark.parametrize
+
+pot_and_tol = (
+    (
+        lambda q: jnp.
+        sum(q.T @ jnp.linalg.inv(jnp.array([[1, 0.95], [0.95, 1]])) @ q / 2.),
+        0.2
+    ), (lambda q: -1 / jnp.linalg.norm(q), 2e-2)
+)
+
+
+@pmp("potential_energy, rtol", pot_and_tol)
+def test_leapfrog_energy_conservation(potential_energy, rtol):
+    dims = (2, )
+    mass_matrix = jnp.ones(shape=dims)
+    kinetic_energy = lambda p: jnp.sum(p**2 / mass_matrix / 2.)
+
+    potential_energy_gradient = grad(potential_energy)
+    positions = [jnp.array([-1.5, -1.55])]
+    momenta = [jnp.array([-1, 1])]
+    for _ in range(25):
+        new_qp = jft.hmc.leapfrog_step(
+            qp=jft.hmc.QP(position=positions[-1], momentum=momenta[-1]),
+            potential_energy_gradient=potential_energy_gradient,
+            kinetic_energy_gradient=lambda x, y: x * y,
+            step_size=0.25,
+            inverse_mass_matrix=1. / mass_matrix
+        )
+        positions.append(new_qp.position)
+        momenta.append(new_qp.momentum)
+
+    potential_energies = list(map(potential_energy, positions))
+    kinetic_energies = list(map(kinetic_energy, momenta))
+
+    jnp.set_printoptions(precision=2)
+    for q, p, e_kin, e_pot in zip(
+        positions, momenta, potential_energies, kinetic_energies
+    ):
+        msg = (
+            f"q: {q}; p: {p}"
+            f"\nE_tot: {e_pot+e_kin:.2e}; E_pot: {e_pot:.2e}; E_kin: {e_kin:.2e}"
+        )
+        print(msg, file=sys.stderr)
+
+    old_energy_tot = potential_energies[0] + kinetic_energies[0]
+    new_energy_tot = potential_energies[-1] + kinetic_energies[-1]
+    assert_allclose(old_energy_tot, new_energy_tot, rtol=rtol)
+
+    return positions, momenta, kinetic_energies, potential_energies
+
+
+if __name__ == "__main__":
+    import matplotlib.pyplot as plt
+
+    qs, ps, e_kins, e_pots = test_leapfrog_energy_conservation(*pot_and_tol[0])
+    positions = jnp.array(qs)
+    momenta = jnp.array(ps)
+    kinetic_energies = jnp.array(e_kins)
+    potential_energies = jnp.array(e_pots)
+
+    # Position Coordinates
+    plt.plot(positions[:, 0], positions[:, 1])
+    plt.xlabel("position[:,0]")
+    plt.ylabel("position[:,1]")
+    plt.show()
+
+    # Momentum coordinates
+    plt.plot(momenta[:, 0], momenta[:, 1])
+    plt.xlabel("momenta[:,0]")
+    plt.ylabel("momenta[:,1]")
+    plt.show()
+
+    # Value of Hamiltonian
+    # does not look exactly the same as in Neal (2011) unfortunately!
+    plt.plot(kinetic_energies, label='kin')
+    plt.plot(potential_energies, label='pot')
+    plt.plot(kinetic_energies + potential_energies, label='total')
+    plt.xlabel('time')
+    plt.ylabel('energy')
+    plt.legend()
+    plt.show()
diff --git a/test/test_re/test_hmc_pytree.py b/test/test_re/test_hmc_pytree.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a6f1e3c89e3fc1f86819834970cac3506075003
--- /dev/null
+++ b/test/test_re/test_hmc_pytree.py
@@ -0,0 +1,75 @@
+from functools import partial
+from jax import numpy as jnp
+from jax.tree_util import tree_leaves
+from numpy.testing import assert_array_equal
+
+import nifty8.re as jft
+
+
+def test_hmc_pytree():
+    """Test sapmler output against known hash from previous commits."""
+    initial_position = jnp.array([0.31415, 2.71828])
+
+    sampler_init = partial(
+        jft.HMCChain,
+        potential_energy=jft.sum_of_squares,
+        inverse_mass_matrix=1.,
+        step_size=0.193,
+        num_steps=100
+    )
+
+    initial_position_py = jft.Field(({"lvl0": initial_position}, ))
+    smpl_w_pytree = sampler_init(position_proto=initial_position_py
+                                ).generate_n_samples(
+                                    key=321,
+                                    initial_position=initial_position_py,
+                                    num_samples=1000
+                                )
+    smpl_wo_pytree = sampler_init(position_proto=initial_position
+                                 ).generate_n_samples(
+                                     key=321,
+                                     initial_position=initial_position,
+                                     num_samples=1000
+                                 )
+
+    ts_w, ts_wo = tree_leaves(smpl_w_pytree), tree_leaves(smpl_wo_pytree)
+    assert len(ts_w) == len(ts_wo)
+    for w, wo in zip(ts_w, ts_wo):
+        assert_array_equal(w, wo)
+
+
+def test_nuts_pytree():
+    """Test sapmler output against known hash from previous commits."""
+    initial_position = jnp.array([0.31415, 2.71828])
+
+    sampler_init = partial(
+        jft.NUTSChain,
+        potential_energy=jft.sum_of_squares,
+        inverse_mass_matrix=1.,
+        step_size=0.193,
+        max_tree_depth=10,
+    )
+
+    initial_position_py = jft.Field(({"lvl0": initial_position}, ))
+    smpl_w_pytree = sampler_init(position_proto=initial_position_py
+                                ).generate_n_samples(
+                                    key=323,
+                                    initial_position=initial_position_py,
+                                    num_samples=1000
+                                )
+    smpl_wo_pytree = sampler_init(position_proto=initial_position
+                                 ).generate_n_samples(
+                                     key=323,
+                                     initial_position=initial_position,
+                                     num_samples=1000
+                                 )
+
+    ts_w, ts_wo = tree_leaves(smpl_w_pytree), tree_leaves(smpl_wo_pytree)
+    assert len(ts_w) == len(ts_wo)
+    for w, wo in zip(ts_w, ts_wo):
+        assert_array_equal(w, wo)
+
+
+if __name__ == "__main__":
+    test_hmc_pytree()
+    test_nuts_pytree()
diff --git a/test/test_re/test_lanczos.py b/test/test_re/test_lanczos.py
new file mode 100644
index 0000000000000000000000000000000000000000..abc4edbac7510c90ccb281abba882bfd259ebae0
--- /dev/null
+++ b/test/test_re/test_lanczos.py
@@ -0,0 +1,71 @@
+#!/usr/bin/env python3
+
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial
+import sys
+from jax import random
+import jax.numpy as jnp
+import numpy as np
+from numpy.testing import assert_allclose
+import pytest
+from scipy.spatial import distance_matrix
+
+import nifty8.re as jft
+
+pmp = pytest.mark.parametrize
+
+
+def matern_kernel(distance, scale, cutoff, dof):
+    from jax.scipy.special import gammaln
+    from scipy.special import kv
+
+    reg_dist = jnp.sqrt(2 * dof) * distance / cutoff
+    cov = scale**2 * 2**(1 - dof) / jnp.exp(
+        gammaln(dof)
+    ) * (reg_dist)**dof * kv(dof, reg_dist)
+    # NOTE, this is not safe for differentiating because `cov` still may
+    # contain NaNs
+    return jnp.where(distance < 1e-8 * cutoff, scale**2, cov)
+
+
+from operator import matmul
+
+
+@pmp("seed", tuple(range(12, 44, 5)))
+@pmp("shape0", (128, 64))
+def test_lanczos_tridiag(seed, shape0):
+    rng = np.random.default_rng(seed)
+    rng_key = random.PRNGKey(rng.integers(12, 42))
+
+    m = rng.normal(size=(shape0, ) * 2)
+    m = m @ m.T  # ensure positive-definiteness
+
+    tridiag, vecs = jft.lanczos.lanczos_tridiag(
+        partial(matmul, m), jft.ShapeWithDtype((shape0, )), shape0, rng_key
+    )
+    m_est = vecs.T @ tridiag @ vecs
+
+    np.testing.assert_allclose(m_est, m, atol=1e-13, rtol=1e-13)
+
+
+@pmp("seed", tuple(range(12, 44, 5)))
+@pmp("shape0", (128, 64))
+def test_stochastic_lq_logdet(seed, shape0, lq_order=15, n_lq_samples=10):
+    rng = np.random.default_rng(seed)
+    rng_key = random.PRNGKey(rng.integers(12, 42))
+
+    c = np.exp(3 + rng.normal())
+    s = np.exp(rng.normal())
+
+    p = np.logspace(np.log(0.1 * c), np.log(1e+2 * c), num=shape0 - 1)
+    p = np.concatenate(([0], p)).reshape(-1, 1)
+
+    m = jnp.asarray(
+        matern_kernel(distance_matrix(p, p), cutoff=c, scale=s, dof=2.5)
+    )
+
+    _, logdet = jnp.linalg.slogdet(m)
+    logdet_est = jft.stochastic_lq_logdet(m, lq_order, n_lq_samples, rng_key)
+    assert_allclose(logdet_est, logdet, rtol=2., atol=20.)
+    print(f"{logdet=} :: {logdet_est=}", file=sys.stderr)
diff --git a/test/test_re/test_ncg.py b/test/test_re/test_ncg.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbd9737d189c1fa66c4ca5a65582d1c1a839ce52
--- /dev/null
+++ b/test/test_re/test_ncg.py
@@ -0,0 +1,229 @@
+#!/usr/bin/env python3
+
+import sys
+
+from jax import random, value_and_grad
+import jax.numpy as jnp
+from numpy.testing import assert_allclose
+import pytest
+
+import nifty8.re as jft
+
+pmp = pytest.mark.parametrize
+
+
+def rosenbrock(np):
+    def func(x):
+        return jnp.sum(100. * jnp.diff(x)**2 + (1. - x[:-1])**2)
+
+    return func
+
+
+def himmelblau(np):
+    def func(p):
+        x, y = p
+        return (x**2 + y - 11.)**2 + (x + y**2 - 7.)**2
+
+    return func
+
+
+def matyas(np):
+    def func(p):
+        x, y = p
+        return 0.26 * (x**2 + y**2) - 0.48 * x * y
+
+    return func
+
+
+def eggholder(np):
+    def func(p):
+        x, y = p
+        return -(y + 47) * jnp.sin(
+            jnp.sqrt(jnp.abs(x / 2. + y + 47.))
+        ) - x * jnp.sin(jnp.sqrt(jnp.abs(x - (y + 47.))))
+
+    return func
+
+
+def test_ncg_for_pytree():
+    pos = jft.Field(
+        [
+            jnp.array(0., dtype=jnp.float32),
+            (jnp.array(3., dtype=jnp.float32), ), {
+                "a": jnp.array(5., dtype=jnp.float32)
+            }
+        ]
+    )
+    getters = (lambda x: x[0], lambda x: x[1][0], lambda x: x[2]["a"])
+    tgt = [-10., 1., 2.]
+    met = [10., 40., 2]
+
+    def model(p):
+        losses = []
+        for i, get in enumerate(getters):
+            losses.append((get(p) - tgt[i])**2 * met[i])
+        return jnp.sum(jnp.array(losses))
+
+    def metric(p, tan):
+        m = []
+        m.append(tan[0] * met[0])
+        m.append((tan[1][0] * met[1], ))
+        m.append({"a": tan[2]["a"] * met[2]})
+        return jft.Field(m)
+
+    res = jft.newton_cg(
+        fun_and_grad=value_and_grad(model),
+        x0=pos,
+        hessp=metric,
+        maxiter=10,
+        absdelta=1e-6
+    )
+    for i, get in enumerate(getters):
+        assert_allclose(get(res), tgt[i], atol=1e-6, rtol=1e-5)
+
+
+@pmp("seed", (3637, 12, 42))
+def test_ncg(seed):
+    key = random.PRNGKey(seed)
+    x = random.normal(key, shape=(3, ))
+    diag = jnp.array([1., 2., 3.])
+    met = lambda y, t: t / diag
+    val_and_grad = lambda y: (
+        jnp.sum(y**2 / diag) / 2 - jnp.dot(x, y), y / diag - x
+    )
+
+    res = jft.newton_cg(
+        fun_and_grad=val_and_grad,
+        x0=x,
+        hessp=met,
+        maxiter=20,
+        absdelta=1e-6,
+        name='N'
+    )
+    assert_allclose(res, diag * x, rtol=1e-4, atol=1e-4)
+
+
+@pmp("seed", (3637, 12, 42))
+@pmp("cg", (jft.cg, jft.static_cg))
+def test_cg(seed, cg):
+    key = random.PRNGKey(seed)
+    sk = random.split(key, 2)
+    x = random.normal(sk[0], shape=(3, ))
+    # Avoid poorly conditioned matrices by shifting the elements from zero
+    diag = 6. + random.normal(sk[1], shape=(3, ))
+    mat = lambda x: x / diag
+
+    res, _ = cg(mat, x, resnorm=1e-5, absdelta=1e-5)
+    assert_allclose(res, diag * x, rtol=1e-4, atol=1e-4)
+
+
+@pmp("seed", (3637, 12, 42))
+@pmp("cg", (jft.cg, jft.static_cg))
+def test_cg_non_pos_def_failure(seed, cg):
+    key = random.PRNGKey(seed)
+    sk = random.split(key, 2)
+
+    x = random.normal(sk[0], shape=(4, ))
+    # Purposely produce a non-positive definite matrix
+    diag = jnp.concatenate(
+        (jnp.array([-1]), 6. + random.normal(sk[1], shape=(3, )))
+    )
+    mat = lambda x: x / diag
+
+    with pytest.raises(ValueError):
+        _, info = cg(mat, x, resnorm=1e-5, absdelta=1e-5)
+        if info < 0:
+            raise ValueError()
+
+
+@pmp("seed", (3637, 12, 42))
+def test_cg_steihaug(seed):
+    key = random.PRNGKey(seed)
+    sk = random.split(key, 2)
+    x = random.normal(sk[0], shape=(3, ))
+    # Avoid poorly conditioned matrices by shifting the elements from zero
+    diag = 6. + random.normal(sk[1], shape=(3, ))
+    mat = lambda x: x / diag
+
+    # Note, the solution to the subproblem with infinite trust radius is the CG
+    # but with the opposite sign
+    res = jft.conjugate_gradient._cg_steihaug_subproblem(
+        jnp.nan, -x, mat, resnorm=1e-6, trust_radius=jnp.inf
+    )
+    assert_allclose(res.step, diag * x, rtol=1e-4, atol=1e-4)
+
+
+@pmp("seed", (3637, 12, 42))
+@pmp("size", (5, 9, 14))
+def test_cg_steihaug_vs_cg_consistency(seed, size):
+    key = random.PRNGKey(seed)
+    sk = random.split(key, 2)
+
+    x = random.normal(sk[0], shape=(size, ))
+    # Avoid poorly conditioned matrices by shifting the elements from zero
+    mat_val = 6. + random.normal(sk[1], shape=(size, size))
+    mat_val = mat_val @ mat_val.T  # Construct a symmetric matrix
+    mat = lambda x: mat_val @ x
+
+    # Note, the solution to the subproblem with infinite trust radius is the CG
+    # but with the opposite sign
+    for i in range(4):
+        print(f"Iteratoin {i:02d}", file=sys.stderr)
+        res_cgs = jft.conjugate_gradient._cg_steihaug_subproblem(
+            jnp.nan,
+            -x,
+            mat,
+            resnorm=1e-6,
+            trust_radius=jnp.inf,
+            miniter=i,
+            maxiter=i
+        )
+        res_cg_plain, _ = jft.conjugate_gradient.cg(
+            mat, x, resnorm=1e-6, miniter=i, maxiter=i
+        )
+        assert_allclose(res_cgs.step, res_cg_plain, rtol=1e-4, atol=1e-5)
+
+
+@pmp(
+    "fun_and_init", (
+        (rosenbrock, jnp.zeros(2)), (himmelblau, jnp.zeros(2)),
+        (matyas, jnp.ones(2) * 6.), (eggholder, jnp.ones(2) * 100.)
+    )
+)
+@pmp("maxiter", (jnp.inf, None))
+def test_minimize(fun_and_init, maxiter):
+    from scipy.optimize import minimize as opt_minimize
+    from jax import grad, hessian
+
+    func, x0 = fun_and_init
+
+    def jft_minimize(x0):
+        result = jft.minimize(
+            func(jnp),
+            x0,
+            method='trust-ncg',
+            options=dict(
+                maxiter=maxiter,
+                energy_reduction_factor=None,
+                gtol=1e-6,
+                initial_trust_radius=1.,
+                max_trust_radius=1000.
+            ),
+        )
+        return result.x
+
+    def scp_minimize(x0):
+        # Use JAX primitives to take derivates
+        fun = func(jnp)
+        result = opt_minimize(
+            fun, x0, jac=grad(fun), hess=hessian(fun), method='trust-ncg'
+        )
+        return result.x
+
+    jax_res = jft_minimize(x0)
+    scipy_res = scp_minimize(x0)
+    assert_allclose(scipy_res, jax_res, rtol=2e-6, atol=2e-5)
+
+
+if __name__ == "__main__":
+    test_ncg_for_pytree()
diff --git a/test/test_re/test_refine.py b/test/test_re/test_refine.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ad0075073a0a44c1109617d34116fc4f3a83f29
--- /dev/null
+++ b/test/test_re/test_refine.py
@@ -0,0 +1,382 @@
+#!/usr/bin/env python3
+
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial
+import sys
+
+import jax
+from jax import random
+import jax.numpy as jnp
+from jax.tree_util import Partial
+import numpy as np
+from numpy.testing import assert_allclose
+import pytest
+from scipy.spatial import distance_matrix
+
+import nifty8.re as jft
+from nifty8.re import refine, refine_chart
+
+pmp = pytest.mark.parametrize
+
+
+def matern_kernel(distance, scale, cutoff, dof):
+    from jax.scipy.special import gammaln
+    from scipy.special import kv
+
+    reg_dist = jnp.sqrt(2 * dof) * distance / cutoff
+    return scale**2 * 2**(1 - dof) / jnp.exp(
+        gammaln(dof)
+    ) * (reg_dist)**dof * kv(dof, reg_dist)
+
+
+scale, cutoff, dof = 1., 80., 3 / 2
+
+x = jnp.logspace(-6, 11, base=jnp.e, num=int(1e+5))
+y = matern_kernel(x, scale, cutoff, dof)
+y = jnp.nan_to_num(y, nan=0.)
+kernel = Partial(jnp.interp, xp=x, fp=y)
+inv_kernel = Partial(jnp.interp, xp=y, fp=x)
+
+
+@pmp("dist", (10., 20., 30., 1e+3))
+def test_refinement_matrices_1d(dist, kernel=kernel):
+    cov_from_loc = refine._get_cov_from_loc(kernel=kernel)
+
+    coarse_coord = dist * jnp.array([0., 1., 2.])
+    fine_coord = coarse_coord[tuple(
+        jnp.array(coarse_coord.shape) // 2
+    )] + (jnp.diff(coarse_coord) / jnp.array([-4., 4.]))
+    cov_ff = cov_from_loc(fine_coord, fine_coord)
+    cov_fc = cov_from_loc(fine_coord, coarse_coord)
+    cov_cc_inv = jnp.linalg.inv(cov_from_loc(coarse_coord, coarse_coord))
+
+    fine_kernel = cov_ff - cov_fc @ cov_cc_inv @ cov_fc.T
+    fine_kernel_sqrt_diy = jnp.linalg.cholesky(fine_kernel)
+    olf_diy = cov_fc @ cov_cc_inv
+
+    olf, fine_kernel_sqrt = refine.layer_refinement_matrices(dist, kernel)
+
+    assert_allclose(olf, olf_diy)
+    assert_allclose(fine_kernel_sqrt, fine_kernel_sqrt_diy)
+
+
+@pmp("seed", (12, 42, 43, 45))
+@pmp("dist", (10., 20., 30., 1e+3))
+def test_refinement_1d(seed, dist, kernel=kernel):
+    rng = np.random.default_rng(seed)
+
+    refs = (
+        refine.refine_conv, refine.refine_conv_general, refine.refine_loop,
+        refine.refine_vmap, refine.refine_loop, refine.refine_slice
+    )
+    cov_from_loc = refine._get_cov_from_loc(kernel=kernel)
+    olf, fine_kernel_sqrt = refine.layer_refinement_matrices(dist, kernel)
+
+    main_coord = jnp.linspace(0., 1000., 50)
+    cov_sqrt = jnp.linalg.cholesky(cov_from_loc(main_coord, main_coord))
+    lvl0 = cov_sqrt @ rng.normal(size=main_coord.shape)
+    lvl1_exc = rng.normal(size=(2 * (lvl0.size - 2), ))
+
+    fine_reference = refine.refine(lvl0, lvl1_exc, olf, fine_kernel_sqrt)
+    eps = jnp.finfo(lvl0.dtype.type).eps
+    aallclose = partial(
+        assert_allclose, desired=fine_reference, rtol=6 * eps, atol=60 * eps
+    )
+    for ref in refs:
+        print(f"testing {ref.__name__}", file=sys.stderr)
+        aallclose(ref(lvl0, lvl1_exc, olf, fine_kernel_sqrt))
+
+
+@pmp("seed", (12, 42))
+@pmp("dist", (60., 1e+3, (80., 80.), (40., 90.), (1e+2, 1e+3, 1e+4)))
+@pmp("_coarse_size", (3, 5))
+@pmp("_fine_size", (2, 4))
+@pmp("_fine_strategy", ("jump", "extend"))
+def test_refinement_nd_cross_consistency(
+    seed, dist, _coarse_size, _fine_size, _fine_strategy, kernel=kernel
+):
+    ndim = len(dist) if hasattr(dist, "__len__") else 1
+    min_shape = (12, ) * ndim
+    depth = 1
+    refs = (refine.refine_conv_general, refine.refine_slice)
+    kwargs = {
+        "_coarse_size": _coarse_size,
+        "_fine_size": _fine_size,
+        "_fine_strategy": _fine_strategy
+    }
+
+    chart = refine_chart.CoordinateChart(
+        min_shape, depth=depth, distances=dist, **kwargs
+    )
+    rfm = refine_chart.RefinementField(chart).matrices(kernel)
+    xi = jft.random_like(
+        random.PRNGKey(seed),
+        refine_chart.RefinementField(chart).shapewithdtype
+    )
+
+    cf = partial(refine_chart.RefinementField.apply, chart=chart, kernel=rfm)
+    fine_reference = cf(xi)
+    eps = jnp.finfo(fine_reference.dtype.type).eps
+    aallclose = partial(
+        assert_allclose, desired=fine_reference, rtol=6 * eps, atol=60 * eps
+    )
+    for ref in refs:
+        print(f"testing {ref.__name__}", file=sys.stderr)
+        aallclose(cf(xi, _refine=ref))
+
+
+@pmp("dist", (60., 1e+3, (80., 80.), (40., 90.), (1e+2, 1e+3, 1e+4)))
+def test_refinement_fine_strategy_basic_consistency(dist, kernel=kernel):
+    olf_j, ks_j = refine.layer_refinement_matrices(
+        dist, kernel=kernel, _fine_size=2, _fine_strategy="jump"
+    )
+    olf_e, ks_e = refine.layer_refinement_matrices(
+        dist, kernel=kernel, _fine_size=2, _fine_strategy="extend"
+    )
+
+    assert_allclose(olf_j, olf_e, rtol=1e-13, atol=0.)
+    assert_allclose(ks_j, ks_e, rtol=1e-13, atol=0.)
+
+    shape0 = (12, ) * len(dist) if isinstance(dist, tuple) else (12, )
+    depth = 2
+    olfs_j, (csq0_j, kss_j) = refine.refinement_matrices(
+        shape0, depth, dist, kernel=kernel, _fine_strategy="jump"
+    )
+    olfs_e, (csq0_e, kss_e) = refine.refinement_matrices(
+        shape0, depth, dist, kernel=kernel, _fine_strategy="extend"
+    )
+
+    assert_allclose(olfs_j, olfs_e, rtol=1e-13, atol=0.)
+    assert_allclose(kss_j, kss_e, rtol=1e-13, atol=0.)
+    assert_allclose(csq0_j, csq0_e, rtol=1e-13, atol=0.)
+
+
+@pmp("dist", (60., 1e+3, (80., 80.), (40., 90.), (1e+2, 1e+3, 1e+4)))
+@pmp("_coarse_size", (3, 5))
+@pmp("_fine_size", (2, 4))
+@pmp("_fine_strategy", ("jump", "extend"))
+def test_refinement_covariance(
+    dist, _coarse_size, _fine_size, _fine_strategy, kernel=kernel
+):
+    distances0 = np.atleast_1d(dist)
+    ndim = len(distances0)
+
+    cf = refine_chart.RefinementField(
+        shape0=(_coarse_size, ) * ndim,
+        depth=1,
+        _coarse_size=_coarse_size,
+        _fine_size=_fine_size,
+        _fine_strategy=_fine_strategy,
+        distances0=distances0,
+        kernel=kernel
+    )
+    exc_shp = [
+        jft.ShapeWithDtype((_coarse_size, ) * ndim),
+        jft.ShapeWithDtype((_fine_size, ) * ndim)
+    ]
+    cf_shp = jax.eval_shape(cf, exc_shp)
+    assert cf_shp.shape == (_fine_size, ) * ndim
+
+    probe = jnp.zeros(cf_shp.shape)
+    indices = np.indices(cf_shp.shape).reshape(ndim, -1)
+    # Work around jax.linear_transpose NotImplementedError
+    _, cf_T = jax.vjp(cf, jft.zeros_like(exc_shp))
+    cf_cf_T = lambda x: cf(*cf_T(x))
+    cov_empirical = jax.vmap(
+        lambda idx: cf_cf_T(probe.at[tuple(idx)].set(1.)).ravel(),
+        in_axes=1,
+        out_axes=-1
+    )(indices)
+
+    pos = np.mgrid[tuple(slice(s) for s in cf_shp.shape)].astype(float)
+    if _fine_strategy == "jump":
+        pos *= distances0.reshape((-1, ) + (1, ) * ndim) / _fine_size
+    elif _fine_strategy == "extend":
+        pos *= distances0.reshape((-1, ) + (1, ) * ndim) / 2
+    else:
+        raise AssertionError(f"invalid `_fine_strategy`; {_fine_strategy}")
+    pos = jnp.moveaxis(pos, 0, -1)
+    p = pos.reshape(-1, ndim)
+    dist_mat = distance_matrix(p, p)
+    cov_truth = kernel(dist_mat)
+
+    assert_allclose(cov_empirical, cov_truth, rtol=1e-14, atol=1e-15)
+
+
+@pmp("seed", (12, 42, 43, 45))
+@pmp("n_dim", (1, 2, 3, 4, 5))
+def test_refinement_nd_shape(seed, n_dim, kernel=kernel):
+    rng = np.random.default_rng(seed)
+
+    distances = np.exp(rng.normal(size=(n_dim, )))
+    cov_from_loc = refine._get_cov_from_loc(kernel=kernel)
+    olf, fine_kernel_sqrt = refine.layer_refinement_matrices(distances, kernel)
+
+    shp_i = 5
+    gc = distances.reshape(n_dim, 1) * jnp.linspace(0., 1000., shp_i)
+    gc = jnp.stack(jnp.meshgrid(*gc, indexing="ij"), axis=-1).reshape(-1, n_dim)
+    cov_sqrt = jnp.linalg.cholesky(cov_from_loc(gc, gc))
+    lvl0 = (cov_sqrt @ rng.normal(size=gc.shape[0])).reshape((shp_i, ) * n_dim)
+    lvl1_exc = rng.normal(size=tuple(n - 2 for n in lvl0.shape) + (2**n_dim, ))
+
+    fine_reference = refine.refine(lvl0, lvl1_exc, olf, fine_kernel_sqrt)
+    assert fine_reference.shape == tuple((2 * (shp_i - 2), ) * n_dim)
+
+
+@pmp("dist", (60., 1e+3, (80., 80.), (40., 90.), (1e+2, 1e+3, 1e+4)))
+@pmp("_coarse_size", (3, 5))
+@pmp("_fine_size", (2, 4))
+@pmp("_fine_strategy", ("jump", "extend"))
+def test_chart_pixel_refinement_matrices_consistency(
+    dist, _coarse_size, _fine_size, _fine_strategy, kernel=kernel
+):
+    depth = 3
+    distances = np.atleast_1d(dist)
+    kwargs = {
+        "_coarse_size": _coarse_size,
+        "_fine_size": _fine_size,
+        "_fine_strategy": _fine_strategy
+    }
+
+    cc = refine_chart.CoordinateChart(
+        (12, ) * distances.size, depth=depth, distances=distances, **kwargs
+    )
+    olf, ks = refine_chart.RefinementField(cc).matrices_at(
+        level=depth, pixel_index=(0, ) * distances.size, kernel=kernel
+    )
+    olf_classical, ks_classical = refine.layer_refinement_matrices(
+        distances, kernel, **kwargs
+    )
+    assert_allclose(olf, olf_classical, atol=1e-14, rtol=1e-14)
+    assert_allclose(ks, ks_classical, atol=1e-14, rtol=1e-14)
+
+
+@pmp("dist", (60., 1e+3, (80., 80.), (40., 90.), (1e+2, 1e+3, 1e+4)))
+@pmp("_coarse_size", (3, 5))
+@pmp("_fine_size", (2, 4))
+@pmp("_fine_strategy", ("jump", "extend"))
+def test_chart_refinement_matrices_consistency(
+    dist, _coarse_size, _fine_size, _fine_strategy, kernel=kernel
+):
+    depth = 3
+    distances = np.atleast_1d(dist)
+    ndim = distances.size
+    kwargs = {
+        "_coarse_size": _coarse_size,
+        "_fine_size": _fine_size,
+        "_fine_strategy": _fine_strategy
+    }
+
+    cc = refine_chart.CoordinateChart(
+        (12, ) * ndim, depth=depth, distances=distances, **kwargs
+    )
+    refinement = refine_chart.RefinementField(cc).matrices(kernel=kernel)
+
+    cc_irreg = refine_chart.CoordinateChart(
+        shape0=cc.shape0,
+        depth=depth,
+        distances=distances,
+        irregular_axes=tuple(range(ndim)),
+        **kwargs
+    )
+    refinement_irreg = refine_chart.RefinementField(cc_irreg).matrices(
+        kernel=kernel
+    )
+
+    _, (cov_sqrt0, _) = refine.refinement_matrices(
+        cc.shape0, 0, cc.distances0, kernel, **kwargs
+    )
+
+    aallclose = partial(assert_allclose, rtol=1e-14, atol=1e-13)
+    aallclose(refinement.cov_sqrt0, cov_sqrt0)
+    aallclose(refinement_irreg.cov_sqrt0, cov_sqrt0)
+
+    for lvl in range(depth):
+        olf, ks = refinement.filter[lvl], refinement.propagator_sqrt[lvl]
+        olf_irreg, ks_irreg = refinement_irreg.filter[
+            lvl], refinement_irreg.propagator_sqrt[lvl]
+
+        if _fine_strategy == "jump":
+            distances_lvl = cc.distances0 / _fine_size**lvl
+        elif _fine_strategy == "extend":
+            distances_lvl = cc.distances0 / 2**lvl
+        else:
+            raise AssertionError()
+        olf_classical, ks_classical = refine.layer_refinement_matrices(
+            distances_lvl, kernel, **kwargs
+        )
+
+        aallclose(olf.squeeze(), olf_classical)
+        aallclose(ks.squeeze(), ks_classical)
+
+        olf_d = np.diff(
+            olf_irreg.reshape((-1, ) + olf_irreg.shape[-2:]), axis=0
+        )
+        ks_d = np.diff(ks_irreg.reshape((-1, ) + ks_irreg.shape[-2:]), axis=0)
+        aallclose(olf_d, 0.)
+        aallclose(ks_d, 0.)
+        aallclose(olf_irreg[(0, ) * ndim], olf_classical)
+        aallclose(ks_irreg[(0, ) * ndim], ks_classical)
+
+
+@pmp("seed", (12, ))
+@pmp("dist", (60., 1e+3, (80., 80.), (40., 90.), (1e+2, 1e+3, 1e+4)))
+@pmp("_coarse_size", (3, 5))
+@pmp("_fine_size", (2, 4))
+@pmp("_fine_strategy", ("jump", "extend"))
+@pmp("_refine", (refine.refine_conv_general, refine.refine_slice))
+def test_refinement_irregular_regular_consistency(
+    seed,
+    dist,
+    _coarse_size,
+    _fine_size,
+    _fine_strategy,
+    _refine,
+    kernel=kernel
+):
+    depth = 1
+    distances = np.atleast_1d(dist)
+    ndim = distances.size
+    kwargs = {
+        "_coarse_size": _coarse_size,
+        "_fine_size": _fine_size,
+        "_fine_strategy": _fine_strategy
+    }
+
+    cc = refine_chart.RefinementField(
+        shape0=(2 * _coarse_size, ) * ndim,
+        depth=depth,
+        distances=distances,
+        **kwargs
+    )
+    refinement = cc.matrices(kernel=kernel)
+
+    cc_irreg = refine_chart.RefinementField(
+        shape0=cc.chart.shape0,
+        depth=depth,
+        distances=distances,
+        irregular_axes=tuple(range(ndim)),
+        **kwargs
+    )
+    refinement_irreg = cc_irreg.matrices(kernel=kernel)
+
+    rng = np.random.default_rng(seed)
+    exc_swd = cc.shapewithdtype[-1]
+    fn1 = rng.normal(size=cc.chart.shape_at(depth - 1))
+    exc = rng.normal(size=exc_swd.shape)
+
+    refined = _refine(
+        fn1, exc, refinement.filter[-1], refinement.propagator_sqrt[-1],
+        **kwargs
+    )
+    refined_irreg = _refine(
+        fn1, exc, refinement_irreg.filter[-1],
+        refinement_irreg.propagator_sqrt[-1], **kwargs
+    )
+    assert_allclose(refined_irreg, refined, rtol=1e-14, atol=1e-13)
+
+
+if __name__ == "__main__":
+    test_refinement_matrices_1d(5.)
+    test_refinement_1d(42, 10.)
diff --git a/test/test_re/test_refine_util.py b/test/test_re/test_refine_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..36df85c0d076b512c1c0d541bf3ae8a3e8b94bd6
--- /dev/null
+++ b/test/test_re/test_refine_util.py
@@ -0,0 +1,80 @@
+#!/usr/bin/env python3
+
+# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
+
+from functools import partial
+
+import jax
+import numpy as np
+import pytest
+
+from nifty8.re import refine_chart, refine_util
+
+pmp = pytest.mark.parametrize
+
+
+@pmp("shape0", ((16, ), (13, 15), (11, 12, 13)))
+@pmp("depth", (1, 2))
+@pmp("_coarse_size", (3, 5, 7))
+@pmp("_fine_size", (2, 4, 6))
+@pmp("_fine_strategy", ("jump", "extend"))
+def test_shape_translations(
+    shape0, depth, _coarse_size, _fine_size, _fine_strategy
+):
+    kwargs = {
+        "_coarse_size": _coarse_size,
+        "_fine_size": _fine_size,
+        "_fine_strategy": _fine_strategy
+    }
+
+    def cf(shape0, xi):
+        chart = refine_chart.CoordinateChart(
+            shape0=shape0,
+            depth=depth,
+            distances0=(1., ) * len(shape0),
+            **kwargs
+        )
+        return refine_chart.RefinementField.apply(
+            xi, chart=chart, kernel=lambda x: x
+        )
+
+    dom = refine_util.get_refinement_shapewithdtype(shape0, depth, **kwargs)
+    tgt = jax.eval_shape(partial(cf, shape0), dom)
+    tgt_pred_shp = refine_util.coarse2fine_shape(shape0, depth, **kwargs)
+    assert tgt_pred_shp == tgt.shape
+    assert dom[-1].size == tgt.size == np.prod(tgt_pred_shp)
+
+    shape0_pred = refine_util.fine2coarse_shape(tgt.shape, depth, **kwargs)
+    dom_pred = refine_util.get_refinement_shapewithdtype(
+        shape0_pred, depth, **kwargs
+    )
+    tgt_pred = jax.eval_shape(partial(cf, shape0_pred), dom_pred)
+
+    assert tgt.shape == tgt_pred.shape
+    if _fine_strategy == "jump":
+        assert shape0_pred == shape0
+    else:
+        assert _fine_strategy == "extend"
+        assert all(s0_p <= s0 for s0_p, s0 in zip(shape0_pred, shape0))
+
+
+@pmp("seed", (42, 45))
+def test_gauss_kl(seed, n_resamples=100):
+    rng = np.random.default_rng(seed)
+    for _ in range(n_resamples):
+        d = max(rng.poisson(4), 1)
+        m_t = rng.normal(size=(d, d))
+        m_t = m_t @ m_t.T
+        scl = rng.lognormal(2., 3.)
+
+        np.testing.assert_allclose(
+            refine_util.gauss_kl(m_t, m_t), 0., atol=1e-11
+        )
+        kl_rhs_scl = 0.5 * d * (np.log(scl) + 1. / scl - 1.)
+        np.testing.assert_allclose(
+            kl_rhs_scl, refine_util.gauss_kl(m_t, scl * m_t), rtol=1e-11
+        )
+        kl_lhs_scl = 0.5 * d * (-np.log(scl) + scl - 1.)
+        np.testing.assert_allclose(
+            kl_lhs_scl, refine_util.gauss_kl(scl * m_t, m_t), rtol=1e-10
+        )
diff --git a/test/test_re/test_stats_distributions.py b/test/test_re/test_stats_distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9e021f84d5a903f6c23c59a14f8ad7edb6de41
--- /dev/null
+++ b/test/test_re/test_stats_distributions.py
@@ -0,0 +1,41 @@
+import numpy as np
+from numpy.testing import assert_allclose
+import pytest
+
+import nifty8.re as jft
+
+pmp = pytest.mark.parametrize
+
+
+@pmp("a", (3., 1.5, 4.))
+@pmp("scale", (2., 4.))
+@pmp("loc", (2., 4., 0.))
+@pmp("seed", (42, 43))
+def test_invgamma_roundtrip(a, scale, loc, seed, step=1e-1):
+    rng = np.random.default_rng(seed)
+
+    n_samples = int(1e+4)
+    n_rvs = rng.normal(loc=0., scale=2., size=(n_samples, ))
+    n_rvs = n_rvs.clip(-5.2, 5.2)
+
+    pr = jft.invgamma_prior(a, scale, loc=loc, step=step)
+    ipr = jft.invgamma_invprior(a, scale, loc=loc, step=step)
+
+    n_roundtrip = ipr(pr(n_rvs))
+    assert_allclose(n_roundtrip, n_rvs, rtol=1e-4, atol=1e-3)
+
+
+@pmp("mean", (2., 4.))
+@pmp("std", (2., 4.))
+@pmp("seed", (42, 43))
+def test_lognormal_roundtrip(mean, std, seed):
+    rng = np.random.default_rng(seed)
+
+    n_samples = int(1e+4)
+    n_rvs = rng.normal(loc=0., scale=2., size=(n_samples, ))
+
+    pr = jft.lognormal_prior(mean, std)
+    ipr = jft.lognormal_invprior(mean, std)
+
+    n_roundtrip = ipr(pr(n_rvs))
+    assert_allclose(n_roundtrip, n_rvs, rtol=1e-6, atol=1e-6)