diff --git a/src/re/__init__.py b/src/re/__init__.py
index 1c1b9b88721c737d667a0f20849d34597aa03692..d6e0efce1a5520e8eec0d4b4598358f5ba558af9 100644
--- a/src/re/__init__.py
+++ b/src/re/__init__.py
@@ -3,7 +3,7 @@
 from .. import config
 from . import structured_kernel_interpolation
 from .conjugate_gradient import cg, static_cg
-from .correlated_field import CorrelatedFieldMaker, non_parametric_amplitude
+from .correlated_field import CorrelatedFieldMaker
 from .custom_map import lmap, smap
 from .evi import (
     Samples,
diff --git a/src/re/correlated_field.py b/src/re/correlated_field.py
index bdcb620ec06c8dd45dcf4c42f4f18b4a581e3b99..52028afb3cdea9b64096d4a394242055bf75dd60 100644
--- a/src/re/correlated_field.py
+++ b/src/re/correlated_field.py
@@ -1,15 +1,16 @@
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
-# Authors: Gordian Edenhofer, Philipp Frank
+# Authors: Gordian Edenhofer, Philipp Frank,
+# Matteo Guardiani, Julian Rüstig
 
 import operator
+import dataclasses
 from collections import namedtuple
 from collections.abc import Mapping
-from functools import partial
+from functools import partial, reduce
 from typing import Callable, Optional, Tuple, Union
 
 import numpy as np
-from jax import numpy as jnp
-from jax import vmap
+from jax import numpy as jnp, vmap
 
 from ..config import _config
 from .gauss_markov import IntegratedWienerProcess
@@ -267,7 +268,7 @@ def make_grid(
             msg = "`shape` must be length one. Its the nside of the spherical grid."
             raise ValueError(msg)
         nside = shape[0]
-        (m_length_idx, m_length, m_count), (lmax, mmax, size) = (
+        ((m_length_idx, m_length, m_count), (lmax, mmax, size)) = (
             get_spherical_mode_distributor(nside)
         )
         um, log_vol = _log_modes(m_length)
@@ -298,157 +299,225 @@ def _remove_slope(rel_log_mode_dist, x):
     return x - x[-1] * sc
 
 
-def matern_amplitude(
-    grid,
-    scale: Callable,
-    cutoff: Callable,
-    loglogslope: Callable,
-    renormalize_amplitude: bool,
-    prefix: str = "",
-    kind: str = "amplitude",
-) -> Model:
-    """Constructs a function computing the amplitude of a Matérn-kernel
-    power spectrum.
-
-    See
-    :class:`nifty8.re.correlated_field.CorrelatedFieldMaker.add_fluctuations
-    _matern`
-    for more details on the parameters.
-
-    See also
-    --------
-    `Causal, Bayesian, & non-parametric modeling of the SARS-CoV-2 viral
-    load vs. patient's age`, Guardiani, Matteo and Frank, Philipp and Kostić,
-    Andrija and Edenhofer, Gordian and Roth, Jakob and Uhlmann, Berit and
-    Enßlin, Torsten, `<https://arxiv.org/abs/2105.13483>`_
-    `<https://doi.org/10.1371/journal.pone.0275011>`_
-    """
-    totvol = grid.total_volume
-    mode_lengths = grid.harmonic_grid.mode_lengths
-    mode_multiplicity = grid.harmonic_grid.mode_multiplicity
+class MaternAmplitude(Model):
+    cutoff: Callable = dataclasses.field(metadata=dict(static=False))
+    loglogslope: Callable = dataclasses.field(metadata=dict(static=False))
+    scale: Callable = dataclasses.field(metadata=dict(static=False))
+
+    def __init__(
+        self,
+        grid: Union[RegularCartesianGrid, RegularFourierGrid, HEALPixGrid, LMGrid],
+        scale: Optional[Callable],
+        cutoff: Callable,
+        loglogslope: Callable,
+        renormalize_amplitude: bool,
+        prefix: str = "",
+        kind: str = "amplitude",
+    ):
+        """Initializes a model that computes the amplitude of a Matérn-kernel
+        power spectrum.
+
+        See
+        :class:`nifty8.re.correlated_field.CorrelatedFieldMaker.add_fluctuations
+        _matern`
+        for more details on the parameters.
+
+        See also
+        --------
+        `Causal, Bayesian, & non-parametric modeling of the SARS-CoV-2 viral
+        load vs. patient's age`, Guardiani, Matteo and Frank, Philipp and Kostić,
+        Andrija and Edenhofer, Gordian and Roth, Jakob and Uhlmann, Berit and
+        Enßlin, Torsten, `<https://arxiv.org/abs/2105.13483>`_
+        `<https://doi.org/10.1371/journal.pone.0275011>`_
+        """
 
-    scale = WrappedCall(scale, name=prefix + "scale")
-    ptree = scale.domain.copy()
-    cutoff = WrappedCall(cutoff, name=prefix + "cutoff")
-    ptree.update(cutoff.domain)
-    loglogslope = WrappedCall(loglogslope, name=prefix + "loglogslope")
-    ptree.update(loglogslope.domain)
+        self.grid = grid
+        self.cutoff = WrappedCall(cutoff, name=prefix + "cutoff")
+        self.loglogslope = WrappedCall(loglogslope, name=prefix + "loglogslope")
+        self.scale = (
+            WrappedCall(scale, name=prefix + "scale") if scale is not None else None
+        )
 
-    def correlate(primals: Mapping) -> jnp.ndarray:
-        scl = scale(primals)
-        ctf = cutoff(primals)
-        slp = loglogslope(primals)
+        self.kind = kind.lower()
 
-        ln_spectrum = 0.25 * slp * jnp.log1p((mode_lengths / ctf) ** 2)
+        supported_kinds = {"amplitude", "power"}
+        if self.kind not in supported_kinds:
+            raise ValueError(
+                f"Invalid kind specified {self.kind!r}, "
+                f"supported kinds: {supported_kinds}"
+            )
 
+        self.renormalize_amplitude = renormalize_amplitude
+        if self.renormalize_amplitude:
+            logger.warning("Renormalize amplidude is not yet tested!")
+
+        models = [self.scale, self.cutoff, self.loglogslope]
+        domain = reduce(operator.or_, [m.domain for m in models if m is not None])
+
+        super().__init__(domain=domain, white_init=True)
+
+    def __call__(self, primals: Mapping) -> jnp.ndarray:
+        if self.scale is None:
+            scl = 1.0
+        else:
+            scl = self.scale(primals)
+
+        ctf = self.cutoff(primals)
+        slp = self.loglogslope(primals)
+
+        ln_spectrum = (
+            0.25 * slp * jnp.log1p((self.grid.harmonic_grid.mode_lengths / ctf) ** 2)
+        )
         spectrum = jnp.exp(ln_spectrum)
 
         norm = 1.0
-        if renormalize_amplitude:
-            logger.warning("Renormalize amplidude is not yet tested!")
-            if kind.lower() == "amplitude":
-                norm = jnp.sqrt(jnp.sum(mode_multiplicity[1:] * spectrum[1:] ** 4))
-            elif kind.lower() == "power":
-                norm = jnp.sqrt(jnp.sum(mode_multiplicity[1:] * spectrum[1:] ** 2))
-            norm /= jnp.sqrt(totvol)  # Due to integral in harmonic space
-        spectrum = scl * (jnp.sqrt(totvol) / norm) * spectrum
-        spectrum = spectrum.at[0].set(totvol)
-        if kind.lower() == "power":
+        if self.renormalize_amplitude:
+            if self.kind == "amplitude":
+                norm = jnp.sqrt(
+                    jnp.sum(
+                        self.grid.harmonic_grid.mode_multiplicity[1:]
+                        * spectrum[1:] ** 4
+                    )
+                )
+            elif self.kind == "power":
+                norm = jnp.sqrt(
+                    jnp.sum(
+                        self.grid.harmonic_grid.mode_multiplicity[1:]
+                        * spectrum[1:] ** 2
+                    )
+                )
+
+            norm /= jnp.sqrt(
+                self.grid.total_volume
+            )  # Due to integral in harmonic space
+        spectrum = scl * (jnp.sqrt(self.grid.total_volume) / norm) * spectrum
+        spectrum = spectrum.at[0].set(self.grid.total_volume)
+        if self.kind.lower() == "power":
             spectrum = jnp.sqrt(spectrum)
-        elif kind.lower() != "amplitude":
-            raise ValueError(f"invalid kind specified {kind!r}")
+
         return spectrum
 
-    return Model(correlate, domain=ptree, init=partial(random_like, primals=ptree))
-
-
-def non_parametric_amplitude(
-    grid,
-    fluctuations: Callable,
-    loglogavgslope: Callable,
-    flexibility: Optional[Callable] = None,
-    asperity: Optional[Callable] = None,
-    prefix: str = "",
-    kind: str = "amplitude",
-) -> Model:
-    """Constructs a function computing the amplitude of a non-parametric power
-    spectrum
-
-    See
-    :class:`nifty8.re.correlated_field.CorrelatedFieldMaker.add_fluctuations`
-    for more details on the parameters.
-
-    See also
-    --------
-    `Variable structures in M87* from space, time and frequency resolved
-    interferometry`, Arras, Philipp and Frank, Philipp and Haim, Philipp
-    and Knollmüller, Jakob and Leike, Reimar and Reinecke, Martin and
-    Enßlin, Torsten, `<https://arxiv.org/abs/2002.05218>`_
-    `<http://dx.doi.org/10.1038/s41550-021-01548-0>`_
-    """
-    totvol = grid.total_volume
-    rel_log_mode_len = grid.harmonic_grid.relative_log_mode_lengths
-    mode_multiplicity = grid.harmonic_grid.mode_multiplicity
-    log_vol = grid.harmonic_grid.log_volume
 
-    fluctuations = WrappedCall(
-        fluctuations, name=prefix + "fluctuations", white_init=True
-    )
-    ptree = fluctuations.domain.copy()
-    loglogavgslope = WrappedCall(
-        loglogavgslope, name=prefix + "loglogavgslope", white_init=True
-    )
-    ptree.update(loglogavgslope.domain)
-    if flexibility is not None and (log_vol.size > 0):
-        flexibility = WrappedCall(
-            flexibility, name=prefix + "flexibility", white_init=True
+class NonParametricAmplitude(Model):
+    loglogavgslope: Callable = dataclasses.field(metadata=dict(static=False))
+    fluctuations: Callable = dataclasses.field(metadata=dict(static=False))
+    flexibility: Optional[Callable] = dataclasses.field(metadata=dict(static=False))
+    asperity: Optional[Callable] = dataclasses.field(metadata=dict(static=False))
+
+    def __init__(
+        self,
+        grid: Union[RegularCartesianGrid, RegularFourierGrid, HEALPixGrid, LMGrid],
+        fluctuations: Optional[Callable],
+        loglogavgslope: Callable,
+        flexibility: Optional[Callable] = None,
+        asperity: Optional[Callable] = None,
+        prefix: str = "",
+        kind: str = "amplitude",
+    ):
+        """Initializes a model which computes the amplitude of a non-parametric
+        power spectrum.
+
+        See
+        :class:`nifty8.re.correlated_field.CorrelatedFieldMaker.add_fluctuations`
+        for more details on the parameters.
+
+        See also
+        --------
+        `Variable structures in M87* from space, time and frequency resolved
+        interferometry`, Arras, Philipp and Frank, Philipp and Haim, Philipp
+        and Knollmüller, Jakob and Leike, Reimar and Reinecke, Martin and
+        Enßlin, Torsten, `<https://arxiv.org/abs/2002.05218>`_
+        `<http://dx.doi.org/10.1038/s41550-021-01548-0>`_
+        """
+        self.grid = grid
+        log_vol = grid.harmonic_grid.log_volume
+        self.kind = kind.lower()
+
+        supported_kinds = {"amplitude", "power"}
+        if self.kind not in supported_kinds:
+            raise ValueError(
+                f"Invalid kind specified {self.kind!r}, "
+                f"supported kinds: {supported_kinds}"
+            )
+
+        self._loglogavgslope = WrappedCall(
+            loglogavgslope, name=prefix + "loglogavgslope", white_init=True
         )
-        assert log_vol is not None
-        assert rel_log_mode_len.ndim == log_vol.ndim == 1
-        if asperity is not None:
-            asperity = WrappedCall(asperity, name=prefix + "asperity", white_init=True)
-        deviations = IntegratedWienerProcess(
-            jnp.zeros((2,)),
-            flexibility,
-            log_vol,
-            name=prefix + "spectrum",
-            asperity=asperity,
+        self.fluctuations = (
+            WrappedCall(fluctuations, name=prefix + "fluctuations", white_init=True)
+            if fluctuations is not None
+            else None
         )
-        ptree.update(deviations.domain)
-    else:
-        deviations = None
+        if flexibility is not None and (log_vol.size > 0):
+            flexibility = WrappedCall(
+                flexibility, name=prefix + "flexibility", white_init=True
+            )
+            assert log_vol is not None
+            assert (
+                self.grid.harmonic_grid.relative_log_mode_lengths.ndim
+                == log_vol.ndim
+                == 1
+            )
+            if asperity is not None:
+                asperity = WrappedCall(
+                    asperity, name=prefix + "asperity", white_init=True
+                )
+            self._deviations = IntegratedWienerProcess(
+                jnp.zeros((2,)),
+                flexibility,
+                log_vol,
+                name=prefix + "spectrum",
+                asperity=asperity,
+            )
+        else:
+            self._deviations = None
 
-    def correlate(primals: Mapping) -> jnp.ndarray:
-        flu = fluctuations(primals)
-        slope = loglogavgslope(primals)
-        slope *= rel_log_mode_len
+        models = [
+            self.fluctuations,
+            self._loglogavgslope,
+            self._deviations,
+        ]
+        domain = reduce(operator.or_, [m.domain for m in models if m is not None])
+
+        super().__init__(domain=domain, white_init=True)
+
+    def __call__(self, primals: Mapping) -> jnp.ndarray:
+        mode_multiplicity = self.grid.harmonic_grid.mode_multiplicity
+        relative_log_mode_lengths = self.grid.harmonic_grid.relative_log_mode_lengths
+
+        flu = 1.0 if self.fluctuations is None else self.fluctuations(primals)
+        slope = self._loglogavgslope(primals)
+        slope *= relative_log_mode_lengths
         ln_spectrum = slope
 
-        if deviations is not None:
-            twolog = deviations(primals)
+        if self._deviations is not None:
+            twolog = self._deviations(primals)
             # Prepend zeromode
             twolog = jnp.concatenate((jnp.zeros((1,)), twolog[:, 0]))
-            ln_spectrum += _remove_slope(rel_log_mode_len, twolog)
+            ln_spectrum += _remove_slope(relative_log_mode_lengths, twolog)
 
         # Exponentiate and norm the power spectrum
         spectrum = jnp.exp(ln_spectrum)
+
         # Take the sqrt of the integral of the slope w/o fluctuations and
         # zero-mode while taking into account the multiplicity of each mode
-        if kind.lower() == "amplitude":
+        if self.kind == "amplitude":
             norm = jnp.sqrt(jnp.sum(mode_multiplicity[1:] * spectrum[1:] ** 2))
-            norm /= jnp.sqrt(totvol)  # Due to integral in harmonic space
-            amplitude = flu * (jnp.sqrt(totvol) / norm) * spectrum
-        elif kind.lower() == "power":
+            norm /= jnp.sqrt(
+                self.grid.total_volume
+            )  # Due to integral in harmonic space
+            amplitude = flu * (jnp.sqrt(self.grid.total_volume) / norm) * spectrum
+        elif self.kind == "power":
             norm = jnp.sqrt(jnp.sum(mode_multiplicity[1:] * spectrum[1:]))
-            norm /= jnp.sqrt(totvol)  # Due to integral in harmonic space
-            amplitude = flu * (jnp.sqrt(totvol) / norm) * jnp.sqrt(spectrum)
-        else:
-            raise ValueError(f"invalid kind specified {kind!r}")
-        amplitude = amplitude.at[0].set(totvol)
+            norm /= jnp.sqrt(
+                self.grid.total_volume
+            )  # Due to integral in harmonic space
+            amplitude = (
+                flu * (jnp.sqrt(self.grid.total_volume) / norm) * jnp.sqrt(spectrum)
+            )
+        amplitude = amplitude.at[0].set(self.grid.total_volume)
         return amplitude
 
-    return Model(correlate, domain=ptree, init=partial(random_like, primals=ptree))
-
 
 class CorrelatedFieldMaker:
     """Construction helper for hierarchical correlated field models.
@@ -579,7 +648,7 @@ class CorrelatedFieldMaker:
             te = f"invalid `asperity` specified; got '{type(asperity)}'"
             raise TypeError(te)
 
-        npa = non_parametric_amplitude(
+        npa = NonParametricAmplitude(
             grid=grid,
             fluctuations=flu,
             loglogavgslope=slp,
@@ -675,7 +744,7 @@ class CorrelatedFieldMaker:
             te = f"invalid `loglogslope` specified; got '{type(loglogslope)}'"
             raise TypeError(te)
 
-        ma = matern_amplitude(
+        ma = MaternAmplitude(
             grid=grid,
             scale=scale,
             cutoff=cutoff,