diff --git a/resolve/re/finufft_response.py b/resolve/re/finufft_response.py index 32926ad94c3a3c0a51414eb19ffc3f7c14a59d8d..a9ef168e26f458869481fb51bbdb44412adcee4c 100644 --- a/resolve/re/finufft_response.py +++ b/resolve/re/finufft_response.py @@ -1,4 +1,8 @@ import numpy as np +from jax import numpy as jnp +from jax_finufft import nufft2 +import nifty8 as ift + distances = 0.2 @@ -9,6 +13,67 @@ extent = (-halfside[0], halfside[0], -halfside[1], halfside[1]) def nufft_coords(x, y): - x = 2*np.pi * (x - extent[0] - distances[0]/2)/(extent[1] - extent[0]) - y = 2*np.pi * (y - extent[2] - distances[1]/2)/(extent[3] - extent[2]) - return np.array((x, y)) + x = x - np.min(x) + y = y - np.min(y) + max_xy = np.max([np.max(x), np.max(y)]) + x = 2*np.pi*(x / max_xy) + y = 2*np.pi*(y / max_xy) + return x, y + +def nufft_coords2(x, y, dst): + x = (2*np.pi*x*dst) % (2*np.pi) + y = (2*np.pi*y*dst) % (2*np.pi) + return x, y + + +def finufft_response(obs, distances, sky): + assert len(obs.freq) == 1 # FIXME + u, v, w = obs.uvw.T + dtype_uvw = obs.uvw.dtype + dtype_sky = sky.dtype + assert dtype_uvw == dtype_sky + sky = sky.astype(np.complex128) + u, v = nufft_coords2(u, v, distances[0]) + return nufft2(sky, u, v) + +def old_response(obs, distances, sky): + sspace = ift.RGSpace(sky.shape, distances) + sky_dom = rve.default_sky_domain(sdom=sspace) + sky_new = np.empty(sky_dom.shape) + sky_new[0,0,0,:,:] = sky + sky_new = ift.makeField(sky_dom, sky_new) + R = rve.InterferometryResponse(obs, sky_dom, False, 1e-5) + return R(sky_new), R + +########## for debug ########### + +import resolve as rve +obs = rve.Observation.load('uid___A002_Xd80784_X1ab3_only_SPT0418-47_only_spw_21.npz') +freq = obs.freq +obs = obs.restrict_by_freq(freq[0], freq[1]) +obs = obs.average_stokesi() +obs = obs.to_double_precision() +sky = np.full((100, 100), 1.0) +dist = (0.01, 0.01) +res_finufft = finufft_response(obs, dist, sky) +res_finufft = np.reshape(res_finufft, obs.vis.shape) +res_finufft = np.array(res_finufft) +res_finufft = ift.makeField(obs.vis.domain, res_finufft) + +res_ducc, R = old_response(obs, dist, sky) +# np.testing.assert_allclose(res_finufft, res_ducc.val) + + +dirty_ducc = R.adjoint(res_ducc) +dirty_finufft = R.adjoint(res_finufft) + +import matplotlib.pyplot as plt + +plt.imshow(dirty_ducc.val[0,0,0,:]) +plt.colorbar() +plt.show() + +plt.imshow(dirty_finufft.val[0,0,0,:]) +plt.colorbar() +plt.show() +