From d12b16f514f4c11374fa1ec84f75f946ed8fe562 Mon Sep 17 00:00:00 2001
From: Gordian Edenhofer <gordian.edenhofer@gmail.com>
Date: Sun, 16 Feb 2025 12:05:19 -0600
Subject: [PATCH] Enable JAX x64 in tests and demos

---
 demos/0_intro.py                                   |  3 +--
 demos/1_tomography.py                              |  3 +--
 demos/a_icr.py                                     |  1 -
 demos/a_nonlinear_regression.py                    | 10 +++++-----
 demos/a_wiener_filter.py                           |  3 +--
 test/test_re/__init__.py                           |  0
 test/test_re/test_correlated_field.py              | 10 ++++------
 test/test_re/test_custom_map.py                    | 10 +++-------
 test/test_re/test_estimate_evidence_lower_bound.py |  7 +++----
 test/test_re/test_evi.py                           |  9 +++------
 test/test_re/test_forest_math.py                   |  6 +++---
 test/test_re/test_gauss_markov.py                  |  9 ++++-----
 test/test_re/test_hmc_1d_distributions.py          |  9 ++++-----
 test/test_re/test_hmc_hashes.py                    |  8 +++-----
 test/test_re/test_hmc_leapfrog.py                  | 14 ++++++--------
 test/test_re/test_hmc_pytree.py                    |  9 ++++-----
 test/test_re/test_indexing.py                      |  8 ++++----
 test/test_re/test_lanczos.py                       | 13 +++++++------
 test/test_re/test_likelihood.py                    |  7 +++----
 test/test_re/test_likelihood_impl.py               |  7 +++----
 test/test_re/test_minisanity.py                    |  8 ++++----
 test/test_re/test_misc.py                          |  7 ++++---
 test/test_re/test_ncg.py                           |  9 ++++-----
 test/test_re/test_num.py                           |  8 ++++----
 test/test_re/test_optimize_kl.py                   |  7 +++----
 test/test_re/test_refine.py                        | 13 ++++++-------
 test/test_re/test_refine_healpix.py                | 14 +++++---------
 test/test_re/test_refine_util.py                   |  7 +++----
 test/test_re/test_stats_distributions.py           | 10 ++++------
 test/test_re/test_vmodel.py                        | 13 ++++++-------
 30 files changed, 105 insertions(+), 137 deletions(-)
 delete mode 100644 test/test_re/__init__.py

diff --git a/demos/0_intro.py b/demos/0_intro.py
index 4c61e1eba..0ac93f530 100644
--- a/demos/0_intro.py
+++ b/demos/0_intro.py
@@ -11,11 +11,10 @@
 # %%
 import jax
 import matplotlib.pyplot as plt
+import nifty8.re as jft
 from jax import numpy as jnp
 from jax import random
 
-import nifty8.re as jft
-
 jax.config.update("jax_enable_x64", True)
 
 seed = 42
diff --git a/demos/1_tomography.py b/demos/1_tomography.py
index dc27d9174..902443683 100644
--- a/demos/1_tomography.py
+++ b/demos/1_tomography.py
@@ -9,12 +9,11 @@
 # %%
 import jax
 import matplotlib.pyplot as plt
+import nifty8.re as jft
 import numpy as np
 from jax import numpy as jnp
 from jax import random
 
-import nifty8.re as jft
-
 jax.config.update("jax_enable_x64", True)
 
 # %%
diff --git a/demos/a_icr.py b/demos/a_icr.py
index 4791a6f89..526b690f4 100644
--- a/demos/a_icr.py
+++ b/demos/a_icr.py
@@ -2,7 +2,6 @@
 
 import jax
 import matplotlib.pyplot as plt
-
 import nifty8.re as jft
 from jax import random
 
diff --git a/demos/a_nonlinear_regression.py b/demos/a_nonlinear_regression.py
index 446880cd8..d019e12e5 100644
--- a/demos/a_nonlinear_regression.py
+++ b/demos/a_nonlinear_regression.py
@@ -7,15 +7,15 @@
 # # Demonstration of a non-linear regression using NIFTy.re
 
 # %%
+import operator as op
+from functools import partial
+
 import jax
 import matplotlib.pyplot as plt
-from jax import numpy as jnp
+import nifty8.re as jft
 import numpy as np
+from jax import numpy as jnp
 from jax import random
-from functools import partial
-import operator as op
-
-import nifty8.re as jft
 
 jax.config.update("jax_enable_x64", True)
 
diff --git a/demos/a_wiener_filter.py b/demos/a_wiener_filter.py
index e44d49423..9513a9539 100644
--- a/demos/a_wiener_filter.py
+++ b/demos/a_wiener_filter.py
@@ -11,11 +11,10 @@
 # %%
 import jax
 import matplotlib.pyplot as plt
+import nifty8.re as jft
 from jax import numpy as jnp
 from jax import random
 
-import nifty8.re as jft
-
 jax.config.update("jax_enable_x64", True)
 
 seed = 42
diff --git a/test/test_re/__init__.py b/test/test_re/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/test/test_re/test_correlated_field.py b/test/test_re/test_correlated_field.py
index 91d853eff..0c1004bb7 100644
--- a/test/test_re/test_correlated_field.py
+++ b/test/test_re/test_correlated_field.py
@@ -1,16 +1,14 @@
 #!/usr/bin/env python3
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
+import jax
 import jax.random as random
+import nifty8 as ift
+import nifty8.re as jft
 import pytest
 from numpy.testing import assert_allclose
 
-import nifty8 as ift
-import nifty8.re as jft
+jax.config.update("jax_enable_x64", True)
 
 pmp = pytest.mark.parametrize
 
diff --git a/test/test_re/test_custom_map.py b/test/test_re/test_custom_map.py
index 0bbbcbac5..d9cecc239 100644
--- a/test/test_re/test_custom_map.py
+++ b/test/test_re/test_custom_map.py
@@ -1,23 +1,19 @@
 #!/usr/bin/env python3
 
-import pytest
-
-pytest.importorskip("jax")
-
 from functools import partial
 
 import jax
 import jax.numpy as jnp
+import nifty8.re as jft
+import pytest
 from jax import random
 from jax.tree_util import tree_map
 from numpy.testing import assert_allclose
 
-import nifty8.re as jft
+jax.config.update("jax_enable_x64", True)
 
 pmp = pytest.mark.parametrize
 
-jax.config.update("jax_enable_x64", True)
-
 
 def f(u, v):
     return jnp.exp(u @ v)
diff --git a/test/test_re/test_estimate_evidence_lower_bound.py b/test/test_re/test_estimate_evidence_lower_bound.py
index 8d80c7c8f..072f1992a 100644
--- a/test/test_re/test_estimate_evidence_lower_bound.py
+++ b/test/test_re/test_estimate_evidence_lower_bound.py
@@ -3,17 +3,16 @@
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 # Author: Matteo Guardiani
 
-import pytest
-
-pytest.importorskip("jax")
-
 import jax
 import jax.random as random
 import numpy as np
+import pytest
 
 import nifty8 as ift
 import nifty8.re as jft
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_evi.py b/test/test_re/test_evi.py
index 88dab250b..b4f990342 100644
--- a/test/test_re/test_evi.py
+++ b/test/test_re/test_evi.py
@@ -1,21 +1,18 @@
 #!/usr/bin/env python3
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
 from functools import partial
 
 import jax
 import jax.numpy as jnp
+import nifty8.re as jft
 import numpy as np
+import pytest
 from jax import random
 from numpy.testing import assert_allclose, assert_array_equal
 
-import nifty8.re as jft
-
 jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_forest_math.py b/test/test_re/test_forest_math.py
index 6ec08007f..308d0acbd 100644
--- a/test/test_re/test_forest_math.py
+++ b/test/test_re/test_forest_math.py
@@ -1,8 +1,8 @@
+import jax
+import nifty8.re as jft
 import pytest
 
-pytest.importorskip("jax")
-
-import nifty8.re as jft
+jax.config.update("jax_enable_x64", True)
 
 
 def test_map_forest_axes_validation():
diff --git a/test/test_re/test_gauss_markov.py b/test/test_re/test_gauss_markov.py
index 497c5f553..117ca1e2e 100644
--- a/test/test_re/test_gauss_markov.py
+++ b/test/test_re/test_gauss_markov.py
@@ -1,18 +1,17 @@
 #!/usr/bin/env python3
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
 import jax
-import numpy as np
 import jax.numpy as jnp
+import numpy as np
+import pytest
 from jax import random, vmap
 from numpy.testing import assert_allclose
 
 import nifty8.re as jft
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_hmc_1d_distributions.py b/test/test_re/test_hmc_1d_distributions.py
index 8ae98817f..8d159d534 100644
--- a/test/test_re/test_hmc_1d_distributions.py
+++ b/test/test_re/test_hmc_1d_distributions.py
@@ -1,16 +1,15 @@
 import sys
 
