Re-MFSky
8 unresolved threads
8 unresolved threads
Edited by Philipp Frank
Merge request reports
Activity
Filter activity
assigned to @pfrank
@pfrank : Thanks for making the effort of rewriting it. Can you resolve the Merge conflicts, so that we get a cleaner diff?
added 7 commits
- 7755c6eb - re.vmodel: Fix dict handling of VModel
- 45cbf1cc - VModel: More flexible application for "incompatible" domains
- 1dad42c4 - Merge branch 'subset_vmodel' into 'NIFTy_8'
- 2c01775d - general model for multifrequency astro sky
- 98dc0f1e - re.special: Add MF model to special models
- 6a2d0669 - re.special: Fixups
- c12a8429 - compare to prev impl
Toggle commit listrequested review from @veberle
- src/re/special.py 0 → 100644
47 48 def apply(x): 49 return vmap(wp_with_drift, in_axes=(0,) * 3 + (None,), out_axes=~out_axes)( 50 jnp.broadcast_to(offset(x), shp).ravel(), 51 jnp.broadcast_to(drift(x), shp).ravel(), 52 jnp.broadcast_to(deviations(x), shp + (N_freqs - 1,)).reshape( 53 (offset.target.size, -1) 54 ), 55 freqs, 56 ).reshape(tot_shape) 57 58 model = Model(apply, init=offset.init | drift.init | deviations.init) 59 model.offset = offset 60 model.drift = drift 61 model.deviations = deviations 62 model.frequencies = freqs I looked at it last week and made some notes, which I will post here tomorrow.
Edited by Vincent Eberle- src/re/special.py 0 → 100644
22 def wp_with_drift(offst, drft, dev, frequencies): 23 df = frequencies[1:] - frequencies[:-1] 24 f0 = frequencies - frequencies[0] 25 return wiener_process(dev, offst, sigma=1, dt=df) + drft * f0 26 27 28 def MFSkyModel( 29 offset: LazyModel, 30 drift: LazyModel, 31 deviations: LazyModel, 32 in_axes: Any = None, 33 out_axes=0, 34 N_freqs: int = None, 35 lim_freqs: tuple = None, 36 base=np.log, 37 _freqs=None, - src/re/special.py 0 → 100644
1 # Copyright(C) 2024 2 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause 3 # Authors: Margret Westerkamp, Vincent Eberle, Philipp Frank 4 5 import numpy as np 6 import jax.numpy as jnp 7 from typing import Any 8 from jax import vmap 9 from .gauss_markov import wiener_process 10 from .model import LazyModel, Model, VModel 11 from dataclasses import field 12 13 14 def get_freqs(lim_freqs, N_freqs, base=np.log): - src/re/special.py 0 → 100644
7 from typing import Any 8 from jax import vmap 9 from .gauss_markov import wiener_process 10 from .model import LazyModel, Model, VModel 11 from dataclasses import field 12 13 14 def get_freqs(lim_freqs, N_freqs, base=np.log): 15 freqmin = base(lim_freqs[0]) 16 freqmax = base(lim_freqs[1]) 17 if freqmax <= freqmin: 18 raise ValueError(f"Frequencies of invalid range [{freqmin}, {freqmax}[") 19 return (freqmax - freqmin) / N_freqs * np.arange(N_freqs) + freqmin 20 21 22 def wp_with_drift(offst, drft, dev, frequencies): - src/re/special.py 0 → 100644
13 14 def get_freqs(lim_freqs, N_freqs, base=np.log): 15 freqmin = base(lim_freqs[0]) 16 freqmax = base(lim_freqs[1]) 17 if freqmax <= freqmin: 18 raise ValueError(f"Frequencies of invalid range [{freqmin}, {freqmax}[") 19 return (freqmax - freqmin) / N_freqs * np.arange(N_freqs) + freqmin 20 21 22 def wp_with_drift(offst, drft, dev, frequencies): 23 df = frequencies[1:] - frequencies[:-1] 24 f0 = frequencies - frequencies[0] 25 return wiener_process(dev, offst, sigma=1, dt=df) + drft * f0 26 27 28 def MFSkyModel( - src/re/special.py 0 → 100644
9 from .gauss_markov import wiener_process 10 from .model import LazyModel, Model, VModel 11 from dataclasses import field 12 13 14 def get_freqs(lim_freqs, N_freqs, base=np.log): 15 freqmin = base(lim_freqs[0]) 16 freqmax = base(lim_freqs[1]) 17 if freqmax <= freqmin: 18 raise ValueError(f"Frequencies of invalid range [{freqmin}, {freqmax}[") 19 return (freqmax - freqmin) / N_freqs * np.arange(N_freqs) + freqmin 20 21 22 def wp_with_drift(offst, drft, dev, frequencies): 23 df = frequencies[1:] - frequencies[:-1] 24 f0 = frequencies - frequencies[0] - Comment on lines +23 to +24
IMHO hardcoding this into the model removes the capability of shifting the y-axis, which was possible in the "old" model
My suggestions
- we could add this additional DOF (reference_frequency and frequencies)
- or the used does is in advance and we only pass f0 and df.
open for other suggestions.
- src/re/special.py 0 → 100644
38 ): 39 shp = np.broadcast_shapes( 40 offset.target.shape, drift.target.shape, deviations.target.shape 41 ) 42 freqs = get_freqs(lim_freqs, N_freqs, base=base) if _freqs is None else _freqs 43 N_freqs = freqs.size 44 if in_axes is not None: 45 deviations = VModel(deviations, N_freqs - 1, in_axes=in_axes, out_axes=-1) 46 tot_shape = shp[:out_axes] + (N_freqs,) + shp[out_axes:] 47 48 def apply(x): 49 return vmap(wp_with_drift, in_axes=(0,) * 3 + (None,), out_axes=~out_axes)( 50 jnp.broadcast_to(offset(x), shp).ravel(), 51 jnp.broadcast_to(drift(x), shp).ravel(), 52 jnp.broadcast_to(deviations(x), shp + (N_freqs - 1,)).reshape( 53 (offset.target.size, -1) - src/re/special.py 0 → 100644
35 lim_freqs: tuple = None, 36 base=np.log, 37 _freqs=None, 38 ): 39 shp = np.broadcast_shapes( 40 offset.target.shape, drift.target.shape, deviations.target.shape 41 ) 42 freqs = get_freqs(lim_freqs, N_freqs, base=base) if _freqs is None else _freqs 43 N_freqs = freqs.size 44 if in_axes is not None: 45 deviations = VModel(deviations, N_freqs - 1, in_axes=in_axes, out_axes=-1) 46 tot_shape = shp[:out_axes] + (N_freqs,) + shp[out_axes:] 47 48 def apply(x): 49 return vmap(wp_with_drift, in_axes=(0,) * 3 + (None,), out_axes=~out_axes)( 50 jnp.broadcast_to(offset(x), shp).ravel(),