diff --git a/src/re/__init__.py b/src/re/__init__.py
index c6e4495617f6f4c003fa533577b6c1891d479414..1c1b9b88721c737d667a0f20849d34597aa03692 100644
--- a/src/re/__init__.py
+++ b/src/re/__init__.py
@@ -39,8 +39,10 @@ from .multi_grid import (
     Grid,
     HEALPixGrid,
     HPLogRGrid,
+    HPBrokenLogRGrid,
     ICRField,
-    LogarithmicGrid,
+    LogGrid,
+    BrokenLogGrid,
     MGrid,
     SimpleOpenGrid,
 )
diff --git a/src/re/multi_grid/__init__.py b/src/re/multi_grid/__init__.py
index ecbcc7ff3064281b118de16b41767b9a9ef8d787..07d395a5cfa47d5fc20426b01bd15f2a64f5161e 100644
--- a/src/re/multi_grid/__init__.py
+++ b/src/re/multi_grid/__init__.py
@@ -2,4 +2,4 @@
 
 from .correlated_field import ICRField
 from .grid import Grid, MGrid
-from .grid_impl import HEALPixGrid, HPLogRGrid, LogarithmicGrid, SimpleOpenGrid
+from .grid_impl import HEALPixGrid, HPLogRGrid, LogGrid, SimpleOpenGrid, BrokenLogGrid, HPBrokenLogRGrid
diff --git a/src/re/multi_grid/grid.py b/src/re/multi_grid/grid.py
index 0f559c05fb8b134c8d73d0d8f683e5ec959a61c7..9f1a20ae2c4ec90790b83c3f9415a985a550a1a9 100644
--- a/src/re/multi_grid/grid.py
+++ b/src/re/multi_grid/grid.py
@@ -112,8 +112,11 @@ class GridAtLevel:
 
     def coord2index(self, coord, dtype=np.uint64):
         slc = (slice(None),) + (np.newaxis,) * (coord.ndim - 1)
-        # TODO type casting
-        return (coord * self.shape[slc] - 0.5).astype(dtype)
+        index = (coord * self.shape[slc] - 0.5)
+        if np.issubdtype(dtype, np.integer):
+            return np.rint(index).astype(dtype)
+        else:
+            raise ValueError(f"non-integer index dtype: {dtype}")
 
     def index2volume(self, index):
         return np.array(1.0 / self.size)[(np.newaxis,) * index.ndim]
@@ -234,10 +237,11 @@ class OpenGridAtLevel(GridAtLevel):
     def coord2index(self, coord, dtype=np.uint64):
         slc = (slice(None),) + (np.newaxis,) * (coord.ndim - 1)
         shp = self.shape + 2 * self.shifts
-        # TODO type
         index = coord * shp[slc] - self.shifts[slc] - 0.5
-        # assert jnp.all(index >= 0)
-        return index.astype(dtype)
+        if np.issubdtype(dtype, np.integer):
+            return np.rint(index).astype(dtype)
+        else:
+            raise ValueError(f"non-integer index dtype: {dtype}")
 
     def index2volume(self, index):
         sz = np.prod(self.shape + 2 * self.shifts)
diff --git a/src/re/multi_grid/grid_impl.py b/src/re/multi_grid/grid_impl.py
index 9931bece9643ba8c0db86e631702c4629b3f536f..f307bca24191759035790b7eb50b7297ae1a1e5e 100644
--- a/src/re/multi_grid/grid_impl.py
+++ b/src/re/multi_grid/grid_impl.py
@@ -21,7 +21,6 @@ from .grid import Grid, GridAtLevel, MGrid, MGridAtLevel, OpenGrid, OpenGridAtLe
 class HEALPixGridAtLevel(GridAtLevel):
     nside: int
     nest: bool
-    fill_strategy: str
 
     def __init__(
         self,
@@ -113,24 +112,56 @@ class HEALPixGrid(Grid):
     ):
         """HEALPix pixelization grid.
 
-        The grid can be defined either via the final nside, the initial nside
-        (nside0), or the initial shape (shape0).
+        The grid can be defined either via depth and exactly one of (nside0, nside, shape0)
+        or via nside and one of (nside0, shape0).
         """
         self.nest = nest
+
+        assert shape0 is None or isinstance(shape0, (int, tuple, np.ndarray))
         if shape0 is not None:
             assert nside0 is None
-            assert isinstance(shape0, int) or np.ndim(shape0) == 0
-            shape0 = shape0[0] if np.ndim(shape0) > 0 else shape0
+            shape0 = np.asarray(shape0).ravel()
+            assert shape0.size == 1, (
+                "shape0 must be a scalar or a single-element array/tuple"
+            )
+            (shape0,) = shape0
+            assert isinstance(shape0, int)
+            # Check whether the shape is a valid HEALPix shape
+            assert shape0 > 0 and shape0 % 12 == 0
             nside0 = (shape0 / 12) ** 0.5
-            assert int(nside0) == nside0
-            nside0 = int(nside0)
-        if nside is not None:
-            assert nside0 is None
-            assert depth is not None
-            nside0 = nside / 2**depth
-            assert int(nside0) == nside0
+            assert np.isclose(nside0, round(nside0), atol=1.0e-10)
+            nside0 = round(nside0)
+
+        assert nside is None or (
+            isinstance(nside, int)
+            and nside > 0
+            and (nside & (nside - 1)) == 0  # power of 2
+        )
+        assert nside0 is None or (
+            isinstance(nside0, int)
+            and nside0 > 0
+            and (nside0 & (nside0 - 1)) == 0  # power of 2
+        )
+        assert depth is None or (isinstance(depth, int) and depth >= 0)
+
+        if depth is not None:
+            if (nside0 is None) == (nside is None):
+                raise ValueError(
+                    "Ambiguous initialisation of HEALPixGrid. If depth is given, please supply exactly one of (nside0, nside, shape0)"
+                )
+            if nside is not None:
+                nside0 = nside // 2**depth
+        else:
+            if (nside is None) or (nside0 is None):
+                raise ValueError(
+                    "Ambiguous initialisation of HEALPixGrid. If depth is not given, please supply nside and exactly one of (nside0, shape0)"
+                )
+            assert nside0 <= nside
+            depth = np.log2(nside / nside0)
+            assert np.isclose(depth, round(depth), atol=1.0e-10)
+            depth = round(depth)
+
         self.nside0 = nside0
-        assert self.nside0 > 0
         if splits is None:
             splits = (4,) * depth
         super().__init__(
@@ -181,7 +212,7 @@ class SimpleOpenGridAtLevel(OpenGridAtLevel):
     def coord2index(self, coord, dtype=np.uint64):
         bc = (slice(None),) + (np.newaxis,) * (coord.ndim - 1)
         coord = coord / ((self.shape + 2 * self.shifts) * self.distances)[bc]
-        return super().coord2index(self, coord, dtype=dtype)
+        return super().coord2index(coord, dtype=dtype)
 
     def index2volume(self, index):
         vol = super().index2volume(index)
@@ -267,7 +298,7 @@ def SimpleOpenGrid(
     )
 
 
-class LogarithmicGridAtLevel(SimpleOpenGridAtLevel):
+class LogGridAtLevel(SimpleOpenGridAtLevel):
     def __init__(self, *args, coord_offset, coord_scale, **kwargs):
         # NOTE, technically `coord_offset` and `coord_scale` are redundant with
         # `shifts` and `distances`, however, for ease of use, we first let them
@@ -290,15 +321,15 @@ class LogarithmicGridAtLevel(SimpleOpenGridAtLevel):
 
     def coord2index(self, coord, dtype=np.uint64):
         coord = (jnp.log(coord) - self.coord_offset) / self.coord_scale
-        return super().coord2index(self, coord, dtype=dtype)
+        return super().coord2index(coord, dtype=dtype)
 
     def index2volume(self, index):
         a = (slice(None),) + (np.newaxis,) * index.ndim
-        coords = super().index2coord(index + jnp.array([-0.5, 0.5])[a])
-        return jnp.prod(coords[1] - coords[0], axis=0)
+        coords = self.index2coord(index + jnp.array([-0.5, 0.5])[a])
+        return jnp.prod(coords[1] - coords[0], axis=0, keepdims=True)
 
 
-def LogarithmicGrid(
+def LogGrid(
     *,
     r_min: float,
     r_max: float,
@@ -310,19 +341,19 @@ def LogarithmicGrid(
     """
     if distances is not None:
         raise ValueError("`distances` are incompatible with a logarithmic grid")
-    if r_min < 0.0 or r_max < r_min:
+    if r_min <= 0.0 or r_max <= r_min:
         raise ValueError(f"invalid r_min {r_min!r} or r_max {r_max!r}")
     coord_offset = np.log(r_min)
     coord_scale = np.log(r_max) - coord_offset
     return SimpleOpenGrid(
         **kwargs,
         atLevel=partial(
-            LogarithmicGridAtLevel, coord_offset=coord_offset, coord_scale=coord_scale
+            LogGridAtLevel, coord_offset=coord_offset, coord_scale=coord_scale
         ),
     )
 
 
-class HPLogRGridAtLevel(MGridAtLevel):
+class HPRadialGridAtLevel(MGridAtLevel):
     def index2coord(self, index, **kwargs):
         coords = super().index2coord(index, **kwargs)
         return coords[:3] * coords[3]
@@ -333,6 +364,13 @@ class HPLogRGridAtLevel(MGridAtLevel):
         coord = jnp.concatenate((coord / r, r), axis=0)
         return super().coord2index(coord, **kwargs)
 
+    def index2volume(self, index):
+        grid_hp, grid_r = self.grids
+        r_upper = grid_r.index2coord(index[1:2] + 0.5)
+        r_lower = grid_r.index2coord(index[1:2] - 0.5)
+        A_unity = grid_hp.index2volume(index[0:1])
+        return A_unity * (r_upper**3 - r_lower**3) / 3
+
 
 def HPLogRGrid(
     min_shape: Optional[Tuple[int, int]] = None,
@@ -343,11 +381,193 @@ def HPLogRGrid(
     r_max,
     r_window_size=3,
     nside0=16,
-    atLevel=HPLogRGridAtLevel,
+    atLevel=HPRadialGridAtLevel,
 ) -> MGrid:
     """Meshgrid of a HEALPix grid and a logarithmic grid.
 
-    See `HEALPixGrid` and `LogarithmicGrid`."""
+    See `HEALPixGrid` and `LogGrid`."""
+    if r_min_shape is None and nside is None:
+        hp_size, r_min_shape = min_shape
+        nside = (hp_size / 12) ** 0.5
+    depth = np.log2(nside / nside0)
+    assert depth == int(depth)
+    depth = int(depth)
+    grid_hp = HEALPixGrid(nside0=nside0, depth=depth)
+    grid_r = LogGrid(
+        min_shape=r_min_shape,
+        r_min=r_min,
+        r_max=r_max,
+        window_size=r_window_size,
+        depth=depth,
+    )
+    return MGrid(grid_hp, grid_r, atLevel=atLevel)
+
+
+class BrokenLogGridAtLevel(SimpleOpenGridAtLevel):
+    def __init__(
+        self,
+        *args,
+        alpha,
+        beta,
+        gamma,
+        delta,
+        epsilon,
+        r_min,
+        r_linthresh,
+        r_max,
+        rg_min,
+        rg_linthresh,
+        rg_max,
+        **kwargs,
+    ):
+        self._alpha = alpha
+        self._beta = beta
+        self._gamma = gamma
+        self._delta = delta
+        self._epsilon = epsilon
+        self._r_min = r_min
+        self._r_linthresh = r_linthresh
+        self._r_max = r_max
+        self._rg_min = rg_min
+        self._rg_linthresh = rg_linthresh
+        self._rg_max = rg_max
+        super().__init__(*args, **kwargs)
+
+    @property
+    def r_min(self):
+        return self.index2coord(np.array([-0.5]))
+
+    @property
+    def r_max(self):
+        return self.index2coord(np.array([self.shape[0] - 0.5]))
+
+    def index2coord(self, index):
+        # map to in-between 0 and 1
+        coord = super().index2coord(index)
+        # map to in-between r_min and r_max
+        condlist = [
+            coord < self._rg_min,
+            (self._rg_min <= coord) & (coord < self._rg_linthresh),
+            (self._rg_linthresh <= coord) & (coord < self._rg_max),
+            self._rg_max <= coord,
+        ]
+        funclist = [
+            lambda rg: self._gamma / (rg - self._delta),
+            lambda rg: self._r_min + self._alpha * (rg - self._rg_min),
+            lambda rg: self._r_linthresh
+            * jnp.exp(self._beta * (rg - self._rg_linthresh)),
+            lambda rg: self._r_max + self._epsilon * (rg - self._rg_max),
+        ]
+        return jnp.piecewise(coord, condlist, funclist)
+
+    def coord2index(self, coord, dtype=np.uint64):
+        # map to in-between 0 and 1
+        condlist = [
+            coord < self._r_min,
+            (self._r_min <= coord) & (coord < self._r_linthresh),
+            (self._r_linthresh <= coord) & (coord < self._r_max),
+            self._r_max <= coord,
+        ]
+        funclist = [
+            lambda r: self._delta + self._gamma / r,
+            lambda r: self._rg_min + (r - self._r_min) / self._alpha,
+            lambda r: self._rg_linthresh + jnp.log(r / self._r_linthresh) / self._beta,
+            lambda r: self._rg_max + (r - self._r_max) / self._epsilon,
+        ]
+        coord = jnp.piecewise(coord, condlist, funclist)
+        # transform to index
+        return super().coord2index(coord, dtype=dtype)
+
+    def index2volume(self, index):
+        a = (slice(None),) + (np.newaxis,) * index.ndim
+        coords = self.index2coord(index + jnp.array([-0.5, 0.5])[a])
+        return jnp.prod(coords[1] - coords[0], axis=0, keepdims=True)
+
+
+def BrokenLogGrid(
+    *,
+    r_min: float,
+    r_linthresh: float,
+    r_max: float,
+    distances=None,
+    **kwargs,
+) -> OpenGrid:
+    """Create a broken logarithmic grid on top of `SimpleOpenGrid` spanning from
+    `r_min` to `r_max` at the final depth.
+    The grid is parametrised by three radii: r_min, r_linthresh, and r_max.
+    Between r_min and r_linthresh pixels are spaced linearly (r).
+    Between r_linthresh and r_max pixels are spaced logarithmically (exp(r)).
+
+    For available parameters see the `SimpleOpenGrid` docstring in addition the ones below.
+
+    Parameters
+    ----------
+    r_min:
+        Minimum coordinate value.
+    r_linthresh:
+        Coordinate value at which the grid switches from linear to logarithmic spacing.
+    r_max:
+        Maximum coordinate value.
+
+    Notes
+    -----
+    For values below rmin, the (padded) pixels are spaced antilinearly (1/r).
+    Above rmax they are spaced linearly (r).
+    """
+    if distances is not None:
+        raise ValueError("`distances` are incompatible with a logarithmic grid")
+    if r_min <= 0.0 or r_max <= r_min:
+        raise ValueError(f"invalid r_min {r_min!r} or r_max {r_max!r}")
+    if r_linthresh < r_min or r_max <= r_linthresh:
+        raise ValueError(f"invalid r_0 {r_linthresh!r}")
+
+    # This parametrisation is technically capable of handling a transformation
+    # from arbitrary rg_min and rg_max, but in accordance to the LogGrid,
+    # we can fix them to 0 and 1 and use the parent class for mapping them there.
+    rg_min = 0.0
+    rg_max = 1.0
+    m = (1.0 - r_min / r_linthresh) / (jnp.log(r_max / r_linthresh))
+    rg_linthresh = rg_min / (1 + m) + rg_max * m / (1 + m)
+    alpha = r_linthresh / (rg_max - rg_linthresh) * jnp.log(r_max / r_linthresh)
+    beta = alpha / r_linthresh
+    gamma = -(r_min**2) / alpha
+    delta = rg_min + r_min / alpha
+    epsilon = r_linthresh * beta * jnp.exp(beta * (rg_max - rg_linthresh))
+
+    return SimpleOpenGrid(
+        **kwargs,
+        atLevel=partial(
+            BrokenLogGridAtLevel,
+            alpha=alpha,
+            beta=beta,
+            gamma=gamma,
+            delta=delta,
+            epsilon=epsilon,
+            r_min=r_min,
+            r_linthresh=r_linthresh,
+            r_max=r_max,
+            rg_min=rg_min,
+            rg_linthresh=rg_linthresh,
+            rg_max=rg_max,
+        ),
+    )
+
+
+def HPBrokenLogRGrid(
+    min_shape: Optional[Tuple[int, int]] = None,
+    *,
+    nside: Optional[int] = None,
+    r_min_shape: Optional[int] = None,
+    r_min,
+    r_linthresh,
+    r_max,
+    r_window_size=3,
+    nside0=16,
+    atLevel=HPRadialGridAtLevel,
+) -> MGrid:
+    """Meshgrid of a HEALPix grid and a broken logarithmic grid.
+
+    See `HEALPixGrid` and `BrokenLogGrid`."""
     if r_min_shape is None and nside is None:
         hp_size, r_min_shape = min_shape
         nside = (hp_size / 12) ** 0.5
@@ -355,9 +575,10 @@ def HPLogRGrid(
     assert depth == int(depth)
     depth = int(depth)
     grid_hp = HEALPixGrid(nside0=nside0, depth=depth)
-    grid_r = LogarithmicGrid(
+    grid_r = BrokenLogGrid(
         min_shape=r_min_shape,
         r_min=r_min,
+        r_linthresh=r_linthresh,
         r_max=r_max,
         window_size=r_window_size,
         depth=depth,