diff --git a/src/re/__init__.py b/src/re/__init__.py index 1c1b9b88721c737d667a0f20849d34597aa03692..d6e0efce1a5520e8eec0d4b4598358f5ba558af9 100644 --- a/src/re/__init__.py +++ b/src/re/__init__.py @@ -3,7 +3,7 @@ from .. import config from . import structured_kernel_interpolation from .conjugate_gradient import cg, static_cg -from .correlated_field import CorrelatedFieldMaker, non_parametric_amplitude +from .correlated_field import CorrelatedFieldMaker from .custom_map import lmap, smap from .evi import ( Samples, diff --git a/src/re/correlated_field.py b/src/re/correlated_field.py index bdcb620ec06c8dd45dcf4c42f4f18b4a581e3b99..52028afb3cdea9b64096d4a394242055bf75dd60 100644 --- a/src/re/correlated_field.py +++ b/src/re/correlated_field.py @@ -1,15 +1,16 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -# Authors: Gordian Edenhofer, Philipp Frank +# Authors: Gordian Edenhofer, Philipp Frank, +# Matteo Guardiani, Julian Rüstig import operator +import dataclasses from collections import namedtuple from collections.abc import Mapping -from functools import partial +from functools import partial, reduce from typing import Callable, Optional, Tuple, Union import numpy as np -from jax import numpy as jnp -from jax import vmap +from jax import numpy as jnp, vmap from ..config import _config from .gauss_markov import IntegratedWienerProcess @@ -267,7 +268,7 @@ def make_grid( msg = "`shape` must be length one. Its the nside of the spherical grid." raise ValueError(msg) nside = shape[0] - (m_length_idx, m_length, m_count), (lmax, mmax, size) = ( + ((m_length_idx, m_length, m_count), (lmax, mmax, size)) = ( get_spherical_mode_distributor(nside) ) um, log_vol = _log_modes(m_length) @@ -298,157 +299,225 @@ def _remove_slope(rel_log_mode_dist, x): return x - x[-1] * sc -def matern_amplitude( - grid, - scale: Callable, - cutoff: Callable, - loglogslope: Callable, - renormalize_amplitude: bool, - prefix: str = "", - kind: str = "amplitude", -) -> Model: - """Constructs a function computing the amplitude of a Matérn-kernel - power spectrum. - - See - :class:`nifty8.re.correlated_field.CorrelatedFieldMaker.add_fluctuations - _matern` - for more details on the parameters. - - See also - -------- - `Causal, Bayesian, & non-parametric modeling of the SARS-CoV-2 viral - load vs. patient's age`, Guardiani, Matteo and Frank, Philipp and Kostić, - Andrija and Edenhofer, Gordian and Roth, Jakob and Uhlmann, Berit and - Enßlin, Torsten, `<https://arxiv.org/abs/2105.13483>`_ - `<https://doi.org/10.1371/journal.pone.0275011>`_ - """ - totvol = grid.total_volume - mode_lengths = grid.harmonic_grid.mode_lengths - mode_multiplicity = grid.harmonic_grid.mode_multiplicity +class MaternAmplitude(Model): + cutoff: Callable = dataclasses.field(metadata=dict(static=False)) + loglogslope: Callable = dataclasses.field(metadata=dict(static=False)) + scale: Callable = dataclasses.field(metadata=dict(static=False)) + + def __init__( + self, + grid: Union[RegularCartesianGrid, RegularFourierGrid, HEALPixGrid, LMGrid], + scale: Optional[Callable], + cutoff: Callable, + loglogslope: Callable, + renormalize_amplitude: bool, + prefix: str = "", + kind: str = "amplitude", + ): + """Initializes a model that computes the amplitude of a Matérn-kernel + power spectrum. + + See + :class:`nifty8.re.correlated_field.CorrelatedFieldMaker.add_fluctuations + _matern` + for more details on the parameters. + + See also + -------- + `Causal, Bayesian, & non-parametric modeling of the SARS-CoV-2 viral + load vs. patient's age`, Guardiani, Matteo and Frank, Philipp and Kostić, + Andrija and Edenhofer, Gordian and Roth, Jakob and Uhlmann, Berit and + Enßlin, Torsten, `<https://arxiv.org/abs/2105.13483>`_ + `<https://doi.org/10.1371/journal.pone.0275011>`_ + """ - scale = WrappedCall(scale, name=prefix + "scale") - ptree = scale.domain.copy() - cutoff = WrappedCall(cutoff, name=prefix + "cutoff") - ptree.update(cutoff.domain) - loglogslope = WrappedCall(loglogslope, name=prefix + "loglogslope") - ptree.update(loglogslope.domain) + self.grid = grid + self.cutoff = WrappedCall(cutoff, name=prefix + "cutoff") + self.loglogslope = WrappedCall(loglogslope, name=prefix + "loglogslope") + self.scale = ( + WrappedCall(scale, name=prefix + "scale") if scale is not None else None + ) - def correlate(primals: Mapping) -> jnp.ndarray: - scl = scale(primals) - ctf = cutoff(primals) - slp = loglogslope(primals) + self.kind = kind.lower() - ln_spectrum = 0.25 * slp * jnp.log1p((mode_lengths / ctf) ** 2) + supported_kinds = {"amplitude", "power"} + if self.kind not in supported_kinds: + raise ValueError( + f"Invalid kind specified {self.kind!r}, " + f"supported kinds: {supported_kinds}" + ) + self.renormalize_amplitude = renormalize_amplitude + if self.renormalize_amplitude: + logger.warning("Renormalize amplidude is not yet tested!") + + models = [self.scale, self.cutoff, self.loglogslope] + domain = reduce(operator.or_, [m.domain for m in models if m is not None]) + + super().__init__(domain=domain, white_init=True) + + def __call__(self, primals: Mapping) -> jnp.ndarray: + if self.scale is None: + scl = 1.0 + else: + scl = self.scale(primals) + + ctf = self.cutoff(primals) + slp = self.loglogslope(primals) + + ln_spectrum = ( + 0.25 * slp * jnp.log1p((self.grid.harmonic_grid.mode_lengths / ctf) ** 2) + ) spectrum = jnp.exp(ln_spectrum) norm = 1.0 - if renormalize_amplitude: - logger.warning("Renormalize amplidude is not yet tested!") - if kind.lower() == "amplitude": - norm = jnp.sqrt(jnp.sum(mode_multiplicity[1:] * spectrum[1:] ** 4)) - elif kind.lower() == "power": - norm = jnp.sqrt(jnp.sum(mode_multiplicity[1:] * spectrum[1:] ** 2)) - norm /= jnp.sqrt(totvol) # Due to integral in harmonic space - spectrum = scl * (jnp.sqrt(totvol) / norm) * spectrum - spectrum = spectrum.at[0].set(totvol) - if kind.lower() == "power": + if self.renormalize_amplitude: + if self.kind == "amplitude": + norm = jnp.sqrt( + jnp.sum( + self.grid.harmonic_grid.mode_multiplicity[1:] + * spectrum[1:] ** 4 + ) + ) + elif self.kind == "power": + norm = jnp.sqrt( + jnp.sum( + self.grid.harmonic_grid.mode_multiplicity[1:] + * spectrum[1:] ** 2 + ) + ) + + norm /= jnp.sqrt( + self.grid.total_volume + ) # Due to integral in harmonic space + spectrum = scl * (jnp.sqrt(self.grid.total_volume) / norm) * spectrum + spectrum = spectrum.at[0].set(self.grid.total_volume) + if self.kind.lower() == "power": spectrum = jnp.sqrt(spectrum) - elif kind.lower() != "amplitude": - raise ValueError(f"invalid kind specified {kind!r}") + return spectrum - return Model(correlate, domain=ptree, init=partial(random_like, primals=ptree)) - - -def non_parametric_amplitude( - grid, - fluctuations: Callable, - loglogavgslope: Callable, - flexibility: Optional[Callable] = None, - asperity: Optional[Callable] = None, - prefix: str = "", - kind: str = "amplitude", -) -> Model: - """Constructs a function computing the amplitude of a non-parametric power - spectrum - - See - :class:`nifty8.re.correlated_field.CorrelatedFieldMaker.add_fluctuations` - for more details on the parameters. - - See also - -------- - `Variable structures in M87* from space, time and frequency resolved - interferometry`, Arras, Philipp and Frank, Philipp and Haim, Philipp - and Knollmüller, Jakob and Leike, Reimar and Reinecke, Martin and - Enßlin, Torsten, `<https://arxiv.org/abs/2002.05218>`_ - `<http://dx.doi.org/10.1038/s41550-021-01548-0>`_ - """ - totvol = grid.total_volume - rel_log_mode_len = grid.harmonic_grid.relative_log_mode_lengths - mode_multiplicity = grid.harmonic_grid.mode_multiplicity - log_vol = grid.harmonic_grid.log_volume - fluctuations = WrappedCall( - fluctuations, name=prefix + "fluctuations", white_init=True - ) - ptree = fluctuations.domain.copy() - loglogavgslope = WrappedCall( - loglogavgslope, name=prefix + "loglogavgslope", white_init=True - ) - ptree.update(loglogavgslope.domain) - if flexibility is not None and (log_vol.size > 0): - flexibility = WrappedCall( - flexibility, name=prefix + "flexibility", white_init=True +class NonParametricAmplitude(Model): + loglogavgslope: Callable = dataclasses.field(metadata=dict(static=False)) + fluctuations: Callable = dataclasses.field(metadata=dict(static=False)) + flexibility: Optional[Callable] = dataclasses.field(metadata=dict(static=False)) + asperity: Optional[Callable] = dataclasses.field(metadata=dict(static=False)) + + def __init__( + self, + grid: Union[RegularCartesianGrid, RegularFourierGrid, HEALPixGrid, LMGrid], + fluctuations: Optional[Callable], + loglogavgslope: Callable, + flexibility: Optional[Callable] = None, + asperity: Optional[Callable] = None, + prefix: str = "", + kind: str = "amplitude", + ): + """Initializes a model which computes the amplitude of a non-parametric + power spectrum. + + See + :class:`nifty8.re.correlated_field.CorrelatedFieldMaker.add_fluctuations` + for more details on the parameters. + + See also + -------- + `Variable structures in M87* from space, time and frequency resolved + interferometry`, Arras, Philipp and Frank, Philipp and Haim, Philipp + and Knollmüller, Jakob and Leike, Reimar and Reinecke, Martin and + Enßlin, Torsten, `<https://arxiv.org/abs/2002.05218>`_ + `<http://dx.doi.org/10.1038/s41550-021-01548-0>`_ + """ + self.grid = grid + log_vol = grid.harmonic_grid.log_volume + self.kind = kind.lower() + + supported_kinds = {"amplitude", "power"} + if self.kind not in supported_kinds: + raise ValueError( + f"Invalid kind specified {self.kind!r}, " + f"supported kinds: {supported_kinds}" + ) + + self._loglogavgslope = WrappedCall( + loglogavgslope, name=prefix + "loglogavgslope", white_init=True ) - assert log_vol is not None - assert rel_log_mode_len.ndim == log_vol.ndim == 1 - if asperity is not None: - asperity = WrappedCall(asperity, name=prefix + "asperity", white_init=True) - deviations = IntegratedWienerProcess( - jnp.zeros((2,)), - flexibility, - log_vol, - name=prefix + "spectrum", - asperity=asperity, + self.fluctuations = ( + WrappedCall(fluctuations, name=prefix + "fluctuations", white_init=True) + if fluctuations is not None + else None ) - ptree.update(deviations.domain) - else: - deviations = None + if flexibility is not None and (log_vol.size > 0): + flexibility = WrappedCall( + flexibility, name=prefix + "flexibility", white_init=True + ) + assert log_vol is not None + assert ( + self.grid.harmonic_grid.relative_log_mode_lengths.ndim + == log_vol.ndim + == 1 + ) + if asperity is not None: + asperity = WrappedCall( + asperity, name=prefix + "asperity", white_init=True + ) + self._deviations = IntegratedWienerProcess( + jnp.zeros((2,)), + flexibility, + log_vol, + name=prefix + "spectrum", + asperity=asperity, + ) + else: + self._deviations = None - def correlate(primals: Mapping) -> jnp.ndarray: - flu = fluctuations(primals) - slope = loglogavgslope(primals) - slope *= rel_log_mode_len + models = [ + self.fluctuations, + self._loglogavgslope, + self._deviations, + ] + domain = reduce(operator.or_, [m.domain for m in models if m is not None]) + + super().__init__(domain=domain, white_init=True) + + def __call__(self, primals: Mapping) -> jnp.ndarray: + mode_multiplicity = self.grid.harmonic_grid.mode_multiplicity + relative_log_mode_lengths = self.grid.harmonic_grid.relative_log_mode_lengths + + flu = 1.0 if self.fluctuations is None else self.fluctuations(primals) + slope = self._loglogavgslope(primals) + slope *= relative_log_mode_lengths ln_spectrum = slope - if deviations is not None: - twolog = deviations(primals) + if self._deviations is not None: + twolog = self._deviations(primals) # Prepend zeromode twolog = jnp.concatenate((jnp.zeros((1,)), twolog[:, 0])) - ln_spectrum += _remove_slope(rel_log_mode_len, twolog) + ln_spectrum += _remove_slope(relative_log_mode_lengths, twolog) # Exponentiate and norm the power spectrum spectrum = jnp.exp(ln_spectrum) + # Take the sqrt of the integral of the slope w/o fluctuations and # zero-mode while taking into account the multiplicity of each mode - if kind.lower() == "amplitude": + if self.kind == "amplitude": norm = jnp.sqrt(jnp.sum(mode_multiplicity[1:] * spectrum[1:] ** 2)) - norm /= jnp.sqrt(totvol) # Due to integral in harmonic space - amplitude = flu * (jnp.sqrt(totvol) / norm) * spectrum - elif kind.lower() == "power": + norm /= jnp.sqrt( + self.grid.total_volume + ) # Due to integral in harmonic space + amplitude = flu * (jnp.sqrt(self.grid.total_volume) / norm) * spectrum + elif self.kind == "power": norm = jnp.sqrt(jnp.sum(mode_multiplicity[1:] * spectrum[1:])) - norm /= jnp.sqrt(totvol) # Due to integral in harmonic space - amplitude = flu * (jnp.sqrt(totvol) / norm) * jnp.sqrt(spectrum) - else: - raise ValueError(f"invalid kind specified {kind!r}") - amplitude = amplitude.at[0].set(totvol) + norm /= jnp.sqrt( + self.grid.total_volume + ) # Due to integral in harmonic space + amplitude = ( + flu * (jnp.sqrt(self.grid.total_volume) / norm) * jnp.sqrt(spectrum) + ) + amplitude = amplitude.at[0].set(self.grid.total_volume) return amplitude - return Model(correlate, domain=ptree, init=partial(random_like, primals=ptree)) - class CorrelatedFieldMaker: """Construction helper for hierarchical correlated field models. @@ -579,7 +648,7 @@ class CorrelatedFieldMaker: te = f"invalid `asperity` specified; got '{type(asperity)}'" raise TypeError(te) - npa = non_parametric_amplitude( + npa = NonParametricAmplitude( grid=grid, fluctuations=flu, loglogavgslope=slp, @@ -675,7 +744,7 @@ class CorrelatedFieldMaker: te = f"invalid `loglogslope` specified; got '{type(loglogslope)}'" raise TypeError(te) - ma = matern_amplitude( + ma = MaternAmplitude( grid=grid, scale=scale, cutoff=cutoff,