+import jax
+import nifty8.re as jft
 import pytest
-
-pytest.importorskip("jax")
-
+import scipy
 from jax import numpy as jnp
 from jax.scipy import stats
 from numpy.testing import assert_allclose
-import scipy
 from scipy.special import comb
 
-import nifty8.re as jft
+jax.config.update("jax_enable_x64", True)
 
 pmp = pytest.mark.parametrize
 
diff --git a/test/test_re/test_hmc_hashes.py b/test/test_re/test_hmc_hashes.py
index d114a1838..c906b9784 100644
--- a/test/test_re/test_hmc_hashes.py
+++ b/test/test_re/test_hmc_hashes.py
@@ -1,14 +1,12 @@
 import sys
 
-import pytest
-
-pytest.importorskip("jax")
-
 import jax
+import nifty8.re as jft
+import pytest
 from jax import numpy as jnp
 from numpy import ndarray
 
-import nifty8.re as jft
+jax.config.update("jax_enable_x64", True)
 
 NDARRAY_TYPE = [ndarray]
 
diff --git a/test/test_re/test_hmc_leapfrog.py b/test/test_re/test_hmc_leapfrog.py
index 7a14826bb..f763efe0f 100644
--- a/test/test_re/test_hmc_leapfrog.py
+++ b/test/test_re/test_hmc_leapfrog.py
@@ -1,14 +1,12 @@
-import pytest
-
-pytest.importorskip("jax")
-
 import sys
 
-from jax import grad
+import jax
+import nifty8.re as jft
+import pytest
 from jax import numpy as jnp
 from numpy.testing import assert_allclose
 
-import nifty8.re as jft
+jax.config.update("jax_enable_x64", True)
 
 pmp = pytest.mark.parametrize
 
@@ -29,7 +27,7 @@ def test_leapfrog_energy_conservation(potential_energy, rtol, interactive=False)
     mass_matrix = jnp.ones(shape=dims)
     kinetic_energy = lambda p: jnp.sum(p**2 / mass_matrix / 2.0)
 
-    potential_energy_gradient = grad(potential_energy)
+    potential_energy_gradient = jax.grad(potential_energy)
     positions = [jnp.array([-1.5, -1.55])]
     momenta = [jnp.array([-1, 1])]
     for _ in range(25):
@@ -52,7 +50,7 @@ def test_leapfrog_energy_conservation(potential_energy, rtol, interactive=False)
     ):
         msg = (
             f"q: {q}; p: {p}"
-            f"\nE_tot: {e_pot+e_kin:.2e}; E_pot: {e_pot:.2e}; E_kin: {e_kin:.2e}"
+            f"\nE_tot: {e_pot + e_kin:.2e}; E_pot: {e_pot:.2e}; E_kin: {e_kin:.2e}"
         )
         print(msg, file=sys.stderr)
 
diff --git a/test/test_re/test_hmc_pytree.py b/test/test_re/test_hmc_pytree.py
index 46ad659b4..0c26f1479 100644
--- a/test/test_re/test_hmc_pytree.py
+++ b/test/test_re/test_hmc_pytree.py
@@ -1,14 +1,13 @@
-import pytest
-
-pytest.importorskip("jax")
-
 from functools import partial
 
+import jax
+import nifty8.re as jft
+import pytest
 from jax import numpy as jnp
 from jax.tree_util import tree_leaves
 from numpy.testing import assert_array_equal
 
-import nifty8.re as jft
+jax.config.update("jax_enable_x64", True)
 
 
 def test_hmc_pytree():
diff --git a/test/test_re/test_indexing.py b/test/test_re/test_indexing.py
index 6afb4306c..ab5ba65e2 100644
--- a/test/test_re/test_indexing.py
+++ b/test/test_re/test_indexing.py
@@ -2,16 +2,16 @@
 
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
+import jax
 import numpy as np
+import pytest
 from jax.tree_util import tree_flatten, tree_map
 
 from nifty8.re.multi_grid.grid import FlatGrid, Grid, MGrid, OpenGrid
 from nifty8.re.multi_grid.grid_impl import HEALPixGrid
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_lanczos.py b/test/test_re/test_lanczos.py
index ff397e0e6..179a81f06 100644
--- a/test/test_re/test_lanczos.py
+++ b/test/test_re/test_lanczos.py
@@ -2,20 +2,21 @@
 
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
-from functools import partial
 import sys
-from jax import random
+from functools import partial
+
+import jax
 import jax.numpy as jnp
 import numpy as np
+import pytest
+from jax import random
 from numpy.testing import assert_allclose
 from scipy.spatial import distance_matrix
 
 import nifty8.re as jft
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_likelihood.py b/test/test_re/test_likelihood.py
index 381068108..a15bccf91 100644
--- a/test/test_re/test_likelihood.py
+++ b/test/test_re/test_likelihood.py
@@ -2,13 +2,10 @@
 
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
 from functools import partial
 
 import jax
+import pytest
 from jax import numpy as jnp
 from jax import random
 from jax.tree_util import tree_map
@@ -17,6 +14,8 @@ from numpy.testing import assert_allclose
 import nifty8.re as jft
 from nifty8.re.likelihood import partial_insert_and_remove as jpartial
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 jax.config.update("jax_enable_x64", True)
diff --git a/test/test_re/test_likelihood_impl.py b/test/test_re/test_likelihood_impl.py
index f423a1fac..c251e4679 100644
--- a/test/test_re/test_likelihood_impl.py
+++ b/test/test_re/test_likelihood_impl.py
@@ -1,19 +1,18 @@
 #!/usr/bin/env python3
 
-import pytest
-
-pytest.importorskip("jax")
-
 from functools import partial, reduce
 
 import jax
 import jax.numpy as jnp
+import pytest
 from jax import random
 from jax.tree_util import tree_map
 from numpy.testing import assert_allclose
 
 import nifty8.re as jft
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_minisanity.py b/test/test_re/test_minisanity.py
index 0e8d3072b..4a5496d1a 100644
--- a/test/test_re/test_minisanity.py
+++ b/test/test_re/test_minisanity.py
@@ -2,15 +2,15 @@
 
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
+import jax
 import numpy as np
+import pytest
 from numpy.testing import assert_allclose, assert_array_equal
 
 from nifty8.re.minisanity import reduced_residual_stats
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_misc.py b/test/test_re/test_misc.py
index 96b6053c9..fec927853 100644
--- a/test/test_re/test_misc.py
+++ b/test/test_re/test_misc.py
@@ -1,9 +1,10 @@
+import jax
+import numpy as np
 import pytest
 
-pytest.importorskip("jax")
-
 import nifty8.re as jft
-import numpy as np
+
+jax.config.update("jax_enable_x64", True)
 
 pmp = pytest.mark.parametrize
 
diff --git a/test/test_re/test_ncg.py b/test/test_re/test_ncg.py
index 2adab31d6..fc17651a1 100644
--- a/test/test_re/test_ncg.py
+++ b/test/test_re/test_ncg.py
@@ -1,19 +1,18 @@
 #!/usr/bin/env python3
 
 import sys
-
-import pytest
-
-pytest.importorskip("jax")
-
 from functools import partial
 
+import jax
 import jax.numpy as jnp
+import pytest
 from jax import jit, random, value_and_grad
 from numpy.testing import assert_allclose
 
 import nifty8.re as jft
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_num.py b/test/test_re/test_num.py
index 2243fa255..0216a9030 100644
--- a/test/test_re/test_num.py
+++ b/test/test_re/test_num.py
@@ -2,16 +2,16 @@
 
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
+import jax
+import numpy as np
 import pytest
-
-pytest.importorskip("jax")
-
 from jax import numpy as jnp
-import numpy as np
 from numpy.testing import assert_allclose, assert_array_equal
 
 from nifty8.re.num import amend_unique, amend_unique_, unique
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_optimize_kl.py b/test/test_re/test_optimize_kl.py
index dc515c928..1e5f8a288 100644
--- a/test/test_re/test_optimize_kl.py
+++ b/test/test_re/test_optimize_kl.py
@@ -1,15 +1,12 @@
 #!/usr/bin/env python3
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
 from functools import partial, reduce
 
 import jax
 import jax.numpy as jnp
 import numpy as np
+import pytest
 from jax import random
 from jax.tree_util import tree_map
 from numpy.testing import assert_allclose, assert_array_equal
@@ -17,6 +14,8 @@ from numpy.testing import assert_allclose, assert_array_equal
 import nifty8.re as jft
 from nifty8.re.optimize_kl import concatenate_zip
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 jax.config.update("jax_enable_x64", True)
diff --git a/test/test_re/test_refine.py b/test/test_re/test_refine.py
index 3b0a28c27..8af6d237a 100644
--- a/test/test_re/test_refine.py
+++ b/test/test_re/test_refine.py
@@ -2,18 +2,15 @@
 
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-from functools import partial
 import sys
-
-import pytest
-
-pytest.importorskip("jax")
+from functools import partial
 
 import jax
-from jax import random
 import jax.numpy as jnp
-from jax.tree_util import Partial
 import numpy as np
+import pytest
+from jax import random
+from jax.tree_util import Partial
 from numpy.testing import assert_allclose
 from scipy.spatial import distance_matrix
 
@@ -26,6 +23,8 @@ try:
 except ImportError:
     healpy = None
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_refine_healpix.py b/test/test_re/test_refine_healpix.py
index 42e36fa65..d3adb2d37 100644
--- a/test/test_re/test_refine_healpix.py
+++ b/test/test_re/test_refine_healpix.py
@@ -5,24 +5,20 @@
 import sys
 from functools import partial
 
-import pytest
-
-pytest.importorskip("jax")
-pytest.importorskip("healpy")
-
 import jax
-from jax import random
 import jax.numpy as jnp
-from jax.tree_util import Partial
 import numpy as np
+import pytest
+from jax import random
+from jax.tree_util import Partial
 from numpy.testing import assert_allclose, assert_array_equal
 
 import nifty8.re as jft
 
-pmp = pytest.mark.parametrize
-
 jax.config.update("jax_enable_x64", True)
 
+pmp = pytest.mark.parametrize
+
 
 def lst2fixt(lst):
     @pytest.fixture(params=lst)
diff --git a/test/test_re/test_refine_util.py b/test/test_re/test_refine_util.py
index e97b1d1a4..eba8420f2 100644
--- a/test/test_re/test_refine_util.py
+++ b/test/test_re/test_refine_util.py
@@ -4,16 +4,15 @@
 
 from functools import partial
 
-import pytest
-
-pytest.importorskip("jax")
-
 import jax
 import numpy as np
+import pytest
 
 import nifty8.re as jft
 from nifty8.re.refine import util
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_stats_distributions.py b/test/test_re/test_stats_distributions.py
index 00f71349f..abbc78cba 100644
--- a/test/test_re/test_stats_distributions.py
+++ b/test/test_re/test_stats_distributions.py
@@ -1,15 +1,13 @@
-import pytest
-
-pytest.importorskip("jax")
-
-from functools import partial
-
+import jax
 import numpy as np
+import pytest
 from numpy.testing import assert_allclose
 from scipy import stats
 
 import nifty8.re as jft
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
diff --git a/test/test_re/test_vmodel.py b/test/test_re/test_vmodel.py
index 465522d7e..44f01fb41 100644
--- a/test/test_re/test_vmodel.py
+++ b/test/test_re/test_vmodel.py
@@ -1,18 +1,17 @@
 #!/usr/bin/env python3
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import pytest
-
-pytest.importorskip("jax")
-
+import jax
 import jax.numpy as jnp
 import jax.random as random
-from jax import vmap
+import pytest
 from jax.tree_util import tree_map
 from numpy.testing import assert_allclose
 
 import nifty8.re as jft
 
+jax.config.update("jax_enable_x64", True)
+
 pmp = pytest.mark.parametrize
 
 
@@ -21,7 +20,7 @@ def _v_compare(model, axis_size, in_axes=0, out_axes=0):
 
     x = jft.random_like(random.PRNGKey(10), vmodel.domain)
     res = vmodel(x)
-    gt = vmap(model, in_axes=(in_axes,), out_axes=out_axes)(x)
+    gt = jax.vmap(model, in_axes=(in_axes,), out_axes=out_axes)(x)
 
     tree_map(assert_allclose, gt, res)
 
@@ -73,7 +72,7 @@ def test_key(key):
 
     x = jft.random_like(random.PRNGKey(10), vmodel.domain)
     res = vmodel(x)
-    gt = vmap(model, in_axes=({"a": 0, "b": None},))(x)
+    gt = jax.vmap(model, in_axes=({"a": 0, "b": None},))(x)
     assert_allclose(gt, res)
 
     # Check intended dict handling with dummy input
-- 
GitLab