diff --git a/src/re/__init__.py b/src/re/__init__.py index c6e4495617f6f4c003fa533577b6c1891d479414..1c1b9b88721c737d667a0f20849d34597aa03692 100644 --- a/src/re/__init__.py +++ b/src/re/__init__.py @@ -39,8 +39,10 @@ from .multi_grid import ( Grid, HEALPixGrid, HPLogRGrid, + HPBrokenLogRGrid, ICRField, - LogarithmicGrid, + LogGrid, + BrokenLogGrid, MGrid, SimpleOpenGrid, ) diff --git a/src/re/multi_grid/__init__.py b/src/re/multi_grid/__init__.py index ecbcc7ff3064281b118de16b41767b9a9ef8d787..07d395a5cfa47d5fc20426b01bd15f2a64f5161e 100644 --- a/src/re/multi_grid/__init__.py +++ b/src/re/multi_grid/__init__.py @@ -2,4 +2,4 @@ from .correlated_field import ICRField from .grid import Grid, MGrid -from .grid_impl import HEALPixGrid, HPLogRGrid, LogarithmicGrid, SimpleOpenGrid +from .grid_impl import HEALPixGrid, HPLogRGrid, LogGrid, SimpleOpenGrid, BrokenLogGrid, HPBrokenLogRGrid diff --git a/src/re/multi_grid/grid.py b/src/re/multi_grid/grid.py index 0f559c05fb8b134c8d73d0d8f683e5ec959a61c7..9f1a20ae2c4ec90790b83c3f9415a985a550a1a9 100644 --- a/src/re/multi_grid/grid.py +++ b/src/re/multi_grid/grid.py @@ -112,8 +112,11 @@ class GridAtLevel: def coord2index(self, coord, dtype=np.uint64): slc = (slice(None),) + (np.newaxis,) * (coord.ndim - 1) - # TODO type casting - return (coord * self.shape[slc] - 0.5).astype(dtype) + index = (coord * self.shape[slc] - 0.5) + if np.issubdtype(dtype, np.integer): + return np.rint(index).astype(dtype) + else: + raise ValueError(f"non-integer index dtype: {dtype}") def index2volume(self, index): return np.array(1.0 / self.size)[(np.newaxis,) * index.ndim] @@ -234,10 +237,11 @@ class OpenGridAtLevel(GridAtLevel): def coord2index(self, coord, dtype=np.uint64): slc = (slice(None),) + (np.newaxis,) * (coord.ndim - 1) shp = self.shape + 2 * self.shifts - # TODO type index = coord * shp[slc] - self.shifts[slc] - 0.5 - # assert jnp.all(index >= 0) - return index.astype(dtype) + if np.issubdtype(dtype, np.integer): + return np.rint(index).astype(dtype) + else: + raise ValueError(f"non-integer index dtype: {dtype}") def index2volume(self, index): sz = np.prod(self.shape + 2 * self.shifts) diff --git a/src/re/multi_grid/grid_impl.py b/src/re/multi_grid/grid_impl.py index 9931bece9643ba8c0db86e631702c4629b3f536f..f307bca24191759035790b7eb50b7297ae1a1e5e 100644 --- a/src/re/multi_grid/grid_impl.py +++ b/src/re/multi_grid/grid_impl.py @@ -21,7 +21,6 @@ from .grid import Grid, GridAtLevel, MGrid, MGridAtLevel, OpenGrid, OpenGridAtLe class HEALPixGridAtLevel(GridAtLevel): nside: int nest: bool - fill_strategy: str def __init__( self, @@ -113,24 +112,56 @@ class HEALPixGrid(Grid): ): """HEALPix pixelization grid. - The grid can be defined either via the final nside, the initial nside - (nside0), or the initial shape (shape0). + The grid can be defined either via depth and exactly one of (nside0, nside, shape0) + or via nside and one of (nside0, shape0). """ self.nest = nest + + assert shape0 is None or isinstance(shape0, (int, tuple, np.ndarray)) if shape0 is not None: assert nside0 is None - assert isinstance(shape0, int) or np.ndim(shape0) == 0 - shape0 = shape0[0] if np.ndim(shape0) > 0 else shape0 + shape0 = np.asarray(shape0).ravel() + assert shape0.size == 1, ( + "shape0 must be a scalar or a single-element array/tuple" + ) + (shape0,) = shape0 + assert isinstance(shape0, int) + # Check whether the shape is a valid HEALPix shape + assert shape0 > 0 and shape0 % 12 == 0 nside0 = (shape0 / 12) ** 0.5 - assert int(nside0) == nside0 - nside0 = int(nside0) - if nside is not None: - assert nside0 is None - assert depth is not None - nside0 = nside / 2**depth - assert int(nside0) == nside0 + assert np.isclose(nside0, round(nside0), atol=1.0e-10) + nside0 = round(nside0) + + assert nside is None or ( + isinstance(nside, int) + and nside > 0 + and (nside & (nside - 1)) == 0 # power of 2 + ) + assert nside0 is None or ( + isinstance(nside0, int) + and nside0 > 0 + and (nside0 & (nside0 - 1)) == 0 # power of 2 + ) + assert depth is None or (isinstance(depth, int) and depth >= 0) + + if depth is not None: + if (nside0 is None) == (nside is None): + raise ValueError( + "Ambiguous initialisation of HEALPixGrid. If depth is given, please supply exactly one of (nside0, nside, shape0)" + ) + if nside is not None: + nside0 = nside // 2**depth + else: + if (nside is None) or (nside0 is None): + raise ValueError( + "Ambiguous initialisation of HEALPixGrid. If depth is not given, please supply nside and exactly one of (nside0, shape0)" + ) + assert nside0 <= nside + depth = np.log2(nside / nside0) + assert np.isclose(depth, round(depth), atol=1.0e-10) + depth = round(depth) + self.nside0 = nside0 - assert self.nside0 > 0 if splits is None: splits = (4,) * depth super().__init__( @@ -181,7 +212,7 @@ class SimpleOpenGridAtLevel(OpenGridAtLevel): def coord2index(self, coord, dtype=np.uint64): bc = (slice(None),) + (np.newaxis,) * (coord.ndim - 1) coord = coord / ((self.shape + 2 * self.shifts) * self.distances)[bc] - return super().coord2index(self, coord, dtype=dtype) + return super().coord2index(coord, dtype=dtype) def index2volume(self, index): vol = super().index2volume(index) @@ -267,7 +298,7 @@ def SimpleOpenGrid( ) -class LogarithmicGridAtLevel(SimpleOpenGridAtLevel): +class LogGridAtLevel(SimpleOpenGridAtLevel): def __init__(self, *args, coord_offset, coord_scale, **kwargs): # NOTE, technically `coord_offset` and `coord_scale` are redundant with # `shifts` and `distances`, however, for ease of use, we first let them @@ -290,15 +321,15 @@ class LogarithmicGridAtLevel(SimpleOpenGridAtLevel): def coord2index(self, coord, dtype=np.uint64): coord = (jnp.log(coord) - self.coord_offset) / self.coord_scale - return super().coord2index(self, coord, dtype=dtype) + return super().coord2index(coord, dtype=dtype) def index2volume(self, index): a = (slice(None),) + (np.newaxis,) * index.ndim - coords = super().index2coord(index + jnp.array([-0.5, 0.5])[a]) - return jnp.prod(coords[1] - coords[0], axis=0) + coords = self.index2coord(index + jnp.array([-0.5, 0.5])[a]) + return jnp.prod(coords[1] - coords[0], axis=0, keepdims=True) -def LogarithmicGrid( +def LogGrid( *, r_min: float, r_max: float, @@ -310,19 +341,19 @@ def LogarithmicGrid( """ if distances is not None: raise ValueError("`distances` are incompatible with a logarithmic grid") - if r_min < 0.0 or r_max < r_min: + if r_min <= 0.0 or r_max <= r_min: raise ValueError(f"invalid r_min {r_min!r} or r_max {r_max!r}") coord_offset = np.log(r_min) coord_scale = np.log(r_max) - coord_offset return SimpleOpenGrid( **kwargs, atLevel=partial( - LogarithmicGridAtLevel, coord_offset=coord_offset, coord_scale=coord_scale + LogGridAtLevel, coord_offset=coord_offset, coord_scale=coord_scale ), ) -class HPLogRGridAtLevel(MGridAtLevel): +class HPRadialGridAtLevel(MGridAtLevel): def index2coord(self, index, **kwargs): coords = super().index2coord(index, **kwargs) return coords[:3] * coords[3] @@ -333,6 +364,13 @@ class HPLogRGridAtLevel(MGridAtLevel): coord = jnp.concatenate((coord / r, r), axis=0) return super().coord2index(coord, **kwargs) + def index2volume(self, index): + grid_hp, grid_r = self.grids + r_upper = grid_r.index2coord(index[1:2] + 0.5) + r_lower = grid_r.index2coord(index[1:2] - 0.5) + A_unity = grid_hp.index2volume(index[0:1]) + return A_unity * (r_upper**3 - r_lower**3) / 3 + def HPLogRGrid( min_shape: Optional[Tuple[int, int]] = None, @@ -343,11 +381,193 @@ def HPLogRGrid( r_max, r_window_size=3, nside0=16, - atLevel=HPLogRGridAtLevel, + atLevel=HPRadialGridAtLevel, ) -> MGrid: """Meshgrid of a HEALPix grid and a logarithmic grid. - See `HEALPixGrid` and `LogarithmicGrid`.""" + See `HEALPixGrid` and `LogGrid`.""" + if r_min_shape is None and nside is None: + hp_size, r_min_shape = min_shape + nside = (hp_size / 12) ** 0.5 + depth = np.log2(nside / nside0) + assert depth == int(depth) + depth = int(depth) + grid_hp = HEALPixGrid(nside0=nside0, depth=depth) + grid_r = LogGrid( + min_shape=r_min_shape, + r_min=r_min, + r_max=r_max, + window_size=r_window_size, + depth=depth, + ) + return MGrid(grid_hp, grid_r, atLevel=atLevel) + + +class BrokenLogGridAtLevel(SimpleOpenGridAtLevel): + def __init__( + self, + *args, + alpha, + beta, + gamma, + delta, + epsilon, + r_min, + r_linthresh, + r_max, + rg_min, + rg_linthresh, + rg_max, + **kwargs, + ): + self._alpha = alpha + self._beta = beta + self._gamma = gamma + self._delta = delta + self._epsilon = epsilon + self._r_min = r_min + self._r_linthresh = r_linthresh + self._r_max = r_max + self._rg_min = rg_min + self._rg_linthresh = rg_linthresh + self._rg_max = rg_max + super().__init__(*args, **kwargs) + + @property + def r_min(self): + return self.index2coord(np.array([-0.5])) + + @property + def r_max(self): + return self.index2coord(np.array([self.shape[0] - 0.5])) + + def index2coord(self, index): + # map to in-between 0 and 1 + coord = super().index2coord(index) + # map to in-between r_min and r_max + condlist = [ + coord < self._rg_min, + (self._rg_min <= coord) & (coord < self._rg_linthresh), + (self._rg_linthresh <= coord) & (coord < self._rg_max), + self._rg_max <= coord, + ] + funclist = [ + lambda rg: self._gamma / (rg - self._delta), + lambda rg: self._r_min + self._alpha * (rg - self._rg_min), + lambda rg: self._r_linthresh + * jnp.exp(self._beta * (rg - self._rg_linthresh)), + lambda rg: self._r_max + self._epsilon * (rg - self._rg_max), + ] + return jnp.piecewise(coord, condlist, funclist) + + def coord2index(self, coord, dtype=np.uint64): + # map to in-between 0 and 1 + condlist = [ + coord < self._r_min, + (self._r_min <= coord) & (coord < self._r_linthresh), + (self._r_linthresh <= coord) & (coord < self._r_max), + self._r_max <= coord, + ] + funclist = [ + lambda r: self._delta + self._gamma / r, + lambda r: self._rg_min + (r - self._r_min) / self._alpha, + lambda r: self._rg_linthresh + jnp.log(r / self._r_linthresh) / self._beta, + lambda r: self._rg_max + (r - self._r_max) / self._epsilon, + ] + coord = jnp.piecewise(coord, condlist, funclist) + # transform to index + return super().coord2index(coord, dtype=dtype) + + def index2volume(self, index): + a = (slice(None),) + (np.newaxis,) * index.ndim + coords = self.index2coord(index + jnp.array([-0.5, 0.5])[a]) + return jnp.prod(coords[1] - coords[0], axis=0, keepdims=True) + + +def BrokenLogGrid( + *, + r_min: float, + r_linthresh: float, + r_max: float, + distances=None, + **kwargs, +) -> OpenGrid: + """Create a broken logarithmic grid on top of `SimpleOpenGrid` spanning from + `r_min` to `r_max` at the final depth. + The grid is parametrised by three radii: r_min, r_linthresh, and r_max. + Between r_min and r_linthresh pixels are spaced linearly (r). + Between r_linthresh and r_max pixels are spaced logarithmically (exp(r)). + + For available parameters see the `SimpleOpenGrid` docstring in addition the ones below. + + Parameters + ---------- + r_min: + Minimum coordinate value. + r_linthresh: + Coordinate value at which the grid switches from linear to logarithmic spacing. + r_max: + Maximum coordinate value. + + Notes + ----- + For values below rmin, the (padded) pixels are spaced antilinearly (1/r). + Above rmax they are spaced linearly (r). + """ + if distances is not None: + raise ValueError("`distances` are incompatible with a logarithmic grid") + if r_min <= 0.0 or r_max <= r_min: + raise ValueError(f"invalid r_min {r_min!r} or r_max {r_max!r}") + if r_linthresh < r_min or r_max <= r_linthresh: + raise ValueError(f"invalid r_0 {r_linthresh!r}") + + # This parametrisation is technically capable of handling a transformation + # from arbitrary rg_min and rg_max, but in accordance to the LogGrid, + # we can fix them to 0 and 1 and use the parent class for mapping them there. + rg_min = 0.0 + rg_max = 1.0 + m = (1.0 - r_min / r_linthresh) / (jnp.log(r_max / r_linthresh)) + rg_linthresh = rg_min / (1 + m) + rg_max * m / (1 + m) + alpha = r_linthresh / (rg_max - rg_linthresh) * jnp.log(r_max / r_linthresh) + beta = alpha / r_linthresh + gamma = -(r_min**2) / alpha + delta = rg_min + r_min / alpha + epsilon = r_linthresh * beta * jnp.exp(beta * (rg_max - rg_linthresh)) + + return SimpleOpenGrid( + **kwargs, + atLevel=partial( + BrokenLogGridAtLevel, + alpha=alpha, + beta=beta, + gamma=gamma, + delta=delta, + epsilon=epsilon, + r_min=r_min, + r_linthresh=r_linthresh, + r_max=r_max, + rg_min=rg_min, + rg_linthresh=rg_linthresh, + rg_max=rg_max, + ), + ) + + +def HPBrokenLogRGrid( + min_shape: Optional[Tuple[int, int]] = None, + *, + nside: Optional[int] = None, + r_min_shape: Optional[int] = None, + r_min, + r_linthresh, + r_max, + r_window_size=3, + nside0=16, + atLevel=HPRadialGridAtLevel, +) -> MGrid: + """Meshgrid of a HEALPix grid and a broken logarithmic grid. + + See `HEALPixGrid` and `BrokenLogGrid`.""" if r_min_shape is None and nside is None: hp_size, r_min_shape = min_shape nside = (hp_size / 12) ** 0.5 @@ -355,9 +575,10 @@ def HPLogRGrid( assert depth == int(depth) depth = int(depth) grid_hp = HEALPixGrid(nside0=nside0, depth=depth) - grid_r = LogarithmicGrid( + grid_r = BrokenLogGrid( min_shape=r_min_shape, r_min=r_min, + r_linthresh=r_linthresh, r_max=r_max, window_size=r_window_size, depth=depth,