diff --git a/Dockerfile b/Dockerfile index fffd63f2fd5cd64311fa94c6ce7267ddd2e774a0..36aca4e19612dc7871639938831500ed1dd3d83c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,6 +15,7 @@ RUN apt-get install -qq python3-mpi4py RUN apt-get install -qq python3-pytest-cov # Documentation dependencies RUN pip3 install pydata-sphinx-theme +RUN pip3 install jax jaxlib # Create user (openmpi does not like to be run as root) RUN useradd -ms /bin/bash testinguser diff --git a/README.md b/README.md index d1e17157192cea4ba7fb42e35da2933f281e25c1..df7f78ac8047d3d7033c5bd63261264deda54af0 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ In the likely case that you encounter bugs, please contact me via [email](mailto - Optional dependencies are: - For reading [measurement sets](https://casa.nrao.edu/Memos/229.html), install [python-casacore](https://github.com/casacore/python-casacore). - For reading and writing FITS files: astropy. + - Some operators support [jax](https://github.com/google/jax). ## Related publications diff --git a/demo/basic_polarization.py b/demo/basic_polarization.py index 64d1ab3eed85ac2fd1e2e5fcfe052b0ad47ff3ca..8b0fe06d56754aa4625f1345a2a3c9de1c5a9f36 100644 --- a/demo/basic_polarization.py +++ b/demo/basic_polarization.py @@ -117,7 +117,7 @@ def main(): prefix=f"log{kk}", ) logop = reduce(add, [vv.ducktape_left(kk) for kk, vv in opdct.items()]) - mexp = rve.polarization_matrix_exponential(logop.target, args.with_v) + mexp = rve.polarization_matrix_exponential(logop["i"].target, args.with_v) # ift.extra.check_operator(mexp, ift.from_random(mexp.domain)*0.1, ntries=5) sky = mexp @ logop duckI = ift.ducktape(None, sky.target, "I") diff --git a/resolve/polarization_matrix_exponential.py b/resolve/polarization_matrix_exponential.py index cc5b4abc18260cc77f1f959756a9c8b29d137e5c..fd591dad89edeabaff97dc596f2e6bf2dccd70c6 100644 --- a/resolve/polarization_matrix_exponential.py +++ b/resolve/polarization_matrix_exponential.py @@ -1,32 +1,35 @@ # SPDX-License-Identifier: GPL-3.0-or-later -# Copyright(C) 2019-2020 Max-Planck-Society +# Copyright(C) 2019-2021 Max-Planck-Society # Author: Philipp Arras import nifty8 as ift -class polarization_matrix_exponential(ift.Operator): - def __init__(self, domain, with_v): +def polarization_matrix_exponential(domain, with_v, jax=False): + dom = ift.makeDomain(domain) + keys = ["i", "q", "u"] + if with_v: + keys += ["v"] + domain = ift.makeDomain({kk: dom for kk in keys}) + target = ift.makeDomain({kk.upper(): dom for kk in keys}) + if jax: + return _jax_pol(domain, target) + return PolarizationMatrixExponential(domain, target) + + +class PolarizationMatrixExponential(ift.Operator): + def __init__(self, domain, target): self._domain = ift.makeDomain(domain) - keys = ["i", "q", "u"] - if with_v: - keys += ["v"] - assert set(self._domain.keys()) == set(keys) - assert self._domain["i"] == self._domain["q"] == self._domain["u"] - if with_v: - assert self._domain["i"] == self._domain["v"] - self._target = ift.makeDomain( - {kk.upper(): self._domain["i"] for kk in self._domain.keys()} - ) - self._with_v = with_v + self._target = ift.makeDomain(target) def apply(self, x): self._check_input(x) + with_v = "v" in self.domain.keys() duckI = ift.ducktape(None, self._domain["i"], "I") duckQ = ift.ducktape(None, self._domain["q"], "Q") duckU = ift.ducktape(None, self._domain["u"], "U") tmpi = x["i"].exp() - if self._with_v: + if with_v: duckV = ift.ducktape(None, self._domain["u"], "V") log_p = (x["q"] ** 2 + x["u"] ** 2 + x["v"] ** 2).sqrt() else: @@ -35,15 +38,37 @@ class polarization_matrix_exponential(ift.Operator): tmp = tmpi * log_p.sinh() * log_p.reciprocal() U = duckU(tmp * x["u"]) Q = duckQ(tmp * x["q"]) - if self._with_v: + if with_v: V = duckV(tmp * x["v"]) if ift.is_linearization(x): val = I.val.unite(U.val.unite(Q.val)) jac = I.jac + U.jac + Q.jac - if self._with_v: + if with_v: val = val.unite(V.val) jac = jac + V.jac return x.new(val, jac) - if self._with_v: + if with_v: return I.unite(U.unite(Q.unite(V))) return I.unite(U.unite(Q)) + + +def _jax_pol(domain, target): + from jax.numpy import sqrt, exp, cosh, sinh + with_v = "v" in domain.keys() + + def func(x): + res = {} + sq = x["q"] ** 2 + x["u"] ** 2 + if with_v: + sq = sq + x["v"] ** 2 + log_p = sqrt(sq) + tmpi = exp(x["i"]) + res["I"] = tmpi * cosh(log_p) + tmp = tmpi * sinh(log_p) / log_p + res["U"] = tmp * x["u"] + res["Q"] = tmp * x["q"] + if with_v: + res["V"] = tmp * x["v"] + return res + + return ift.JaxOperator(domain, target, func) diff --git a/test/test_polarization.py b/test/test_polarization.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd552aeba3d2944650c7857ac25b1b5a0597b17 --- /dev/null +++ b/test/test_polarization.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# Copyright(C) 2021 Max-Planck-Society +# Author: Philipp Arras + +from os.path import join + +import numpy as np +import pytest + +import nifty8 as ift +import resolve as rve + +pmp = pytest.mark.parametrize + + +@pmp("with_v", (False, True)) +def test_polarization(with_v): + dom = ift.RGSpace([10, 20]) + op = rve.polarization_matrix_exponential(dom, with_v, False) + op_jax = rve.polarization_matrix_exponential(dom, with_v, True) + + assert op.domain is op_jax.domain + assert op.target is op_jax.target + pos = ift.from_random(op.domain) + ift.extra.assert_allclose(op(pos), op_jax(pos)) + + ift.extra.check_operator(op, pos, ntries=5) + ift.extra.check_operator(op, pos, ntries=5)