diff --git a/demos/0_intro.py b/demos/0_intro.py
index 4c61e1ebab94db5e629cb8fbfa7212517815aa04..0ac93f5300d92ff52ddc6982abe2dd73924e5c41 100644
--- a/demos/0_intro.py
+++ b/demos/0_intro.py
@@ -11,11 +11,10 @@
 # %%
 import jax
 import matplotlib.pyplot as plt
+import nifty8.re as jft
 from jax import numpy as jnp
 from jax import random
 
-import nifty8.re as jft
-
 jax.config.update("jax_enable_x64", True)
 
 seed = 42
diff --git a/demos/1_tomography.py b/demos/1_tomography.py
index dc27d917407b7047f94b16ecf72c8f2e44af5db3..90244368329eaa48d2eef5aabf18f6c6153273a0 100644
--- a/demos/1_tomography.py
+++ b/demos/1_tomography.py
@@ -9,12 +9,11 @@
 # %%
 import jax
 import matplotlib.pyplot as plt
+import nifty8.re as jft
 import numpy as np
 from jax import numpy as jnp
 from jax import random
 
-import nifty8.re as jft
-
 jax.config.update("jax_enable_x64", True)
 
 # %%
diff --git a/demos/a_icr.py b/demos/a_icr.py
index 4791a6f89ede5638a462494acb3202ee76e201ec..526b690f4395ae01154122ae72f3f9d5b478e47e 100644
--- a/demos/a_icr.py
+++ b/demos/a_icr.py
@@ -2,7 +2,6 @@
 
 import jax
 import matplotlib.pyplot as plt
-
 import nifty8.re as jft
 from jax import random
 
diff --git a/demos/a_nonlinear_regression.py b/demos/a_nonlinear_regression.py
index 446880cd8e2cc4e089cdf7e4686c9632a4ff1426..d019e12e5c5f2f2fda34bb53764fa37ff7758066 100644
--- a/demos/a_nonlinear_regression.py
+++ b/demos/a_nonlinear_regression.py
@@ -7,15 +7,15 @@
 # # Demonstration of a non-linear regression using NIFTy.re
 
 # %%
+import operator as op
+from functools import partial
+
 import jax
 import matplotlib.pyplot as plt
-from jax import numpy as jnp
+import nifty8.re as jft
 import numpy as np
+from jax import numpy as jnp
 from jax import random
-from functools import partial
-import operator as op
-
-import nifty8.re as jft
 
 jax.config.update("jax_enable_x64", True)
 
diff --git a/demos/a_wiener_filter.py b/demos/a_wiener_filter.py
index e44d494238e8cf8ddc55b03c8c5f9ef2a519e7c0..9513a95398b23027f0887a756154e2ce64b18f88 100644
--- a/demos/a_wiener_filter.py
+++ b/demos/a_wiener_filter.py
@@ -11,11 +11,10 @@
 # %%
 import jax
 import matplotlib.pyplot as plt
+import nifty8.re as jft
 from jax import numpy as jnp
 from jax import random
 
-import nifty8.re as jft
-
 jax.config.update("jax_enable_x64", True)
 
 seed = 42
diff --git a/test/test_re/__init__.py b/test/test_re/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/test/test_re/test_correlated_field.py b/test/test_re/test_correlated_field.py
index 91d853eff3421304e068665cc3fe9fc1e6ab6b5b..0c1004bb7b36a3acca88c408ce45a8049dcd1d06 100644
--- a/test/test_re/test_correlated_field.py
+++ b/test/test_re/test_correlated_field.py
@@ -1,16 +1,14 @@
 #!/usr/bin/env python3
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
+import jax
 import jax.random as random
+import nifty8 as ift
+import nifty8.re as jft
 import pytest
 from numpy.testing import assert_allclose
 
-import nifty8 as ift
-import nifty8.re as jft
+jax.config.update("jax_enable_x64", True)
 
 pmp = pytest.mark.parametrize
 
diff --git a/test/test_re/test_custom_map.py b/test/test_re/test_custom_map.py
index 0bbbcbac596605c2848d865da5c9edd04c0bd772..d9cecc2399be2f5ab076a0e22d7960d275d60414 100644
--- a/test/test_re/test_custom_map.py
+++ b/test/test_re/test_custom_map.py
@@ -1,23 +1,19 @@
 #!/usr/bin/env python3
 
-import pytest
-
-pytest.importorskip("jax")
-
 from functools import partial
 
 import jax
 import jax.numpy as jnp
+import nifty8.re as jft
+import pytest
 from jax import random
 from jax.tree_util import tree_map
 from numpy.testing import assert_allclose
 
-import nifty8.re as jft
+jax.config.update("jax_enable_x64", True)
 
 pmp = pytest.mark.parametrize
 
-jax.config.update("jax_enable_x64", True)
-
 
 def f(u, v):
     return jnp.exp(u @ v)
diff --git a/test/test_re/test_estimate_evidence_lower_bound.py b/test/test_re/test_estimate_evidence_lower_bound.py
index 8d80c7c8f2a6e9c18facc4b20328c3bb9c248df7..072f1992a53a19b3f8b7636d3715ca9fe0048b9b 100644
--- a/test/test_re/test_estimate_evidence_lower_bound.py
+++ b/test/test_re/test_estimate_evidence_lower_bound.py
@@ -3,17 +3,16 @@
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 # Author: Matteo Guardiani
 
-import pytest
-
-pytest.importorskip("jax")
-
 import jax
 import jax.random as random
 import numpy as np
+import pytest
 
 import nifty8 as ift
 import nifty8.re as jft
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_evi.py b/test/test_re/test_evi.py
index 88dab250bdbfc9ab7dc485dc97804f6168cd271b..b4f990342030792e213cb8c72a08b5f3943db539 100644
--- a/test/test_re/test_evi.py
+++ b/test/test_re/test_evi.py
@@ -1,21 +1,18 @@
 #!/usr/bin/env python3
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
 from functools import partial
 
 import jax
 import jax.numpy as jnp
+import nifty8.re as jft
 import numpy as np
+import pytest
 from jax import random
 from numpy.testing import assert_allclose, assert_array_equal
 
-import nifty8.re as jft
-
 jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_forest_math.py b/test/test_re/test_forest_math.py
index 6ec08007f4bae0ca7f79381bf71a3b9784939344..308d0acbdeed9f3f431c214b86917c57b0df2886 100644
--- a/test/test_re/test_forest_math.py
+++ b/test/test_re/test_forest_math.py
@@ -1,8 +1,8 @@
+import jax
+import nifty8.re as jft
 import pytest
 
-pytest.importorskip("jax")
-
-import nifty8.re as jft
+jax.config.update("jax_enable_x64", True)
 
 
 def test_map_forest_axes_validation():
diff --git a/test/test_re/test_gauss_markov.py b/test/test_re/test_gauss_markov.py
index 497c5f5533c5b601748ca9f87874740b01105aac..117ca1e2e943a29b85089fb299cda7c8e3ad02ab 100644
--- a/test/test_re/test_gauss_markov.py
+++ b/test/test_re/test_gauss_markov.py
@@ -1,18 +1,17 @@
 #!/usr/bin/env python3
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
 import jax
-import numpy as np
 import jax.numpy as jnp
+import numpy as np
+import pytest
 from jax import random, vmap
 from numpy.testing import assert_allclose
 
 import nifty8.re as jft
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_hmc_1d_distributions.py b/test/test_re/test_hmc_1d_distributions.py
index 8ae98817f71878ef6ace39faca82e721eeabf046..8d159d5342e8c5c79cd0b118a086ce070a90fd50 100644
--- a/test/test_re/test_hmc_1d_distributions.py
+++ b/test/test_re/test_hmc_1d_distributions.py
@@ -1,16 +1,15 @@
 import sys
 
+import jax
+import nifty8.re as jft
 import pytest
-
-pytest.importorskip("jax")
-
+import scipy
 from jax import numpy as jnp
 from jax.scipy import stats
 from numpy.testing import assert_allclose
-import scipy
 from scipy.special import comb
 
-import nifty8.re as jft
+jax.config.update("jax_enable_x64", True)
 
 pmp = pytest.mark.parametrize
 
diff --git a/test/test_re/test_hmc_hashes.py b/test/test_re/test_hmc_hashes.py
index d114a1838768874c26d93cc9be89a14ffdda5408..c906b9784815a0cbd76cdf59c5f25f33c8c07d01 100644
--- a/test/test_re/test_hmc_hashes.py
+++ b/test/test_re/test_hmc_hashes.py
@@ -1,14 +1,12 @@
 import sys
 
-import pytest
-
-pytest.importorskip("jax")
-
 import jax
+import nifty8.re as jft
+import pytest
 from jax import numpy as jnp
 from numpy import ndarray
 
-import nifty8.re as jft
+jax.config.update("jax_enable_x64", True)
 
 NDARRAY_TYPE = [ndarray]
 
diff --git a/test/test_re/test_hmc_leapfrog.py b/test/test_re/test_hmc_leapfrog.py
index 7a14826bbc795eef7c9b74561d333eea7353f54f..f763efe0f977f37f79efc0ece1b336af005a0867 100644
--- a/test/test_re/test_hmc_leapfrog.py
+++ b/test/test_re/test_hmc_leapfrog.py
@@ -1,14 +1,12 @@
-import pytest
-
-pytest.importorskip("jax")
-
 import sys
 
-from jax import grad
+import jax
+import nifty8.re as jft
+import pytest
 from jax import numpy as jnp
 from numpy.testing import assert_allclose
 
-import nifty8.re as jft
+jax.config.update("jax_enable_x64", True)
 
 pmp = pytest.mark.parametrize
 
@@ -29,7 +27,7 @@ def test_leapfrog_energy_conservation(potential_energy, rtol, interactive=False)
     mass_matrix = jnp.ones(shape=dims)
     kinetic_energy = lambda p: jnp.sum(p**2 / mass_matrix / 2.0)
 
-    potential_energy_gradient = grad(potential_energy)
+    potential_energy_gradient = jax.grad(potential_energy)
     positions = [jnp.array([-1.5, -1.55])]
     momenta = [jnp.array([-1, 1])]
     for _ in range(25):
@@ -52,7 +50,7 @@ def test_leapfrog_energy_conservation(potential_energy, rtol, interactive=False)
     ):
         msg = (
             f"q: {q}; p: {p}"
-            f"\nE_tot: {e_pot+e_kin:.2e}; E_pot: {e_pot:.2e}; E_kin: {e_kin:.2e}"
+            f"\nE_tot: {e_pot + e_kin:.2e}; E_pot: {e_pot:.2e}; E_kin: {e_kin:.2e}"
         )
         print(msg, file=sys.stderr)
 
diff --git a/test/test_re/test_hmc_pytree.py b/test/test_re/test_hmc_pytree.py
index 46ad659b48d49c51fe315125dc44cc6bbfe59daa..0c26f1479e31faca7f781d15a09ffb8bec1e4540 100644
--- a/test/test_re/test_hmc_pytree.py
+++ b/test/test_re/test_hmc_pytree.py
@@ -1,14 +1,13 @@
-import pytest
-
-pytest.importorskip("jax")
-
 from functools import partial
 
+import jax
+import nifty8.re as jft
+import pytest
 from jax import numpy as jnp
 from jax.tree_util import tree_leaves
 from numpy.testing import assert_array_equal
 
-import nifty8.re as jft
+jax.config.update("jax_enable_x64", True)
 
 
 def test_hmc_pytree():
diff --git a/test/test_re/test_indexing.py b/test/test_re/test_indexing.py
index 6afb4306cab204a29230d7067e01d128e4910ee3..ab5ba65e2f009631333fb691eb21ea88793708d5 100644
--- a/test/test_re/test_indexing.py
+++ b/test/test_re/test_indexing.py
@@ -2,16 +2,16 @@
 
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
+import jax
 import numpy as np
+import pytest
 from jax.tree_util import tree_flatten, tree_map
 
 from nifty8.re.multi_grid.grid import FlatGrid, Grid, MGrid, OpenGrid
 from nifty8.re.multi_grid.grid_impl import HEALPixGrid
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_lanczos.py b/test/test_re/test_lanczos.py
index ff397e0e67ca4d93fefdd9fda74a6115ca135cce..179a81f06d13db5dcb1c67ff5bcea834e9fb0776 100644
--- a/test/test_re/test_lanczos.py
+++ b/test/test_re/test_lanczos.py
@@ -2,20 +2,21 @@
 
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
-from functools import partial
 import sys
-from jax import random
+from functools import partial
+
+import jax
 import jax.numpy as jnp
 import numpy as np
+import pytest
+from jax import random
 from numpy.testing import assert_allclose
 from scipy.spatial import distance_matrix
 
 import nifty8.re as jft
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_likelihood.py b/test/test_re/test_likelihood.py
index 3810681082f2395b7bbb07997b15e50834adb665..a15bccf91e728f35cf998b1dd31c0ab137af0779 100644
--- a/test/test_re/test_likelihood.py
+++ b/test/test_re/test_likelihood.py
@@ -2,13 +2,10 @@
 
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
 from functools import partial
 
 import jax
