From d12b16f514f4c11374fa1ec84f75f946ed8fe562 Mon Sep 17 00:00:00 2001 From: Gordian Edenhofer <gordian.edenhofer@gmail.com> Date: Sun, 16 Feb 2025 12:05:19 -0600 Subject: [PATCH] Enable JAX x64 in tests and demos --- demos/0_intro.py | 3 +-- demos/1_tomography.py | 3 +-- demos/a_icr.py | 1 - demos/a_nonlinear_regression.py | 10 +++++----- demos/a_wiener_filter.py | 3 +-- test/test_re/__init__.py | 0 test/test_re/test_correlated_field.py | 10 ++++------ test/test_re/test_custom_map.py | 10 +++------- test/test_re/test_estimate_evidence_lower_bound.py | 7 +++---- test/test_re/test_evi.py | 9 +++------ test/test_re/test_forest_math.py | 6 +++--- test/test_re/test_gauss_markov.py | 9 ++++----- test/test_re/test_hmc_1d_distributions.py | 9 ++++----- test/test_re/test_hmc_hashes.py | 8 +++----- test/test_re/test_hmc_leapfrog.py | 14 ++++++-------- test/test_re/test_hmc_pytree.py | 9 ++++----- test/test_re/test_indexing.py | 8 ++++---- test/test_re/test_lanczos.py | 13 +++++++------ test/test_re/test_likelihood.py | 7 +++---- test/test_re/test_likelihood_impl.py | 7 +++---- test/test_re/test_minisanity.py | 8 ++++---- test/test_re/test_misc.py | 7 ++++--- test/test_re/test_ncg.py | 9 ++++----- test/test_re/test_num.py | 8 ++++---- test/test_re/test_optimize_kl.py | 7 +++---- test/test_re/test_refine.py | 13 ++++++------- test/test_re/test_refine_healpix.py | 14 +++++--------- test/test_re/test_refine_util.py | 7 +++---- test/test_re/test_stats_distributions.py | 10 ++++------ test/test_re/test_vmodel.py | 13 ++++++------- 30 files changed, 105 insertions(+), 137 deletions(-) delete mode 100644 test/test_re/__init__.py diff --git a/demos/0_intro.py b/demos/0_intro.py index 4c61e1eba..0ac93f530 100644 --- a/demos/0_intro.py +++ b/demos/0_intro.py @@ -11,11 +11,10 @@ # %% import jax import matplotlib.pyplot as plt +import nifty8.re as jft from jax import numpy as jnp from jax import random -import nifty8.re as jft - jax.config.update("jax_enable_x64", True) seed = 42 diff --git a/demos/1_tomography.py b/demos/1_tomography.py index dc27d9174..902443683 100644 --- a/demos/1_tomography.py +++ b/demos/1_tomography.py @@ -9,12 +9,11 @@ # %% import jax import matplotlib.pyplot as plt +import nifty8.re as jft import numpy as np from jax import numpy as jnp from jax import random -import nifty8.re as jft - jax.config.update("jax_enable_x64", True) # %% diff --git a/demos/a_icr.py b/demos/a_icr.py index 4791a6f89..526b690f4 100644 --- a/demos/a_icr.py +++ b/demos/a_icr.py @@ -2,7 +2,6 @@ import jax import matplotlib.pyplot as plt - import nifty8.re as jft from jax import random diff --git a/demos/a_nonlinear_regression.py b/demos/a_nonlinear_regression.py index 446880cd8..d019e12e5 100644 --- a/demos/a_nonlinear_regression.py +++ b/demos/a_nonlinear_regression.py @@ -7,15 +7,15 @@ # # Demonstration of a non-linear regression using NIFTy.re # %% +import operator as op +from functools import partial + import jax import matplotlib.pyplot as plt -from jax import numpy as jnp +import nifty8.re as jft import numpy as np +from jax import numpy as jnp from jax import random -from functools import partial -import operator as op - -import nifty8.re as jft jax.config.update("jax_enable_x64", True) diff --git a/demos/a_wiener_filter.py b/demos/a_wiener_filter.py index e44d49423..9513a9539 100644 --- a/demos/a_wiener_filter.py +++ b/demos/a_wiener_filter.py @@ -11,11 +11,10 @@ # %% import jax import matplotlib.pyplot as plt +import nifty8.re as jft from jax import numpy as jnp from jax import random -import nifty8.re as jft - jax.config.update("jax_enable_x64", True) seed = 42 diff --git a/test/test_re/__init__.py b/test/test_re/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/test_re/test_correlated_field.py b/test/test_re/test_correlated_field.py index 91d853eff..0c1004bb7 100644 --- a/test/test_re/test_correlated_field.py +++ b/test/test_re/test_correlated_field.py @@ -1,16 +1,14 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - +import jax import jax.random as random +import nifty8 as ift +import nifty8.re as jft import pytest from numpy.testing import assert_allclose -import nifty8 as ift -import nifty8.re as jft +jax.config.update("jax_enable_x64", True) pmp = pytest.mark.parametrize diff --git a/test/test_re/test_custom_map.py b/test/test_re/test_custom_map.py index 0bbbcbac5..d9cecc239 100644 --- a/test/test_re/test_custom_map.py +++ b/test/test_re/test_custom_map.py @@ -1,23 +1,19 @@ #!/usr/bin/env python3 -import pytest - -pytest.importorskip("jax") - from functools import partial import jax import jax.numpy as jnp +import nifty8.re as jft +import pytest from jax import random from jax.tree_util import tree_map from numpy.testing import assert_allclose -import nifty8.re as jft +jax.config.update("jax_enable_x64", True) pmp = pytest.mark.parametrize -jax.config.update("jax_enable_x64", True) - def f(u, v): return jnp.exp(u @ v) diff --git a/test/test_re/test_estimate_evidence_lower_bound.py b/test/test_re/test_estimate_evidence_lower_bound.py index 8d80c7c8f..072f1992a 100644 --- a/test/test_re/test_estimate_evidence_lower_bound.py +++ b/test/test_re/test_estimate_evidence_lower_bound.py @@ -3,17 +3,16 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause # Author: Matteo Guardiani -import pytest - -pytest.importorskip("jax") - import jax import jax.random as random import numpy as np +import pytest import nifty8 as ift import nifty8.re as jft +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_evi.py b/test/test_re/test_evi.py index 88dab250b..b4f990342 100644 --- a/test/test_re/test_evi.py +++ b/test/test_re/test_evi.py @@ -1,21 +1,18 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - from functools import partial import jax import jax.numpy as jnp +import nifty8.re as jft import numpy as np +import pytest from jax import random from numpy.testing import assert_allclose, assert_array_equal -import nifty8.re as jft - jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_forest_math.py b/test/test_re/test_forest_math.py index 6ec08007f..308d0acbd 100644 --- a/test/test_re/test_forest_math.py +++ b/test/test_re/test_forest_math.py @@ -1,8 +1,8 @@ +import jax +import nifty8.re as jft import pytest -pytest.importorskip("jax") - -import nifty8.re as jft +jax.config.update("jax_enable_x64", True) def test_map_forest_axes_validation(): diff --git a/test/test_re/test_gauss_markov.py b/test/test_re/test_gauss_markov.py index 497c5f553..117ca1e2e 100644 --- a/test/test_re/test_gauss_markov.py +++ b/test/test_re/test_gauss_markov.py @@ -1,18 +1,17 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - import jax -import numpy as np import jax.numpy as jnp +import numpy as np +import pytest from jax import random, vmap from numpy.testing import assert_allclose import nifty8.re as jft +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_hmc_1d_distributions.py b/test/test_re/test_hmc_1d_distributions.py index 8ae98817f..8d159d534 100644 --- a/test/test_re/test_hmc_1d_distributions.py +++ b/test/test_re/test_hmc_1d_distributions.py @@ -1,16 +1,15 @@ import sys +import jax +import nifty8.re as jft import pytest - -pytest.importorskip("jax") - +import scipy from jax import numpy as jnp from jax.scipy import stats from numpy.testing import assert_allclose -import scipy from scipy.special import comb -import nifty8.re as jft +jax.config.update("jax_enable_x64", True) pmp = pytest.mark.parametrize diff --git a/test/test_re/test_hmc_hashes.py b/test/test_re/test_hmc_hashes.py index d114a1838..c906b9784 100644 --- a/test/test_re/test_hmc_hashes.py +++ b/test/test_re/test_hmc_hashes.py @@ -1,14 +1,12 @@ import sys -import pytest - -pytest.importorskip("jax") - import jax +import nifty8.re as jft +import pytest from jax import numpy as jnp from numpy import ndarray -import nifty8.re as jft +jax.config.update("jax_enable_x64", True) NDARRAY_TYPE = [ndarray] diff --git a/test/test_re/test_hmc_leapfrog.py b/test/test_re/test_hmc_leapfrog.py index 7a14826bb..f763efe0f 100644 --- a/test/test_re/test_hmc_leapfrog.py +++ b/test/test_re/test_hmc_leapfrog.py @@ -1,14 +1,12 @@ -import pytest - -pytest.importorskip("jax") - import sys -from jax import grad +import jax +import nifty8.re as jft +import pytest from jax import numpy as jnp from numpy.testing import assert_allclose -import nifty8.re as jft +jax.config.update("jax_enable_x64", True) pmp = pytest.mark.parametrize @@ -29,7 +27,7 @@ def test_leapfrog_energy_conservation(potential_energy, rtol, interactive=False) mass_matrix = jnp.ones(shape=dims) kinetic_energy = lambda p: jnp.sum(p**2 / mass_matrix / 2.0) - potential_energy_gradient = grad(potential_energy) + potential_energy_gradient = jax.grad(potential_energy) positions = [jnp.array([-1.5, -1.55])] momenta = [jnp.array([-1, 1])] for _ in range(25): @@ -52,7 +50,7 @@ def test_leapfrog_energy_conservation(potential_energy, rtol, interactive=False) ): msg = ( f"q: {q}; p: {p}" - f"\nE_tot: {e_pot+e_kin:.2e}; E_pot: {e_pot:.2e}; E_kin: {e_kin:.2e}" + f"\nE_tot: {e_pot + e_kin:.2e}; E_pot: {e_pot:.2e}; E_kin: {e_kin:.2e}" ) print(msg, file=sys.stderr) diff --git a/test/test_re/test_hmc_pytree.py b/test/test_re/test_hmc_pytree.py index 46ad659b4..0c26f1479 100644 --- a/test/test_re/test_hmc_pytree.py +++ b/test/test_re/test_hmc_pytree.py @@ -1,14 +1,13 @@ -import pytest - -pytest.importorskip("jax") - from functools import partial +import jax +import nifty8.re as jft +import pytest from jax import numpy as jnp from jax.tree_util import tree_leaves from numpy.testing import assert_array_equal -import nifty8.re as jft +jax.config.update("jax_enable_x64", True) def test_hmc_pytree(): diff --git a/test/test_re/test_indexing.py b/test/test_re/test_indexing.py index 6afb4306c..ab5ba65e2 100644 --- a/test/test_re/test_indexing.py +++ b/test/test_re/test_indexing.py @@ -2,16 +2,16 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - +import jax import numpy as np +import pytest from jax.tree_util import tree_flatten, tree_map from nifty8.re.multi_grid.grid import FlatGrid, Grid, MGrid, OpenGrid from nifty8.re.multi_grid.grid_impl import HEALPixGrid +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_lanczos.py b/test/test_re/test_lanczos.py index ff397e0e6..179a81f06 100644 --- a/test/test_re/test_lanczos.py +++ b/test/test_re/test_lanczos.py @@ -2,20 +2,21 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - -from functools import partial import sys -from jax import random +from functools import partial + +import jax import jax.numpy as jnp import numpy as np +import pytest +from jax import random from numpy.testing import assert_allclose from scipy.spatial import distance_matrix import nifty8.re as jft +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_likelihood.py b/test/test_re/test_likelihood.py index 381068108..a15bccf91 100644 --- a/test/test_re/test_likelihood.py +++ b/test/test_re/test_likelihood.py @@ -2,13 +2,10 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - from functools import partial import jax +import pytest from jax import numpy as jnp from jax import random from jax.tree_util import tree_map @@ -17,6 +14,8 @@ from numpy.testing import assert_allclose import nifty8.re as jft from nifty8.re.likelihood import partial_insert_and_remove as jpartial +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize jax.config.update("jax_enable_x64", True) diff --git a/test/test_re/test_likelihood_impl.py b/test/test_re/test_likelihood_impl.py index f423a1fac..c251e4679 100644 --- a/test/test_re/test_likelihood_impl.py +++ b/test/test_re/test_likelihood_impl.py @@ -1,19 +1,18 @@ #!/usr/bin/env python3 -import pytest - -pytest.importorskip("jax") - from functools import partial, reduce import jax import jax.numpy as jnp +import pytest from jax import random from jax.tree_util import tree_map from numpy.testing import assert_allclose import nifty8.re as jft +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_minisanity.py b/test/test_re/test_minisanity.py index 0e8d3072b..4a5496d1a 100644 --- a/test/test_re/test_minisanity.py +++ b/test/test_re/test_minisanity.py @@ -2,15 +2,15 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - +import jax import numpy as np +import pytest from numpy.testing import assert_allclose, assert_array_equal from nifty8.re.minisanity import reduced_residual_stats +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_misc.py b/test/test_re/test_misc.py index 96b6053c9..fec927853 100644 --- a/test/test_re/test_misc.py +++ b/test/test_re/test_misc.py @@ -1,9 +1,10 @@ +import jax +import numpy as np import pytest -pytest.importorskip("jax") - import nifty8.re as jft -import numpy as np + +jax.config.update("jax_enable_x64", True) pmp = pytest.mark.parametrize diff --git a/test/test_re/test_ncg.py b/test/test_re/test_ncg.py index 2adab31d6..fc17651a1 100644 --- a/test/test_re/test_ncg.py +++ b/test/test_re/test_ncg.py @@ -1,19 +1,18 @@ #!/usr/bin/env python3 import sys - -import pytest - -pytest.importorskip("jax") - from functools import partial +import jax import jax.numpy as jnp +import pytest from jax import jit, random, value_and_grad from numpy.testing import assert_allclose import nifty8.re as jft +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_num.py b/test/test_re/test_num.py index 2243fa255..0216a9030 100644 --- a/test/test_re/test_num.py +++ b/test/test_re/test_num.py @@ -2,16 +2,16 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause +import jax +import numpy as np import pytest - -pytest.importorskip("jax") - from jax import numpy as jnp -import numpy as np from numpy.testing import assert_allclose, assert_array_equal from nifty8.re.num import amend_unique, amend_unique_, unique +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_optimize_kl.py b/test/test_re/test_optimize_kl.py index dc515c928..1e5f8a288 100644 --- a/test/test_re/test_optimize_kl.py +++ b/test/test_re/test_optimize_kl.py @@ -1,15 +1,12 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - from functools import partial, reduce import jax import jax.numpy as jnp import numpy as np +import pytest from jax import random from jax.tree_util import tree_map from numpy.testing import assert_allclose, assert_array_equal @@ -17,6 +14,8 @@ from numpy.testing import assert_allclose, assert_array_equal import nifty8.re as jft from nifty8.re.optimize_kl import concatenate_zip +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize jax.config.update("jax_enable_x64", True) diff --git a/test/test_re/test_refine.py b/test/test_re/test_refine.py index 3b0a28c27..8af6d237a 100644 --- a/test/test_re/test_refine.py +++ b/test/test_re/test_refine.py @@ -2,18 +2,15 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -from functools import partial import sys - -import pytest - -pytest.importorskip("jax") +from functools import partial import jax -from jax import random import jax.numpy as jnp -from jax.tree_util import Partial import numpy as np +import pytest +from jax import random +from jax.tree_util import Partial from numpy.testing import assert_allclose from scipy.spatial import distance_matrix @@ -26,6 +23,8 @@ try: except ImportError: healpy = None +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_refine_healpix.py b/test/test_re/test_refine_healpix.py index 42e36fa65..d3adb2d37 100644 --- a/test/test_re/test_refine_healpix.py +++ b/test/test_re/test_refine_healpix.py @@ -5,24 +5,20 @@ import sys from functools import partial -import pytest - -pytest.importorskip("jax") -pytest.importorskip("healpy") - import jax -from jax import random import jax.numpy as jnp -from jax.tree_util import Partial import numpy as np +import pytest +from jax import random +from jax.tree_util import Partial from numpy.testing import assert_allclose, assert_array_equal import nifty8.re as jft -pmp = pytest.mark.parametrize - jax.config.update("jax_enable_x64", True) +pmp = pytest.mark.parametrize + def lst2fixt(lst): @pytest.fixture(params=lst) diff --git a/test/test_re/test_refine_util.py b/test/test_re/test_refine_util.py index e97b1d1a4..eba8420f2 100644 --- a/test/test_re/test_refine_util.py +++ b/test/test_re/test_refine_util.py @@ -4,16 +4,15 @@ from functools import partial -import pytest - -pytest.importorskip("jax") - import jax import numpy as np +import pytest import nifty8.re as jft from nifty8.re.refine import util +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_stats_distributions.py b/test/test_re/test_stats_distributions.py index 00f71349f..abbc78cba 100644 --- a/test/test_re/test_stats_distributions.py +++ b/test/test_re/test_stats_distributions.py @@ -1,15 +1,13 @@ -import pytest - -pytest.importorskip("jax") - -from functools import partial - +import jax import numpy as np +import pytest from numpy.testing import assert_allclose from scipy import stats import nifty8.re as jft +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_vmodel.py b/test/test_re/test_vmodel.py index 465522d7e..44f01fb41 100644 --- a/test/test_re/test_vmodel.py +++ b/test/test_re/test_vmodel.py @@ -1,18 +1,17 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - +import jax import jax.numpy as jnp import jax.random as random -from jax import vmap +import pytest from jax.tree_util import tree_map from numpy.testing import assert_allclose import nifty8.re as jft +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize @@ -21,7 +20,7 @@ def _v_compare(model, axis_size, in_axes=0, out_axes=0): x = jft.random_like(random.PRNGKey(10), vmodel.domain) res = vmodel(x) - gt = vmap(model, in_axes=(in_axes,), out_axes=out_axes)(x) + gt = jax.vmap(model, in_axes=(in_axes,), out_axes=out_axes)(x) tree_map(assert_allclose, gt, res) @@ -73,7 +72,7 @@ def test_key(key): x = jft.random_like(random.PRNGKey(10), vmodel.domain) res = vmodel(x) - gt = vmap(model, in_axes=({"a": 0, "b": None},))(x) + gt = jax.vmap(model, in_axes=({"a": 0, "b": None},))(x) assert_allclose(gt, res) # Check intended dict handling with dummy input -- GitLab