diff --git a/demo/imaging_resolve_jax.py b/demo/imaging_resolve_jax.py index 40581d4ed409d30312ffd01f1b6d1a0126feb241..63154ff3070c91b1d414907d3889ee07ec8697c5 100644 --- a/demo/imaging_resolve_jax.py +++ b/demo/imaging_resolve_jax.py @@ -11,8 +11,9 @@ from matplotlib.colors import LogNorm import configparser from jax import random -response = 'ducc' -# response = "finu" +# choose between ducc0 and finufft backend +response = 'ducc0' +# response = "finufft" seed = 42 key = random.PRNGKey(seed) @@ -21,7 +22,7 @@ jax.config.update("jax_enable_x64", True) obs = rve.Observation.load("CYG-ALL-2052-2MHZ_RESOLVE_float64.npz") obs = obs.restrict_to_stokesi() -# obs = obs.average_stokesi() +obs = obs.average_stokesi() obs._weight = 0.1 * obs._weight # scale weights, as they are wrong for this specific dataset cfg = configparser.ConfigParser() cfg.read("cygnusa_2ghz.cfg") @@ -32,23 +33,16 @@ sky, additional = jrve.sky_model(cfg["sky"]) sky_sp = rve.sky_model._spatial_dom(cfg["sky"]) sky_dom = rve.default_sky_domain(sdom=sky_sp) -if response == "finu": - R_finufft = jrve.InterferometryResponseFinuFFT( - obs, sky_sp.distances[0], sky_sp.distances[1], 1e-9 - ) - signal_response = lambda x: R_finufft(sky(x)[0, 0, 0, :, :]) -elif response == 'ducc': - sky_domain_dict = dict(npix_x=sky_sp.shape[0], - npix_y=sky_sp.shape[1], - pixsize_x=sky_sp.distances[0], - pixsize_y=sky_sp.distances[1], - pol_labels=['I'], - times=[0.], - freqs=[0.]) - R_new = jrve.InterferometryResponse(obs, sky_domain_dict, False, 1e-9) - signal_response = lambda x: R_new(sky(x)) -else: - raise ValueError() + +sky_domain_dict = dict(npix_x=sky_sp.shape[0], + npix_y=sky_sp.shape[1], + pixsize_x=sky_sp.distances[0], + pixsize_y=sky_sp.distances[1], + pol_labels=['I'], + times=[0.], + freqs=[0.]) +R_new = jrve.InterferometryResponse(obs, sky_domain_dict, False, 1e-9, backend=response) +signal_response = lambda x: R_new(sky(x)) nll = jft.Gaussian(obs.vis.val, obs.weight.val).amend(signal_response) diff --git a/resolve/re/__init__.py b/resolve/re/__init__.py index 7912af0fb74783716a206ebf5130cc3d80fbfd39..ce2ba46e2dd79aa995eb542f9f9c8aad80c18a68 100644 --- a/resolve/re/__init__.py +++ b/resolve/re/__init__.py @@ -1,3 +1,3 @@ from .sky_model import sky_model_diffuse, sky_model_points, sky_model -from .response import InterferometryResponse, InterferometryResponseFinuFFT, InterferometryResponseDucc, InterferometryResponseOld \ No newline at end of file +from .response import InterferometryResponse, InterferometryResponseFinuFFT, InterferometryResponseDucc \ No newline at end of file diff --git a/resolve/re/response.py b/resolve/re/response.py index 43f3c22f4d8fec495d85773337aee969bdbc77b4..1c40d9cb794f8243168c6e4f083e587d44752bce 100644 --- a/resolve/re/response.py +++ b/resolve/re/response.py @@ -6,24 +6,25 @@ from functools import partial from ..util import dtype_float2complex from jax.tree_util import Partial + def get_binbounds(coordinates): if len(coordinates) == 1: - return np.array([-np.inf, np.inf]) + return np.array([-np.inf, np.inf]) c = np.array(coordinates) bounds = np.empty(self.size + 1) - bounds[1:-1] = c[:-1] + 0.5*np.diff(c) - bounds[0] = c[0] - 0.5*(c[1] - c[0]) - bounds[-1] = c[-1] + 0.5*(c[-1] - c[-2]) + bounds[1:-1] = c[:-1] + 0.5 * np.diff(c) + bounds[0] = c[0] - 0.5 * (c[1] - c[0]) + bounds[-1] = c[-1] + 0.5 * (c[-1] - c[-2]) return bounds def convert_polarization(inp, inp_pol, out_pol): - if inp_pol == ('I',): - if out_pol == ('LL', 'RR') or out_pol == ('XX', 'YY'): + if inp_pol == ("I",): + if out_pol == ("LL", "RR") or out_pol == ("XX", "YY"): new_shp = list(inp.shape) new_shp[0] = 2 return jnp.broadcast_to(inp, new_shp) - if len(out_pol) == 1 and out_pol[0] in ('I', 'RR', 'LL', 'XX', 'yy'): + if len(out_pol) == 1 and out_pol[0] in ("I", "RR", "LL", "XX", "yy"): return inp err = f"conversion of polarization {inp_pol} to {out_pol} not implemented. Please implement!" raise NotImplementedError(err) @@ -36,6 +37,7 @@ def InterferometryResponse( epsilon, nthreads=1, verbosity=0, + backend="ducc0", ): """Returns a function computing the radio interferometric response @@ -45,6 +47,8 @@ def InterferometryResponse( The observation for which the response should compute model visibilities sky_domain_dict: dict A dictionary providing information about the discretization of the sky. + do_wgridding : bool + Whether to perform wgridding. epsilon: float The numerical accuracy with which to evaluate the response. nthreads: int, optional @@ -52,39 +56,60 @@ def InterferometryResponse( verbosity: int, optional If set to 1 prints information about the setup and performance of the response. + backend : string + If `ducc0` use ducc0 wgridder. If `finufft` use finufft to compute response. """ - npix_x = sky_domain_dict['npix_x'] - npix_y = sky_domain_dict['npix_y'] - pixsize_x = sky_domain_dict['pixsize_x'] - pixsize_y = sky_domain_dict['pixsize_y'] + if do_wgridding and backend == "finufft": + raise RuntimeError("Cannot do wgridding with backend finufft.") + + npix_x = sky_domain_dict["npix_x"] + npix_y = sky_domain_dict["npix_y"] + pixsize_x = sky_domain_dict["pixsize_x"] + pixsize_y = sky_domain_dict["pixsize_y"] - n_pol = len(sky_domain_dict['pol_labels']) + n_pol = len(sky_domain_dict["pol_labels"]) # compute bins for time and freq - n_times = len(sky_domain_dict['times']) - bb_times = get_binbounds(sky_domain_dict['times']) + n_times = len(sky_domain_dict["times"]) + bb_times = get_binbounds(sky_domain_dict["times"]) - n_freqs = len(sky_domain_dict['freqs']) - bb_freqs = get_binbounds(sky_domain_dict['freqs']) + n_freqs = len(sky_domain_dict["freqs"]) + bb_freqs = get_binbounds(sky_domain_dict["freqs"]) # build responses for: time binds, freq bins sr = [] row_indices, freq_indices = [], [] for t in range(n_times): sr_tmp, t_tmp, f_tmp = [], [], [] - if tuple(bb_times[t:t+2]) == (-np.inf, np.inf): + if tuple(bb_times[t : t + 2]) == (-np.inf, np.inf): oo = observation tind = slice(None) else: - oo, tind = observation.restrict_by_time(bb_times[t], bb_times[t+1], True) + oo, tind = observation.restrict_by_time(bb_times[t], bb_times[t + 1], True) for f in range(n_freqs): - ooo, find = oo.restrict_by_freq(bb_freqs[f], bb_freqs[f+1], True) + ooo, find = oo.restrict_by_freq(bb_freqs[f], bb_freqs[f + 1], True) if any(np.array(ooo.vis.shape) == 0): rrr = None else: - rrr = InterferometryResponseDucc(ooo, npix_x, npix_y, pixsize_x, - pixsize_y, do_wgridding, epsilon, - nthreads, verbosity) + if backend == "ducc0": + rrr = InterferometryResponseDucc( + ooo, + npix_x, + npix_y, + pixsize_x, + pixsize_y, + do_wgridding, + epsilon, + nthreads, + verbosity, + ) + elif backend == "finufft": + rrr = InterferometryResponseFinuFFT( + ooo, pixsize_x, pixsize_y, epsilon + ) + else: + err = f"backend must be `ducc0` or `finufft` not {backend}" + raise ValueError(err) sr_tmp.append(rrr) t_tmp.append(tind) @@ -93,18 +118,18 @@ def InterferometryResponse( row_indices.append(t_tmp) freq_indices.append(f_tmp) - - target_shape = (n_pol, ) + tuple(observation.vis.shape[1:]) + target_shape = (n_pol,) + tuple(observation.vis.shape[1:]) foo = np.zeros(target_shape, np.int8) for pp in range(n_pol): for tt in range(n_times): for ff in range(n_freqs): - foo[pp, row_indices[tt][ff], freq_indices[tt][ff]] = 1. + foo[pp, row_indices[tt][ff], freq_indices[tt][ff]] = 1.0 if np.any(foo == 0): raise RuntimeError("This should not happen. Please report.") - inp_pol = tuple(sky_domain_dict['pol_labels']) + inp_pol = tuple(sky_domain_dict["pol_labels"]) out_pol = observation.vis.domain[0].labels + def apply_R(sky): res = jnp.empty(target_shape, dtype_float2complex(sky.dtype)) for pp in range(sky.shape[0]): @@ -120,33 +145,6 @@ def InterferometryResponse( return apply_R -def InterferometryResponseOld( - observation, domain, do_wgridding, epsilon, verbosity=0, nthreads=1 -): - import jax_linop - from ..response import InterferometryResponse - - R_old = InterferometryResponse( - observation, domain, do_wgridding, epsilon, verbosity, nthreads - ) - - def R(inp, out, state): - inp = ift.makeField(R_old.domain, inp) - out[()] = R_old(inp).val - - def Re_T(inp, out, state): - inp = ift.makeField(R_old.target, inp.conj()) - out[()] = R_old.adjoint(inp).val.conj() - - def R_abstract(shape, dtype, state): - return R_old.target.shape, np.dtype(np.complex128) - - def R_abstract_T(shape, dtype, state): - return R_old.domain.shape, np.dtype(np.float64) - - R_jax = jax_linop.get_linear_call(R, Re_T, R_abstract, R_abstract_T) - return lambda x: R_jax(x)[0] - def InterferometryResponseDucc( observation, @@ -195,7 +193,7 @@ def InterferometryResponseFinuFFT(observation, pixsizex, pixsizey, epsilon): def apply_finufft(inp, u, v, eps): res = vol * nufft2(inp.astype(np.complex128), u, v, eps=eps) - return jnp.expand_dims(res.reshape(-1, len(freq)), 0) + return res.reshape(-1, len(freq)) - R = partial(apply_finufft, u=u_finu, v=v_finu, eps=epsilon) + R = Partial(apply_finufft, u=u_finu, v=v_finu, eps=epsilon) return R