+import pytest
 from jax import numpy as jnp
 from jax import random
 from jax.tree_util import tree_map
@@ -17,6 +14,8 @@ from numpy.testing import assert_allclose
 import nifty8.re as jft
 from nifty8.re.likelihood import partial_insert_and_remove as jpartial
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 jax.config.update("jax_enable_x64", True)
diff --git a/test/test_re/test_likelihood_impl.py b/test/test_re/test_likelihood_impl.py
index f423a1facfe9aa3a33e4679e18570ea7df3bb954..c251e4679265143bddef97fc0164cd35cc5361c6 100644
--- a/test/test_re/test_likelihood_impl.py
+++ b/test/test_re/test_likelihood_impl.py
@@ -1,19 +1,18 @@
 #!/usr/bin/env python3
 
-import pytest
-
-pytest.importorskip("jax")
-
 from functools import partial, reduce
 
 import jax
 import jax.numpy as jnp
+import pytest
 from jax import random
 from jax.tree_util import tree_map
 from numpy.testing import assert_allclose
 
 import nifty8.re as jft
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_minisanity.py b/test/test_re/test_minisanity.py
index 0e8d3072bd7949db2a61250773d497bc4c09231c..4a5496d1aae34c5aa39656a2c2db81fd538829c8 100644
--- a/test/test_re/test_minisanity.py
+++ b/test/test_re/test_minisanity.py
@@ -2,15 +2,15 @@
 
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
+import jax
 import numpy as np
+import pytest
 from numpy.testing import assert_allclose, assert_array_equal
 
 from nifty8.re.minisanity import reduced_residual_stats
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_misc.py b/test/test_re/test_misc.py
index 96b6053c9f3f2142a356d80c16a253abb33c6833..fec927853907aad975371f8f11cd12860c21daa1 100644
--- a/test/test_re/test_misc.py
+++ b/test/test_re/test_misc.py
@@ -1,9 +1,10 @@
+import jax
+import numpy as np
 import pytest
 
-pytest.importorskip("jax")
-
 import nifty8.re as jft
-import numpy as np
+
+jax.config.update("jax_enable_x64", True)
 
 pmp = pytest.mark.parametrize
 
diff --git a/test/test_re/test_ncg.py b/test/test_re/test_ncg.py
index 2adab31d6cb053a0d9e7887433623315a689bfa9..fc17651a1febb92f7f421c1e540d88955059c645 100644
--- a/test/test_re/test_ncg.py
+++ b/test/test_re/test_ncg.py
@@ -1,19 +1,18 @@
 #!/usr/bin/env python3
 
 import sys
-
-import pytest
-
-pytest.importorskip("jax")
-
 from functools import partial
 
+import jax
 import jax.numpy as jnp
+import pytest
 from jax import jit, random, value_and_grad
 from numpy.testing import assert_allclose
 
 import nifty8.re as jft
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_num.py b/test/test_re/test_num.py
index 2243fa255809395e51015cf5fa4cb5b0ce010017..0216a9030a0647f73e878433325113f4a95bc627 100644
--- a/test/test_re/test_num.py
+++ b/test/test_re/test_num.py
@@ -2,16 +2,16 @@
 
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
+import jax
+import numpy as np
 import pytest
-
-pytest.importorskip("jax")
-
 from jax import numpy as jnp
-import numpy as np
 from numpy.testing import assert_allclose, assert_array_equal
 
 from nifty8.re.num import amend_unique, amend_unique_, unique
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_optimize_kl.py b/test/test_re/test_optimize_kl.py
index dc515c92871c06867f2440ef7497310f05a22831..1e5f8a2881b698d40acfd793c9343daa36253d53 100644
--- a/test/test_re/test_optimize_kl.py
+++ b/test/test_re/test_optimize_kl.py
@@ -1,15 +1,12 @@
 #!/usr/bin/env python3
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
 from functools import partial, reduce
 
 import jax
 import jax.numpy as jnp
 import numpy as np
+import pytest
 from jax import random
 from jax.tree_util import tree_map
 from numpy.testing import assert_allclose, assert_array_equal
@@ -17,6 +14,8 @@ from numpy.testing import assert_allclose, assert_array_equal
 import nifty8.re as jft
 from nifty8.re.optimize_kl import concatenate_zip
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 jax.config.update("jax_enable_x64", True)
diff --git a/test/test_re/test_refine.py b/test/test_re/test_refine.py
index 3b0a28c27599dfd992feb818931c880a8ae6528e..8af6d237a7bc5cb878b1581360463d58020a8c8b 100644
--- a/test/test_re/test_refine.py
+++ b/test/test_re/test_refine.py
@@ -2,18 +2,15 @@
 
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-from functools import partial
 import sys
-
-import pytest
-
-pytest.importorskip("jax")
+from functools import partial
 
 import jax
-from jax import random
 import jax.numpy as jnp
-from jax.tree_util import Partial
 import numpy as np
+import pytest
+from jax import random
+from jax.tree_util import Partial
 from numpy.testing import assert_allclose
 from scipy.spatial import distance_matrix
 
@@ -26,6 +23,8 @@ try:
 except ImportError:
     healpy = None
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_refine_healpix.py b/test/test_re/test_refine_healpix.py
index 42e36fa65f145a3eeeaf2ad140581e924089ab99..d3adb2d3792ebed36614de6ba93e173882dfbad0 100644
--- a/test/test_re/test_refine_healpix.py
+++ b/test/test_re/test_refine_healpix.py
@@ -5,24 +5,20 @@
 import sys
 from functools import partial
 
-import pytest
-
-pytest.importorskip("jax")
-pytest.importorskip("healpy")
-
 import jax
-from jax import random
 import jax.numpy as jnp
-from jax.tree_util import Partial
 import numpy as np
+import pytest
+from jax import random
+from jax.tree_util import Partial
 from numpy.testing import assert_allclose, assert_array_equal
 
 import nifty8.re as jft
 
-pmp = pytest.mark.parametrize
-
 jax.config.update("jax_enable_x64", True)
 
+pmp = pytest.mark.parametrize
+
 
 def lst2fixt(lst):
     @pytest.fixture(params=lst)
diff --git a/test/test_re/test_refine_util.py b/test/test_re/test_refine_util.py
index e97b1d1a4a76e9414b5b12613870f3874c969a7f..eba8420f265452a80f377f58e94100fbddd819fc 100644
--- a/test/test_re/test_refine_util.py
+++ b/test/test_re/test_refine_util.py
@@ -4,16 +4,15 @@
 
 from functools import partial
 
-import pytest
-
-pytest.importorskip("jax")
-
 import jax
 import numpy as np
+import pytest
 
 import nifty8.re as jft
 from nifty8.re.refine import util
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_stats_distributions.py b/test/test_re/test_stats_distributions.py
index 00f71349f1c770750a08e900bbfdee96510d6863..abbc78cbaf1db41c2fd12df3dead80673f282c30 100644
--- a/test/test_re/test_stats_distributions.py
+++ b/test/test_re/test_stats_distributions.py
@@ -1,15 +1,13 @@
-import pytest
-
-pytest.importorskip("jax")
-
-from functools import partial
-
+import jax
 import numpy as np
+import pytest
 from numpy.testing import assert_allclose
 from scipy import stats
 
 import nifty8.re as jft
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_vmodel.py b/test/test_re/test_vmodel.py
index 465522d7e75aa31fcaab546455df94799ee38f5e..44f01fb41b8b5f9ac82d591fe092f7769f66c9e1 100644
--- a/test/test_re/test_vmodel.py
+++ b/test/test_re/test_vmodel.py
@@ -1,18 +1,17 @@
 #!/usr/bin/env python3
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
+import jax
 import jax.numpy as jnp
 import jax.random as random
-from jax import vmap
+import pytest
 from jax.tree_util import tree_map
 from numpy.testing import assert_allclose
 
 import nifty8.re as jft
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
@@ -21,7 +20,7 @@ def _v_compare(model, axis_size, in_axes=0, out_axes=0):
 
     x = jft.random_like(random.PRNGKey(10), vmodel.domain)
     res = vmodel(x)
-    gt = vmap(model, in_axes=(in_axes,), out_axes=out_axes)(x)
+    gt = jax.vmap(model, in_axes=(in_axes,), out_axes=out_axes)(x)
 
     tree_map(assert_allclose, gt, res)
 
@@ -73,7 +72,7 @@ def test_key(key):
 
     x = jft.random_like(random.PRNGKey(10), vmodel.domain)
     res = vmodel(x)
-    gt = vmap(model, in_axes=({"a": 0, "b": None},))(x)
+    gt = jax.vmap(model, in_axes=({"a": 0, "b": None},))(x)
     assert_allclose(gt, res)
 
     # Check intended dict handling with dummy input