Skip to content
Snippets Groups Projects
Commit aaa56ce4 authored by Jakob Roth's avatar Jakob Roth
Browse files

resolve.re: unify ducc and finufft response

parent 8c8811c7
No related branches found
No related tags found
No related merge requests found
Pipeline #202537 passed
......@@ -11,8 +11,9 @@ from matplotlib.colors import LogNorm
import configparser
from jax import random
response = 'ducc'
# response = "finu"
# choose between ducc0 and finufft backend
response = 'ducc0'
# response = "finufft"
seed = 42
key = random.PRNGKey(seed)
......@@ -21,7 +22,7 @@ jax.config.update("jax_enable_x64", True)
obs = rve.Observation.load("CYG-ALL-2052-2MHZ_RESOLVE_float64.npz")
obs = obs.restrict_to_stokesi()
# obs = obs.average_stokesi()
obs = obs.average_stokesi()
obs._weight = 0.1 * obs._weight # scale weights, as they are wrong for this specific dataset
cfg = configparser.ConfigParser()
cfg.read("cygnusa_2ghz.cfg")
......@@ -32,23 +33,16 @@ sky, additional = jrve.sky_model(cfg["sky"])
sky_sp = rve.sky_model._spatial_dom(cfg["sky"])
sky_dom = rve.default_sky_domain(sdom=sky_sp)
if response == "finu":
R_finufft = jrve.InterferometryResponseFinuFFT(
obs, sky_sp.distances[0], sky_sp.distances[1], 1e-9
)
signal_response = lambda x: R_finufft(sky(x)[0, 0, 0, :, :])
elif response == 'ducc':
sky_domain_dict = dict(npix_x=sky_sp.shape[0],
npix_y=sky_sp.shape[1],
pixsize_x=sky_sp.distances[0],
pixsize_y=sky_sp.distances[1],
pol_labels=['I'],
times=[0.],
freqs=[0.])
R_new = jrve.InterferometryResponse(obs, sky_domain_dict, False, 1e-9)
signal_response = lambda x: R_new(sky(x))
else:
raise ValueError()
sky_domain_dict = dict(npix_x=sky_sp.shape[0],
npix_y=sky_sp.shape[1],
pixsize_x=sky_sp.distances[0],
pixsize_y=sky_sp.distances[1],
pol_labels=['I'],
times=[0.],
freqs=[0.])
R_new = jrve.InterferometryResponse(obs, sky_domain_dict, False, 1e-9, backend=response)
signal_response = lambda x: R_new(sky(x))
nll = jft.Gaussian(obs.vis.val, obs.weight.val).amend(signal_response)
......
from .sky_model import sky_model_diffuse, sky_model_points, sky_model
from .response import InterferometryResponse, InterferometryResponseFinuFFT, InterferometryResponseDucc, InterferometryResponseOld
\ No newline at end of file
from .response import InterferometryResponse, InterferometryResponseFinuFFT, InterferometryResponseDucc
\ No newline at end of file
......@@ -6,24 +6,25 @@ from functools import partial
from ..util import dtype_float2complex
from jax.tree_util import Partial
def get_binbounds(coordinates):
if len(coordinates) == 1:
return np.array([-np.inf, np.inf])
return np.array([-np.inf, np.inf])
c = np.array(coordinates)
bounds = np.empty(self.size + 1)
bounds[1:-1] = c[:-1] + 0.5*np.diff(c)
bounds[0] = c[0] - 0.5*(c[1] - c[0])
bounds[-1] = c[-1] + 0.5*(c[-1] - c[-2])
bounds[1:-1] = c[:-1] + 0.5 * np.diff(c)
bounds[0] = c[0] - 0.5 * (c[1] - c[0])
bounds[-1] = c[-1] + 0.5 * (c[-1] - c[-2])
return bounds
def convert_polarization(inp, inp_pol, out_pol):
if inp_pol == ('I',):
if out_pol == ('LL', 'RR') or out_pol == ('XX', 'YY'):
if inp_pol == ("I",):
if out_pol == ("LL", "RR") or out_pol == ("XX", "YY"):
new_shp = list(inp.shape)
new_shp[0] = 2
return jnp.broadcast_to(inp, new_shp)
if len(out_pol) == 1 and out_pol[0] in ('I', 'RR', 'LL', 'XX', 'yy'):
if len(out_pol) == 1 and out_pol[0] in ("I", "RR", "LL", "XX", "yy"):
return inp
err = f"conversion of polarization {inp_pol} to {out_pol} not implemented. Please implement!"
raise NotImplementedError(err)
......@@ -36,6 +37,7 @@ def InterferometryResponse(
epsilon,
nthreads=1,
verbosity=0,
backend="ducc0",
):
"""Returns a function computing the radio interferometric response
......@@ -45,6 +47,8 @@ def InterferometryResponse(
The observation for which the response should compute model visibilities
sky_domain_dict: dict
A dictionary providing information about the discretization of the sky.
do_wgridding : bool
Whether to perform wgridding.
epsilon: float
The numerical accuracy with which to evaluate the response.
nthreads: int, optional
......@@ -52,39 +56,60 @@ def InterferometryResponse(
verbosity: int, optional
If set to 1 prints information about the setup and performance of the
response.
backend : string
If `ducc0` use ducc0 wgridder. If `finufft` use finufft to compute response.
"""
npix_x = sky_domain_dict['npix_x']
npix_y = sky_domain_dict['npix_y']
pixsize_x = sky_domain_dict['pixsize_x']
pixsize_y = sky_domain_dict['pixsize_y']
if do_wgridding and backend == "finufft":
raise RuntimeError("Cannot do wgridding with backend finufft.")
npix_x = sky_domain_dict["npix_x"]
npix_y = sky_domain_dict["npix_y"]
pixsize_x = sky_domain_dict["pixsize_x"]
pixsize_y = sky_domain_dict["pixsize_y"]
n_pol = len(sky_domain_dict['pol_labels'])
n_pol = len(sky_domain_dict["pol_labels"])
# compute bins for time and freq
n_times = len(sky_domain_dict['times'])
bb_times = get_binbounds(sky_domain_dict['times'])
n_times = len(sky_domain_dict["times"])
bb_times = get_binbounds(sky_domain_dict["times"])
n_freqs = len(sky_domain_dict['freqs'])
bb_freqs = get_binbounds(sky_domain_dict['freqs'])
n_freqs = len(sky_domain_dict["freqs"])
bb_freqs = get_binbounds(sky_domain_dict["freqs"])
# build responses for: time binds, freq bins
sr = []
row_indices, freq_indices = [], []
for t in range(n_times):
sr_tmp, t_tmp, f_tmp = [], [], []
if tuple(bb_times[t:t+2]) == (-np.inf, np.inf):
if tuple(bb_times[t : t + 2]) == (-np.inf, np.inf):
oo = observation
tind = slice(None)
else:
oo, tind = observation.restrict_by_time(bb_times[t], bb_times[t+1], True)
oo, tind = observation.restrict_by_time(bb_times[t], bb_times[t + 1], True)
for f in range(n_freqs):
ooo, find = oo.restrict_by_freq(bb_freqs[f], bb_freqs[f+1], True)
ooo, find = oo.restrict_by_freq(bb_freqs[f], bb_freqs[f + 1], True)
if any(np.array(ooo.vis.shape) == 0):
rrr = None
else:
rrr = InterferometryResponseDucc(ooo, npix_x, npix_y, pixsize_x,
pixsize_y, do_wgridding, epsilon,
nthreads, verbosity)
if backend == "ducc0":
rrr = InterferometryResponseDucc(
ooo,
npix_x,
npix_y,
pixsize_x,
pixsize_y,
do_wgridding,
epsilon,
nthreads,
verbosity,
)
elif backend == "finufft":
rrr = InterferometryResponseFinuFFT(
ooo, pixsize_x, pixsize_y, epsilon
)
else:
err = f"backend must be `ducc0` or `finufft` not {backend}"
raise ValueError(err)
sr_tmp.append(rrr)
t_tmp.append(tind)
......@@ -93,18 +118,18 @@ def InterferometryResponse(
row_indices.append(t_tmp)
freq_indices.append(f_tmp)
target_shape = (n_pol, ) + tuple(observation.vis.shape[1:])
target_shape = (n_pol,) + tuple(observation.vis.shape[1:])
foo = np.zeros(target_shape, np.int8)
for pp in range(n_pol):
for tt in range(n_times):
for ff in range(n_freqs):
foo[pp, row_indices[tt][ff], freq_indices[tt][ff]] = 1.
foo[pp, row_indices[tt][ff], freq_indices[tt][ff]] = 1.0
if np.any(foo == 0):
raise RuntimeError("This should not happen. Please report.")
inp_pol = tuple(sky_domain_dict['pol_labels'])
inp_pol = tuple(sky_domain_dict["pol_labels"])
out_pol = observation.vis.domain[0].labels
def apply_R(sky):
res = jnp.empty(target_shape, dtype_float2complex(sky.dtype))
for pp in range(sky.shape[0]):
......@@ -120,33 +145,6 @@ def InterferometryResponse(
return apply_R
def InterferometryResponseOld(
observation, domain, do_wgridding, epsilon, verbosity=0, nthreads=1
):
import jax_linop
from ..response import InterferometryResponse
R_old = InterferometryResponse(
observation, domain, do_wgridding, epsilon, verbosity, nthreads
)
def R(inp, out, state):
inp = ift.makeField(R_old.domain, inp)
out[()] = R_old(inp).val
def Re_T(inp, out, state):
inp = ift.makeField(R_old.target, inp.conj())
out[()] = R_old.adjoint(inp).val.conj()
def R_abstract(shape, dtype, state):
return R_old.target.shape, np.dtype(np.complex128)
def R_abstract_T(shape, dtype, state):
return R_old.domain.shape, np.dtype(np.float64)
R_jax = jax_linop.get_linear_call(R, Re_T, R_abstract, R_abstract_T)
return lambda x: R_jax(x)[0]
def InterferometryResponseDucc(
observation,
......@@ -195,7 +193,7 @@ def InterferometryResponseFinuFFT(observation, pixsizex, pixsizey, epsilon):
def apply_finufft(inp, u, v, eps):
res = vol * nufft2(inp.astype(np.complex128), u, v, eps=eps)
return jnp.expand_dims(res.reshape(-1, len(freq)), 0)
return res.reshape(-1, len(freq))
R = partial(apply_finufft, u=u_finu, v=v_finu, eps=epsilon)
R = Partial(apply_finufft, u=u_finu, v=v_finu, eps=epsilon)
return R
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment