Skip to content
Snippets Groups Projects
Commit c659f803 authored by Jakob Roth's avatar Jakob Roth
Browse files

use JAXbind for binding ducc wgridder to JAX

parent 9f6a66e2
No related branches found
No related tags found
No related merge requests found
Pipeline #202172 passed
...@@ -38,7 +38,7 @@ Optional dependencies: ...@@ -38,7 +38,7 @@ Optional dependencies:
- matplotlib - matplotlib
- dask-ms[xarray, zarr] (for reading pfb-clean xds files) - dask-ms[xarray, zarr] (for reading pfb-clean xds files)
- [jax-finufft](https://github.com/flatironinstitute/jax-finufft) (for using the finufft in jax-resolve) - [jax-finufft](https://github.com/flatironinstitute/jax-finufft) (for using the finufft in jax-resolve)
- [jaxlinop](https://gitlab.mpcdf.mpg.de/mtr/jax_linop) (for using ducc gridder in jax-resolve) - [JAXbind](https://github.com/NIFTy-PPL/JAXbind) (for using ducc gridder in jax-resolve)
## Installation ## Installation
... ...
......
...@@ -11,8 +11,7 @@ from matplotlib.colors import LogNorm ...@@ -11,8 +11,7 @@ from matplotlib.colors import LogNorm
import configparser import configparser
from jax import random from jax import random
response = 'old' response = 'ducc'
response = 'new'
# response = "finu" # response = "finu"
seed = 42 seed = 42
...@@ -33,17 +32,12 @@ sky, additional = jrve.sky_model(cfg["sky"]) ...@@ -33,17 +32,12 @@ sky, additional = jrve.sky_model(cfg["sky"])
sky_sp = rve.sky_model._spatial_dom(cfg["sky"]) sky_sp = rve.sky_model._spatial_dom(cfg["sky"])
sky_dom = rve.default_sky_domain(sdom=sky_sp) sky_dom = rve.default_sky_domain(sdom=sky_sp)
if response == "old": if response == "finu":
R_rve = jrve.InterferometryResponse(
obs, sky_dom, False, 1e-9, verbosity=0, nthreads=8
)
signal_response = lambda x: R_rve(sky(x))
elif response == "finu":
R_finufft = jrve.InterferometryResponseFinuFFT( R_finufft = jrve.InterferometryResponseFinuFFT(
obs, sky_sp.distances[0], sky_sp.distances[1], 1e-9 obs, sky_sp.distances[0], sky_sp.distances[1], 1e-9
) )
signal_response = lambda x: R_finufft(sky(x)[0, 0, 0, :, :]) signal_response = lambda x: R_finufft(sky(x)[0, 0, 0, :, :])
elif response == 'new': elif response == 'ducc':
sky_domain_dict = dict(npix_x=sky_sp.shape[0], sky_domain_dict = dict(npix_x=sky_sp.shape[0],
npix_y=sky_sp.shape[1], npix_y=sky_sp.shape[1],
pixsize_x=sky_sp.distances[0], pixsize_x=sky_sp.distances[0],
... ...
......
...@@ -4,6 +4,7 @@ import nifty8 as ift ...@@ -4,6 +4,7 @@ import nifty8 as ift
from functools import partial from functools import partial
from ..util import dtype_float2complex from ..util import dtype_float2complex
from jax.tree_util import Partial
def get_binbounds(coordinates): def get_binbounds(coordinates):
if len(coordinates) == 1: if len(coordinates) == 1:
...@@ -158,37 +159,22 @@ def InterferometryResponseDucc( ...@@ -158,37 +159,22 @@ def InterferometryResponseDucc(
nthreads=1, nthreads=1,
verbosity=0, verbosity=0,
): ):
from ducc0.wgridder.experimental import dirty2vis, vis2dirty from jaxbind.contrib import jaxducc0
import jax_linop
vol = pixsize_x * pixsize_y vol = pixsize_x * pixsize_y
nvis = observation.vis.shape[1]
_args = {
"uvw": observation.uvw,
"freq": observation.freq,
"pixsize_x": pixsize_x,
"pixsize_y": pixsize_y,
"epsilon": epsilon,
"do_wgridding": do_wgridding,
"nthreads": nthreads,
"flip_v": True,
"verbosity": verbosity,
}
def R(inp, out, state):
out[()] = dirty2vis(dirty=inp, **_args)
def Re_T(inp, out, state):
out[()] = vis2dirty(vis=inp.conj(), npix_x=npix_x, npix_y=npix_y, **_args)
def R_abstract(shape, dtype, state): wg = jaxducc0.get_wgridder(
return (nvis, 1), np.dtype(np.complex128) pixsize_x=pixsize_x,
pixsize_y=pixsize_y,
def R_abstract_T(shape, dtype, state): npix_x=npix_x,
return (npix_x, npix_y), np.dtype(np.float64) npix_y=npix_y,
epsilon=epsilon,
do_wgridding=do_wgridding,
nthreads=nthreads,
)
wgridder = Partial(wg, observation.uvw, observation.freq)
R_jax = jax_linop.get_linear_call(R, Re_T, R_abstract, R_abstract_T) return lambda x: vol * wgridder(x)[0]
return lambda x: vol * R_jax(x)[0]
def InterferometryResponseFinuFFT(observation, pixsizex, pixsizey, epsilon): def InterferometryResponseFinuFFT(observation, pixsizex, pixsizey, epsilon):
... ...
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment