diff --git a/Dockerfile b/Dockerfile
index fffd63f2fd5cd64311fa94c6ce7267ddd2e774a0..36aca4e19612dc7871639938831500ed1dd3d83c 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -15,6 +15,7 @@ RUN apt-get install -qq python3-mpi4py
 RUN apt-get install -qq python3-pytest-cov
 # Documentation dependencies
 RUN pip3 install pydata-sphinx-theme
+RUN pip3 install jax jaxlib
 
 # Create user (openmpi does not like to be run as root)
 RUN useradd -ms /bin/bash testinguser
diff --git a/README.md b/README.md
index d1e17157192cea4ba7fb42e35da2933f281e25c1..df7f78ac8047d3d7033c5bd63261264deda54af0 100644
--- a/README.md
+++ b/README.md
@@ -15,6 +15,7 @@ In the likely case that you encounter bugs, please contact me via [email](mailto
 - Optional dependencies are:
     - For reading [measurement sets](https://casa.nrao.edu/Memos/229.html), install [python-casacore](https://github.com/casacore/python-casacore).
     - For reading and writing FITS files: astropy.
+    - Some operators support [jax](https://github.com/google/jax).
 
 ## Related publications
 
diff --git a/demo/basic_polarization.py b/demo/basic_polarization.py
index 64d1ab3eed85ac2fd1e2e5fcfe052b0ad47ff3ca..8b0fe06d56754aa4625f1345a2a3c9de1c5a9f36 100644
--- a/demo/basic_polarization.py
+++ b/demo/basic_polarization.py
@@ -117,7 +117,7 @@ def main():
                 prefix=f"log{kk}",
             )
         logop = reduce(add, [vv.ducktape_left(kk) for kk, vv in opdct.items()])
-        mexp = rve.polarization_matrix_exponential(logop.target, args.with_v)
+        mexp = rve.polarization_matrix_exponential(logop["i"].target, args.with_v)
         # ift.extra.check_operator(mexp, ift.from_random(mexp.domain)*0.1, ntries=5)
         sky = mexp @ logop
         duckI = ift.ducktape(None, sky.target, "I")
diff --git a/resolve/polarization_matrix_exponential.py b/resolve/polarization_matrix_exponential.py
index cc5b4abc18260cc77f1f959756a9c8b29d137e5c..fd591dad89edeabaff97dc596f2e6bf2dccd70c6 100644
--- a/resolve/polarization_matrix_exponential.py
+++ b/resolve/polarization_matrix_exponential.py
@@ -1,32 +1,35 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
-# Copyright(C) 2019-2020 Max-Planck-Society
+# Copyright(C) 2019-2021 Max-Planck-Society
 # Author: Philipp Arras
 
 import nifty8 as ift
 
 
-class polarization_matrix_exponential(ift.Operator):
-    def __init__(self, domain, with_v):
+def polarization_matrix_exponential(domain, with_v, jax=False):
+    dom = ift.makeDomain(domain)
+    keys = ["i", "q", "u"]
+    if with_v:
+        keys += ["v"]
+    domain = ift.makeDomain({kk: dom for kk in keys})
+    target = ift.makeDomain({kk.upper(): dom for kk in keys})
+    if jax:
+        return _jax_pol(domain, target)
+    return PolarizationMatrixExponential(domain, target)
+
+
+class PolarizationMatrixExponential(ift.Operator):
+    def __init__(self, domain, target):
         self._domain = ift.makeDomain(domain)
-        keys = ["i", "q", "u"]
-        if with_v:
-            keys += ["v"]
-        assert set(self._domain.keys()) == set(keys)
-        assert self._domain["i"] == self._domain["q"] == self._domain["u"]
-        if with_v:
-            assert self._domain["i"] == self._domain["v"]
-        self._target = ift.makeDomain(
-            {kk.upper(): self._domain["i"] for kk in self._domain.keys()}
-        )
-        self._with_v = with_v
+        self._target = ift.makeDomain(target)
 
     def apply(self, x):
         self._check_input(x)
+        with_v = "v" in self.domain.keys()
         duckI = ift.ducktape(None, self._domain["i"], "I")
         duckQ = ift.ducktape(None, self._domain["q"], "Q")
         duckU = ift.ducktape(None, self._domain["u"], "U")
         tmpi = x["i"].exp()
-        if self._with_v:
+        if with_v:
             duckV = ift.ducktape(None, self._domain["u"], "V")
             log_p = (x["q"] ** 2 + x["u"] ** 2 + x["v"] ** 2).sqrt()
         else:
@@ -35,15 +38,37 @@ class polarization_matrix_exponential(ift.Operator):
         tmp = tmpi * log_p.sinh() * log_p.reciprocal()
         U = duckU(tmp * x["u"])
         Q = duckQ(tmp * x["q"])
-        if self._with_v:
+        if with_v:
             V = duckV(tmp * x["v"])
         if ift.is_linearization(x):
             val = I.val.unite(U.val.unite(Q.val))
             jac = I.jac + U.jac + Q.jac
-            if self._with_v:
+            if with_v:
                 val = val.unite(V.val)
                 jac = jac + V.jac
             return x.new(val, jac)
-        if self._with_v:
+        if with_v:
             return I.unite(U.unite(Q.unite(V)))
         return I.unite(U.unite(Q))
+
+
+def _jax_pol(domain, target):
+    from jax.numpy import sqrt, exp, cosh, sinh
+    with_v = "v" in domain.keys()
+
+    def func(x):
+        res = {}
+        sq = x["q"] ** 2 + x["u"] ** 2
+        if with_v:
+            sq = sq + x["v"] ** 2
+        log_p = sqrt(sq)
+        tmpi = exp(x["i"])
+        res["I"] = tmpi * cosh(log_p)
+        tmp = tmpi * sinh(log_p) / log_p
+        res["U"] = tmp * x["u"]
+        res["Q"] = tmp * x["q"]
+        if with_v:
+            res["V"] = tmp * x["v"]
+        return res
+
+    return ift.JaxOperator(domain, target, func)
diff --git a/test/test_polarization.py b/test/test_polarization.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fd552aeba3d2944650c7857ac25b1b5a0597b17
--- /dev/null
+++ b/test/test_polarization.py
@@ -0,0 +1,28 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+# Copyright(C) 2021 Max-Planck-Society
+# Author: Philipp Arras
+
+from os.path import join
+
+import numpy as np
+import pytest
+
+import nifty8 as ift
+import resolve as rve
+
+pmp = pytest.mark.parametrize
+
+
+@pmp("with_v", (False, True))
+def test_polarization(with_v):
+    dom = ift.RGSpace([10, 20])
+    op = rve.polarization_matrix_exponential(dom, with_v, False)
+    op_jax = rve.polarization_matrix_exponential(dom, with_v, True)
+
+    assert op.domain is op_jax.domain
+    assert op.target is op_jax.target
+    pos = ift.from_random(op.domain)
+    ift.extra.assert_allclose(op(pos), op_jax(pos))
+
+    ift.extra.check_operator(op, pos, ntries=5)
+    ift.extra.check_operator(op, pos, ntries=5)