Skip to content
Snippets Groups Projects
Commit c659f803 authored by Jakob Roth's avatar Jakob Roth
Browse files

use JAXbind for binding ducc wgridder to JAX

parent 9f6a66e2
No related branches found
No related tags found
No related merge requests found
Pipeline #202172 passed
......@@ -38,7 +38,7 @@ Optional dependencies:
- matplotlib
- dask-ms[xarray, zarr] (for reading pfb-clean xds files)
- [jax-finufft](https://github.com/flatironinstitute/jax-finufft) (for using the finufft in jax-resolve)
- [jaxlinop](https://gitlab.mpcdf.mpg.de/mtr/jax_linop) (for using ducc gridder in jax-resolve)
- [JAXbind](https://github.com/NIFTy-PPL/JAXbind) (for using ducc gridder in jax-resolve)
## Installation
......
......@@ -11,8 +11,7 @@ from matplotlib.colors import LogNorm
import configparser
from jax import random
response = 'old'
response = 'new'
response = 'ducc'
# response = "finu"
seed = 42
......@@ -33,17 +32,12 @@ sky, additional = jrve.sky_model(cfg["sky"])
sky_sp = rve.sky_model._spatial_dom(cfg["sky"])
sky_dom = rve.default_sky_domain(sdom=sky_sp)
if response == "old":
R_rve = jrve.InterferometryResponse(
obs, sky_dom, False, 1e-9, verbosity=0, nthreads=8
)
signal_response = lambda x: R_rve(sky(x))
elif response == "finu":
if response == "finu":
R_finufft = jrve.InterferometryResponseFinuFFT(
obs, sky_sp.distances[0], sky_sp.distances[1], 1e-9
)
signal_response = lambda x: R_finufft(sky(x)[0, 0, 0, :, :])
elif response == 'new':
elif response == 'ducc':
sky_domain_dict = dict(npix_x=sky_sp.shape[0],
npix_y=sky_sp.shape[1],
pixsize_x=sky_sp.distances[0],
......
......@@ -4,6 +4,7 @@ import nifty8 as ift
from functools import partial
from ..util import dtype_float2complex
from jax.tree_util import Partial
def get_binbounds(coordinates):
if len(coordinates) == 1:
......@@ -158,37 +159,22 @@ def InterferometryResponseDucc(
nthreads=1,
verbosity=0,
):
from ducc0.wgridder.experimental import dirty2vis, vis2dirty
import jax_linop
from jaxbind.contrib import jaxducc0
vol = pixsize_x * pixsize_y
nvis = observation.vis.shape[1]
_args = {
"uvw": observation.uvw,
"freq": observation.freq,
"pixsize_x": pixsize_x,
"pixsize_y": pixsize_y,
"epsilon": epsilon,
"do_wgridding": do_wgridding,
"nthreads": nthreads,
"flip_v": True,
"verbosity": verbosity,
}
def R(inp, out, state):
out[()] = dirty2vis(dirty=inp, **_args)
def Re_T(inp, out, state):
out[()] = vis2dirty(vis=inp.conj(), npix_x=npix_x, npix_y=npix_y, **_args)
def R_abstract(shape, dtype, state):
return (nvis, 1), np.dtype(np.complex128)
def R_abstract_T(shape, dtype, state):
return (npix_x, npix_y), np.dtype(np.float64)
wg = jaxducc0.get_wgridder(
pixsize_x=pixsize_x,
pixsize_y=pixsize_y,
npix_x=npix_x,
npix_y=npix_y,
epsilon=epsilon,
do_wgridding=do_wgridding,
nthreads=nthreads,
)
wgridder = Partial(wg, observation.uvw, observation.freq)
R_jax = jax_linop.get_linear_call(R, Re_T, R_abstract, R_abstract_T)
return lambda x: vol * R_jax(x)[0]
return lambda x: vol * wgridder(x)[0]
def InterferometryResponseFinuFFT(observation, pixsizex, pixsizey, epsilon):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment