Skip to content
Snippets Groups Projects

Re-MFSky

Open Philipp Frank requested to merge mf_sky into general_sky
8 unresolved threads

Some proposed reworkings for an MF Sky. @veberle, @gedenhof, @jroth, @wmarg: Happy for any feedback!

Note to self: The changes here should not be merged here, but to main directly.

Edited by Philipp Frank

Merge request reports

Ready to merge by members who can write to the target branch.
  • 0 commits and 1 merge commit will be added to .
  • Source branch will not be deleted.

Activity

Filter activity
  • Approvals
  • Assignees & reviewers
  • Comments (from bots)
  • Comments (from users)
  • Commits & branches
  • Edits
  • Labels
  • Lock status
  • Mentions
  • Merge request status
  • Tracking
src/re/special.py 0 → 100644
47
48 def apply(x):
49 return vmap(wp_with_drift, in_axes=(0,) * 3 + (None,), out_axes=~out_axes)(
50 jnp.broadcast_to(offset(x), shp).ravel(),
51 jnp.broadcast_to(drift(x), shp).ravel(),
52 jnp.broadcast_to(deviations(x), shp + (N_freqs - 1,)).reshape(
53 (offset.target.size, -1)
54 ),
55 freqs,
56 ).reshape(tot_shape)
57
58 model = Model(apply, init=offset.init | drift.init | deviations.init)
59 model.offset = offset
60 model.drift = drift
61 model.deviations = deviations
62 model.frequencies = freqs
  • Nice! Many thanks to everybody who contributed to this! I would love to see this in NIFTy soon :)

  • I looked at it last week and made some notes, which I will post here tomorrow.

    Edited by Vincent Eberle
  • src/re/special.py 0 → 100644
    22 def wp_with_drift(offst, drft, dev, frequencies):
    23 df = frequencies[1:] - frequencies[:-1]
    24 f0 = frequencies - frequencies[0]
    25 return wiener_process(dev, offst, sigma=1, dt=df) + drft * f0
    26
    27
    28 def MFSkyModel(
    29 offset: LazyModel,
    30 drift: LazyModel,
    31 deviations: LazyModel,
    32 in_axes: Any = None,
    33 out_axes=0,
    34 N_freqs: int = None,
    35 lim_freqs: tuple = None,
    36 base=np.log,
    37 _freqs=None,
  • src/re/special.py 0 → 100644
    1 # Copyright(C) 2024
    2 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
    3 # Authors: Margret Westerkamp, Vincent Eberle, Philipp Frank
    4
    5 import numpy as np
    6 import jax.numpy as jnp
    7 from typing import Any
    8 from jax import vmap
    9 from .gauss_markov import wiener_process
    10 from .model import LazyModel, Model, VModel
    11 from dataclasses import field
    12
    13
    14 def get_freqs(lim_freqs, N_freqs, base=np.log):
  • src/re/special.py 0 → 100644
    7 from typing import Any
    8 from jax import vmap
    9 from .gauss_markov import wiener_process
    10 from .model import LazyModel, Model, VModel
    11 from dataclasses import field
    12
    13
    14 def get_freqs(lim_freqs, N_freqs, base=np.log):
    15 freqmin = base(lim_freqs[0])
    16 freqmax = base(lim_freqs[1])
    17 if freqmax <= freqmin:
    18 raise ValueError(f"Frequencies of invalid range [{freqmin}, {freqmax}[")
    19 return (freqmax - freqmin) / N_freqs * np.arange(N_freqs) + freqmin
    20
    21
    22 def wp_with_drift(offst, drft, dev, frequencies):
  • src/re/special.py 0 → 100644
    13
    14 def get_freqs(lim_freqs, N_freqs, base=np.log):
    15 freqmin = base(lim_freqs[0])
    16 freqmax = base(lim_freqs[1])
    17 if freqmax <= freqmin:
    18 raise ValueError(f"Frequencies of invalid range [{freqmin}, {freqmax}[")
    19 return (freqmax - freqmin) / N_freqs * np.arange(N_freqs) + freqmin
    20
    21
    22 def wp_with_drift(offst, drft, dev, frequencies):
    23 df = frequencies[1:] - frequencies[:-1]
    24 f0 = frequencies - frequencies[0]
    25 return wiener_process(dev, offst, sigma=1, dt=df) + drft * f0
    26
    27
    28 def MFSkyModel(
  • src/re/special.py 0 → 100644
    9 from .gauss_markov import wiener_process
    10 from .model import LazyModel, Model, VModel
    11 from dataclasses import field
    12
    13
    14 def get_freqs(lim_freqs, N_freqs, base=np.log):
    15 freqmin = base(lim_freqs[0])
    16 freqmax = base(lim_freqs[1])
    17 if freqmax <= freqmin:
    18 raise ValueError(f"Frequencies of invalid range [{freqmin}, {freqmax}[")
    19 return (freqmax - freqmin) / N_freqs * np.arange(N_freqs) + freqmin
    20
    21
    22 def wp_with_drift(offst, drft, dev, frequencies):
    23 df = frequencies[1:] - frequencies[:-1]
    24 f0 = frequencies - frequencies[0]
    • Comment on lines +23 to +24

      IMHO hardcoding this into the model removes the capability of shifting the y-axis, which was possible in the "old" model

      My suggestions

      1. we could add this additional DOF (reference_frequency and frequencies)
      2. or the used does is in advance and we only pass f0 and df.

      open for other suggestions.

    • Please register or sign in to reply
  • src/re/special.py 0 → 100644
    38 ):
    39 shp = np.broadcast_shapes(
    40 offset.target.shape, drift.target.shape, deviations.target.shape
    41 )
    42 freqs = get_freqs(lim_freqs, N_freqs, base=base) if _freqs is None else _freqs
    43 N_freqs = freqs.size
    44 if in_axes is not None:
    45 deviations = VModel(deviations, N_freqs - 1, in_axes=in_axes, out_axes=-1)
    46 tot_shape = shp[:out_axes] + (N_freqs,) + shp[out_axes:]
    47
    48 def apply(x):
    49 return vmap(wp_with_drift, in_axes=(0,) * 3 + (None,), out_axes=~out_axes)(
    50 jnp.broadcast_to(offset(x), shp).ravel(),
    51 jnp.broadcast_to(drift(x), shp).ravel(),
    52 jnp.broadcast_to(deviations(x), shp + (N_freqs - 1,)).reshape(
    53 (offset.target.size, -1)
    • Comment on lines +52 to +53

      what do I need to put in for the deviations, if I want the sky model, with spatially correlated deviations, you proposed? probably a CF? Could you add this case to the demo?

    • Please register or sign in to reply
  • src/re/special.py 0 → 100644
    35 lim_freqs: tuple = None,
    36 base=np.log,
    37 _freqs=None,
    38 ):
    39 shp = np.broadcast_shapes(
    40 offset.target.shape, drift.target.shape, deviations.target.shape
    41 )
    42 freqs = get_freqs(lim_freqs, N_freqs, base=base) if _freqs is None else _freqs
    43 N_freqs = freqs.size
    44 if in_axes is not None:
    45 deviations = VModel(deviations, N_freqs - 1, in_axes=in_axes, out_axes=-1)
    46 tot_shape = shp[:out_axes] + (N_freqs,) + shp[out_axes:]
    47
    48 def apply(x):
    49 return vmap(wp_with_drift, in_axes=(0,) * 3 + (None,), out_axes=~out_axes)(
    50 jnp.broadcast_to(offset(x), shp).ravel(),
    • Comment on lines +48 to +50

      I really like the slim design, however, afaik this is more stiff than before? e.g. to remove the powerlaw I need to use:

      def drift(x):
          return 0.
      ``
    • Please register or sign in to reply
    Please register or sign in to reply
    Loading