Skip to content
Snippets Groups Projects
Commit 2d9445b0 authored by Jakob Roth's avatar Jakob Roth
Browse files

test_sky_models: skip polarisation consistency check

parent 3c0cb70a
No related branches found
No related tags found
No related merge requests found
Pipeline #211205 passed
...@@ -37,30 +37,48 @@ from .common import setup_function, teardown_function ...@@ -37,30 +37,48 @@ from .common import setup_function, teardown_function
pmp = pytest.mark.parametrize pmp = pytest.mark.parametrize
np.seterr(all="raise") np.seterr(all="raise")
obs = rve.ms2observations("/data/CYG-D-6680-64CH-10S.ms", "DATA", True, 0, polarizations="all") obs = rve.ms2observations(
"/data/CYG-D-6680-64CH-10S.ms", "DATA", True, 0, polarizations="all"
)
class StokesAdder(jft.Model): class StokesAdder(jft.Model):
def __init__(self, correlated_field_dict): def __init__(self, correlated_field_dict):
self.cfs = correlated_field_dict self.cfs = correlated_field_dict
super().__init__(init=reduce(operator.or_, super().__init__(
[value.init for value in self.cfs.values()])) init=reduce(operator.or_, [value.init for value in self.cfs.values()])
)
def __call__(self, x): def __call__(self, x):
def get_stokes(pre_stokes): def get_stokes(pre_stokes):
pol_int = jnp.sqrt(sum(pre_stokes[i]**2 for i in range(1,4))) pol_int = jnp.sqrt(sum(pre_stokes[i] ** 2 for i in range(1, 4)))
return jnp.concatenate([jnp.exp(pre_stokes[:1])*jnp.cosh(pol_int), return jnp.concatenate(
(jnp.exp(pre_stokes[:1])*jnp.sinh(pol_int)/pol_int)*pre_stokes[1:]]) [
jnp.exp(pre_stokes[:1]) * jnp.cosh(pol_int),
(jnp.exp(pre_stokes[:1]) * jnp.sinh(pol_int) / pol_int)
* pre_stokes[1:],
]
)
pre_stokes = jnp.stack([cf(x) for cf in self.cfs.values()]) pre_stokes = jnp.stack([cf(x) for cf in self.cfs.values()])
dims_remaining = pre_stokes.shape[1:] dims_remaining = pre_stokes.shape[1:]
pre_stokes = pre_stokes.reshape((4,-1)) pre_stokes = pre_stokes.reshape((4, -1))
stokes = jax.vmap(get_stokes, in_axes=1, out_axes=-1)(pre_stokes) stokes = jax.vmap(get_stokes, in_axes=1, out_axes=-1)(pre_stokes)
return stokes.reshape((4,) + dims_remaining) return stokes.reshape((4,) + dims_remaining)
@pmp("fname", ["cfg/cygnusa.cfg", "cfg/cygnusa_polarization.cfg", "cfg/mf.cfg",
"cfg/cygnusa_mf.cfg", "cfg/cygnusa_mf_cfm.cfg"]) @pmp(
"fname",
[
"cfg/cygnusa.cfg",
"cfg/cygnusa_polarization.cfg",
"cfg/mf.cfg",
"cfg/cygnusa_mf.cfg",
"cfg/cygnusa_mf_cfm.cfg",
],
)
def test_build_multi_frequency_skymodel(fname): def test_build_multi_frequency_skymodel(fname):
tmp = TemporaryDirectory() tmp = TemporaryDirectory()
direc = tmp.name direc = tmp.name
...@@ -69,7 +87,7 @@ def test_build_multi_frequency_skymodel(fname): ...@@ -69,7 +87,7 @@ def test_build_multi_frequency_skymodel(fname):
op, _ = rve.sky_model_diffuse(cfg["sky"], obs) op, _ = rve.sky_model_diffuse(cfg["sky"], obs)
out = op(ift.from_random(op.domain)) out = op(ift.from_random(op.domain))
if not fname == "cfg/cygnusa_mf_cfm.cfg": # FIXME: overflow in float32 conversion if not fname == "cfg/cygnusa_mf_cfm.cfg": # FIXME: overflow in float32 conversion
rve.ubik_tools.field2fits(out, join(direc, "tmp.fits")) rve.ubik_tools.field2fits(out, join(direc, "tmp.fits"))
key1 = op.domain.keys() key1 = op.domain.keys()
...@@ -82,35 +100,56 @@ def test_build_multi_frequency_skymodel(fname): ...@@ -82,35 +100,56 @@ def test_build_multi_frequency_skymodel(fname):
key2 = op.domain.keys() key2 = op.domain.keys()
assert len(set(key1) & set(key2)) == 0 assert len(set(key1) & set(key2)) == 0
@pytest.mark.skip(reason="pol_test.cfg not on funk")
def test_jax_skymodel_consistency(): def test_jax_skymodel_consistency():
cfg = configparser.ConfigParser() cfg = configparser.ConfigParser()
cfg.read('pol_test.cfg') cfg.read("pol_test.cfg")
diffuse, additional = rve.sky_model_diffuse(cfg['sky']) diffuse, additional = rve.sky_model_diffuse(cfg["sky"])
pols_dict = dict() pols_dict = dict()
for pol_lbl in ('i', 'q', 'u', 'v'): for pol_lbl in ("i", "q", "u", "v"):
pols_dict |= {pol_lbl: dict()} pols_dict |= {pol_lbl: dict()}
for key in cfg['sky'].keys(): for key in cfg["sky"].keys():
if 'stokes'+pol_lbl in key: if "stokes" + pol_lbl in key:
pols_dict[pol_lbl] |= {key.removeprefix('stokes'+pol_lbl+' diffuse space i0 '): cfg['sky'][key]} pols_dict[pol_lbl] |= {
key.removeprefix("stokes" + pol_lbl + " diffuse space i0 "): cfg[
"sky"
][key]
}
sky_shape = (int(cfg['sky']['space npix x']),int(cfg['sky']['space npix y'])) sky_shape = (int(cfg["sky"]["space npix x"]), int(cfg["sky"]["space npix y"]))
distances = tuple(1/npix for npix in sky_shape) distances = tuple(1 / npix for npix in sky_shape)
correlated_field_dict = dict() correlated_field_dict = dict()
for key, val in pols_dict.items(): for key, val in pols_dict.items():
cfm = jft.CorrelatedFieldMaker('stokes'+key.upper()+' diffuse') cfm = jft.CorrelatedFieldMaker("stokes" + key.upper() + " diffuse")
cfm.set_amplitude_total_offset(offset_mean=float(val['zero mode offset']), offset_std=(float(val['zero mode mean']), float(val['zero mode stddev']))) cfm.set_amplitude_total_offset(
cfm.add_fluctuations(sky_shape, offset_mean=float(val["zero mode offset"]),
distances=distances, offset_std=(float(val["zero mode mean"]), float(val["zero mode stddev"])),
asperity=(float(val['asperity mean']), float(val['asperity stddev'])), )
loglogavgslope=(float(val['loglogavgslope mean']), float(val['loglogavgslope stddev'])), cfm.add_fluctuations(
flexibility=(float(val['flexibility mean']), float(val['flexibility stddev'])), sky_shape,
fluctuations=(float(val['fluctuations mean']), float(val['fluctuations stddev'])), distances=distances,
prefix='space i0',non_parametric_kind="power") asperity=(float(val["asperity mean"]), float(val["asperity stddev"])),
correlated_field_dict |= {'prestokes '+key: cfm.finalize()} loglogavgslope=(
float(val["loglogavgslope mean"]),
float(val["loglogavgslope stddev"]),
),
flexibility=(
float(val["flexibility mean"]),
float(val["flexibility stddev"]),
),
fluctuations=(
float(val["fluctuations mean"]),
float(val["fluctuations stddev"]),
),
prefix="space i0",
non_parametric_kind="power",
)
correlated_field_dict |= {"prestokes " + key: cfm.finalize()}
stokes_sky = StokesAdder(correlated_field_dict) stokes_sky = StokesAdder(correlated_field_dict)
...@@ -118,10 +157,10 @@ def test_jax_skymodel_consistency(): ...@@ -118,10 +157,10 @@ def test_jax_skymodel_consistency():
re_pos_init = pos_init.val re_pos_init = pos_init.val
for key in re_pos_init.keys(): for key in re_pos_init.keys():
if 'spectrum' in key: if "spectrum" in key:
re_pos_init[key] = re_pos_init[key].transpose() re_pos_init[key] = re_pos_init[key].transpose()
diffuse_field = diffuse(pos_init).val.reshape((4,)+sky_shape) diffuse_field = diffuse(pos_init).val.reshape((4,) + sky_shape)
stokes_field = stokes_sky(re_pos_init) stokes_field = stokes_sky(re_pos_init)
np.testing.assert_allclose(diffuse_field, stokes) np.testing.assert_allclose(diffuse_field, stokes_field)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment