Skip to content
Snippets Groups Projects
Commit f77f52e8 authored by Jakob Roth's avatar Jakob Roth
Browse files

Merge branch 'add_polarization_radio_response' into 'master'

Polarization response

See merge request !48
parents e765832d 7c7fa2d8
Branches
No related tags found
1 merge request!48Polarization response
Pipeline #207017 passed
......@@ -35,6 +35,8 @@ build_docker_from_cache:
before_script:
- pip3 install --break-system-packages .[full]
- git clone https://github.com/NIFTy-PPL/JAXbind.git
- pip3 install --break-system-packages JAXbind/
test_resolve:
stage: testing
......
......@@ -19,12 +19,23 @@ def get_binbounds(coordinates):
def convert_polarization(inp, inp_pol, out_pol):
if inp_pol == ("I",):
if inp_pol == ("I", "Q", "U", "V"):
if out_pol == ("RR", "RL", "LR", "LL"):
mat_stokes_to_circular = jnp.array(
[[1, 0, 0, 1], [0, 1, 1, 0], [0, 1j, -1j, 0], [1, 0, 0, -1]]
)
return jnp.tensordot(mat_stokes_to_circular, inp, axes=([0], [0]))
elif out_pol == ("XX", "XY", "YX", "YY"):
mat_stokes_to_linear = jnp.array(
[[1, 1, 0, 0], [1, -1, 0, 0], [0, 0, 1, 1], [0, 0, 1j, -1j]]
)
return jnp.tensordot(mat_stokes_to_linear, inp, axes=([0], [0]))
elif inp_pol == ("I",):
if out_pol == ("LL", "RR") or out_pol == ("XX", "YY"):
new_shp = list(inp.shape)
new_shp[0] = 2
return jnp.broadcast_to(inp, new_shp)
if len(out_pol) == 1 and out_pol[0] in ("I", "RR", "LL", "XX", "yy"):
if len(out_pol) == 1 and out_pol[0] in ("I", "RR", "LL", "XX", "YY"):
return inp
err = f"conversion of polarization {inp_pol} to {out_pol} not implemented. Please implement!"
raise NotImplementedError(err)
......
......@@ -146,7 +146,7 @@ setup(
zip_safe=True,
dependency_links=[],
install_requires=["ducc0>=0.23.0", "numpy", "nifty8>=8.0"],
extras_require={"full": ("astropy", "pytest", "pytest-cov", "mpi4py", "python-casacore", "h5py", "matplotlib")},
extras_require={"full": ("astropy", "pytest", "pytest-cov", "mpi4py", "python-casacore", "h5py", "matplotlib", "jax", "jaxlib")},
ext_modules=extensions,
entry_points={"console_scripts":
[
......
......@@ -15,15 +15,12 @@
# Copyright(C) 2022 Max-Planck-Society, Philipp Arras
# Author: Philipp Arras
from os.path import join
import nifty8 as ift
import numpy as np
import pytest
import resolve as rve
from .common import setup_function, teardown_function
import resolve.re as jrve
pmp = pytest.mark.parametrize
np.seterr(all="raise")
......@@ -32,8 +29,8 @@ direc = "/data/"
OBS = []
for polmode in ["all", "stokesi", "stokesiavg"]:
oo = rve.ms2observations(
f"{direc}CYG-ALL-2052-2MHZ.ms", "DATA", True, 0, polarizations=polmode
)[0]
f"{direc}CYG-ALL-2052-2MHZ.ms", "DATA", True, 0, polarizations=polmode
)[0]
# OBS.append(oo.to_single_precision())
OBS.append(oo.to_double_precision())
npix, fov = 256, 1 * rve.DEG2RAD
......@@ -48,10 +45,18 @@ def test_single_response(obs, facets):
obs = obs.to_double_precision()
sdom = dom[-1]
mask = ift.makeField(obs.mask.domain[1:], obs.mask.val[0])
op = rve.SingleResponse(sdom, obs.uvw, obs.freq, mask=mask, facets=facets, epsilon=1e-6,
do_wgridding=False)
ift.extra.check_linear_operator(op, np.float64, np.complex128,
only_r_linear=True, rtol=1e-6, atol=1e-6)
op = rve.SingleResponse(
sdom,
obs.uvw,
obs.freq,
mask=mask,
facets=facets,
epsilon=1e-6,
do_wgridding=False,
)
ift.extra.check_linear_operator(
op, np.float64, np.complex128, only_r_linear=True, rtol=1e-6, atol=1e-6
)
def test_facet_consistency():
......@@ -66,3 +71,34 @@ def test_facet_consistency():
if res0 is None:
res0 = res
ift.extra.assert_allclose(res0, res, rtol=1e-4, atol=1e-4)
def test_jax_response_consistency():
obs = OBS[0]
obs = obs.to_double_precision()
sdom = dom[-1]
fdom = rve.IRGSpace([np.mean(OBS[0].freq)])
pdom = rve.PolarizationSpace(("I", "Q", "U", "V"))
pol_dom = rve.default_sky_domain(sdom=sdom, fdom=fdom, pdom=pdom)
radio_sky = ift.from_random(pol_dom)
radio_sky_arr = radio_sky.val
R_old = rve.InterferometryResponse(
obs, pol_dom, do_wgridding=True, epsilon=1e-9, nthreads=1
)
sky_domain_dict = dict(
npix_x=sdom.shape[0],
npix_y=sdom.shape[1],
pixsize_x=float(sdom.distances[0]),
pixsize_y=float(sdom.distances[1]),
pol_labels=["I", "Q", "U", "V"],
times=[0.0],
freqs=[0.0],
)
R_new = jrve.InterferometryResponse(obs, sky_domain_dict, True, 1e-9, nthreads=1)
vis_field_old = R_old(radio_sky).val
vis_field_new = R_new(radio_sky_arr)
np.testing.assert_allclose(vis_field_old, vis_field_new)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment