Skip to content
Snippets Groups Projects
Commit 28e62f11 authored by Gordian Edenhofer's avatar Gordian Edenhofer
Browse files

Revert "refine healpix: Jit refinement"

This reverts commit b25d31aa.
parent 1ab4cce7
Branches
Tags
No related merge requests found
......@@ -155,101 +155,6 @@ def _refinement_matrices(
return olf, fine_kernel_sqrt
def _refine_healpix(
coarse_values,
exc,
idx_hp,
idx_r,
gc,
gf,
*,
kernel,
radial_chart=None,
precision=None
):
nside = (coarse_values.shape[0] / 12)**0.5
level = log2(nside)
if not nside.is_integer() or not level.is_integer():
raise ValueError("invalid nside of `coarse_values`")
nside, level = int(nside), int(level)
fsz_hp = 4
fsz_r = 2
csz_hp = 9
csz_r = radial_chart.coarse_size if radial_chart is not None else 3
# `idx_r` is the left-most radial pixel of the to-be-refined slice
# Extend `gc` and `gf` radially
ndim = coarse_values.ndim
if ndim == 1:
if gc.ndim != 2 or gf.ndim != 2:
raise AssertionError()
elif ndim == 2:
bc = (1, ) * (ndim - 1) + (-1, 1)
rc = radial_chart.ind2cart(
idx_r + jnp.arange(csz_r)[np.newaxis, :], level
).reshape(bc)
gc = gc[:, np.newaxis, :] * rc
gc = gc.reshape(-1, ndim + 1)
rf = radial_chart.ind2cart(
idx_r + jnp.array([0.75, 1.25])[np.newaxis, :], level
).reshape(bc)
gf = gf[:, np.newaxis, :] * rf
gf = gf.reshape(-1, ndim + 1)
else:
raise AssertionError()
olf, fks = _refinement_matrices((gc, gf), kernel=kernel)
if ndim > 1:
olf = olf.reshape(fsz_hp, fsz_r, csz_hp, csz_r)
c = coarse_values[idx_hp]
if ndim == 2:
c = dynamic_slice_in_dim(
coarse_values[idx_hp], idx_r, slice_size=csz_r, axis=1
)
refined = jnp.tensordot(olf, c, axes=ndim, precision=precision)
f_shp = (fsz_hp, ) if ndim == 1 else (fsz_hp, fsz_r)
refined += jnp.matmul(fks, exc, precision=precision).reshape(f_shp)
return refined
def _vmap_squeeze_first_2ndax(fun, *args, **kwargs):
vfun = vmap(fun, *args, **kwargs)
def vfun_apply(*x):
return vfun(jnp.squeeze(x[0], axis=1), *x[1:])
return vfun_apply
@partial(jax.jit, static_argnames=("kernel", "radial_chart", "precision"))
def _refine_healpix_1d(*args, kernel, radial_chart, precision=None, **kwargs):
refine = partial(
_refine_healpix,
kernel=kernel,
radial_chart=radial_chart,
precision=precision
)
vrefine = _vmap_squeeze_first_2ndax(
refine, in_axes=(None, 0, 0, None, 0, 0)
)
return vrefine(*args, **kwargs)
@partial(jax.jit, static_argnames=("kernel", "radial_chart", "precision"))
def _refine_healpix_2d(*args, kernel, radial_chart, precision=None, **kwargs):
refine = partial(
_refine_healpix,
kernel=kernel,
radial_chart=radial_chart,
precision=precision
)
# TODO: benchmark swapping these two
vrefine = vmap(refine, in_axes=(None, 0, None, 0, None, None))
vrefine = vmap(vrefine, in_axes=(None, 0, 0, None, 0, 0))
return vrefine(*args, **kwargs)
# %%
def matern_kernel(distance, scale, cutoff, dof):
"""Evaluates the Matern covariance kernel parametrized by its `scale`,
......@@ -280,6 +185,15 @@ def matern_kernel(distance, scale, cutoff, dof):
return jnp.where(distance < 1e-8 * cutoff, scale**2, cov)
def _vmap_squeeze_first_2ndax(fun, *args, **kwargs):
vfun = vmap(fun, *args, **kwargs)
def vfun_apply(*x):
return vfun(jnp.squeeze(x[0], axis=1), *x[1:])
return vfun_apply
def refine_slice(
radial_chart,
coarse_values,
......@@ -291,7 +205,10 @@ def refine_slice(
if ndim not in (1, 2):
raise ValueError(f"invalid dimensions {ndim!r}; expected either 0 or 1")
coarse_values = coarse_values[:, np.newaxis] if ndim == 1 else coarse_values
csz_r = radial_chart.coarse_size if radial_chart is not None else 3
fsz_hp = 4
fsz_r = 2
csz_hp = 9
csz_r = 3
nside = (coarse_values.shape[0] / 12)**0.5
level = log2(nside)
......@@ -310,18 +227,52 @@ def refine_slice(
axis=-1
)
def refine(coarse_full, exc, idx_hp, idx_r, gc, gf):
# `idx_r` is the left-most radial pixel of the to-be-refined slice
# Extend `gc` and `gf` radially
if ndim == 1:
pix_r_off = None
vrefine = _refine_healpix_1d
if gc.ndim != 2 or gf.ndim != 2:
raise AssertionError()
elif ndim == 2:
pix_r_off = jnp.arange(radial_chart.shape_at(level)[0] - csz_r + 1)
vrefine = _refine_healpix_2d
bc = (1, ) * (ndim - 1) + (-1, 1)
rc = radial_chart.ind2cart(
idx_r + jnp.arange(csz_r)[np.newaxis, :], level
).reshape(bc)
gc = gc[:, np.newaxis, :] * rc
gc = gc.reshape(-1, ndim + 1)
rf = radial_chart.ind2cart(
idx_r + jnp.array([0.75, 1.25])[np.newaxis, :], level
).reshape(bc)
gf = gf[:, np.newaxis, :] * rf
gf = gf.reshape(-1, ndim + 1)
else:
raise AssertionError()
vrefine = partial(
vrefine, kernel=kernel, radial_chart=radial_chart, precision=precision
olf, fks = _refinement_matrices((gc, gf), kernel=kernel)
if ndim > 1:
olf = olf.reshape(fsz_hp, fsz_r, csz_hp, csz_r)
c = coarse_full[idx_hp]
if ndim == 2:
c = dynamic_slice_in_dim(
coarse_full[idx_hp], idx_r, slice_size=csz_r, axis=1
)
refined = jnp.tensordot(olf, c, axes=ndim, precision=precision)
f_shp = (fsz_hp, ) if ndim == 1 else (fsz_hp, fsz_r)
refined += jnp.matmul(fks, exc, precision=precision).reshape(f_shp)
return refined
# TODO: benchmark swapping these two
if ndim == 1:
pix_r_off = None
vrefine = _vmap_squeeze_first_2ndax(
refine, in_axes=(None, 0, 0, None, 0, 0)
)
elif ndim == 2:
pix_r_off = jnp.arange(radial_chart.shape_at(level)[0] - csz_r + 1)
vrefine = vmap(refine, in_axes=(None, 0, None, 0, None, None))
vrefine = vmap(vrefine, in_axes=(None, 0, 0, None, 0, 0))
else:
raise AssertionError()
refined = vrefine(
coarse_values, excitations, pix_nbr_idx, pix_r_off, gc, gf
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment