diff --git a/demos/0_intro.py b/demos/0_intro.py index 4c61e1ebab94db5e629cb8fbfa7212517815aa04..0ac93f5300d92ff52ddc6982abe2dd73924e5c41 100644 --- a/demos/0_intro.py +++ b/demos/0_intro.py @@ -11,11 +11,10 @@ # %% import jax import matplotlib.pyplot as plt +import nifty8.re as jft from jax import numpy as jnp from jax import random -import nifty8.re as jft - jax.config.update("jax_enable_x64", True) seed = 42 diff --git a/demos/1_tomography.py b/demos/1_tomography.py index dc27d917407b7047f94b16ecf72c8f2e44af5db3..90244368329eaa48d2eef5aabf18f6c6153273a0 100644 --- a/demos/1_tomography.py +++ b/demos/1_tomography.py @@ -9,12 +9,11 @@ # %% import jax import matplotlib.pyplot as plt +import nifty8.re as jft import numpy as np from jax import numpy as jnp from jax import random -import nifty8.re as jft - jax.config.update("jax_enable_x64", True) # %% diff --git a/demos/a_icr.py b/demos/a_icr.py index 4791a6f89ede5638a462494acb3202ee76e201ec..526b690f4395ae01154122ae72f3f9d5b478e47e 100644 --- a/demos/a_icr.py +++ b/demos/a_icr.py @@ -2,7 +2,6 @@ import jax import matplotlib.pyplot as plt - import nifty8.re as jft from jax import random diff --git a/demos/a_nonlinear_regression.py b/demos/a_nonlinear_regression.py index 446880cd8e2cc4e089cdf7e4686c9632a4ff1426..d019e12e5c5f2f2fda34bb53764fa37ff7758066 100644 --- a/demos/a_nonlinear_regression.py +++ b/demos/a_nonlinear_regression.py @@ -7,15 +7,15 @@ # # Demonstration of a non-linear regression using NIFTy.re # %% +import operator as op +from functools import partial + import jax import matplotlib.pyplot as plt -from jax import numpy as jnp +import nifty8.re as jft import numpy as np +from jax import numpy as jnp from jax import random -from functools import partial -import operator as op - -import nifty8.re as jft jax.config.update("jax_enable_x64", True) diff --git a/demos/a_wiener_filter.py b/demos/a_wiener_filter.py index e44d494238e8cf8ddc55b03c8c5f9ef2a519e7c0..9513a95398b23027f0887a756154e2ce64b18f88 100644 --- a/demos/a_wiener_filter.py +++ b/demos/a_wiener_filter.py @@ -11,11 +11,10 @@ # %% import jax import matplotlib.pyplot as plt +import nifty8.re as jft from jax import numpy as jnp from jax import random -import nifty8.re as jft - jax.config.update("jax_enable_x64", True) seed = 42 diff --git a/test/test_re/__init__.py b/test/test_re/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/test/test_re/test_correlated_field.py b/test/test_re/test_correlated_field.py index 91d853eff3421304e068665cc3fe9fc1e6ab6b5b..0c1004bb7b36a3acca88c408ce45a8049dcd1d06 100644 --- a/test/test_re/test_correlated_field.py +++ b/test/test_re/test_correlated_field.py @@ -1,16 +1,14 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - +import jax import jax.random as random +import nifty8 as ift +import nifty8.re as jft import pytest from numpy.testing import assert_allclose -import nifty8 as ift -import nifty8.re as jft +jax.config.update("jax_enable_x64", True) pmp = pytest.mark.parametrize diff --git a/test/test_re/test_custom_map.py b/test/test_re/test_custom_map.py index 0bbbcbac596605c2848d865da5c9edd04c0bd772..d9cecc2399be2f5ab076a0e22d7960d275d60414 100644 --- a/test/test_re/test_custom_map.py +++ b/test/test_re/test_custom_map.py @@ -1,23 +1,19 @@ #!/usr/bin/env python3 -import pytest - -pytest.importorskip("jax") - from functools import partial import jax import jax.numpy as jnp +import nifty8.re as jft +import pytest from jax import random from jax.tree_util import tree_map from numpy.testing import assert_allclose -import nifty8.re as jft +jax.config.update("jax_enable_x64", True) pmp = pytest.mark.parametrize -jax.config.update("jax_enable_x64", True) - def f(u, v): return jnp.exp(u @ v) diff --git a/test/test_re/test_estimate_evidence_lower_bound.py b/test/test_re/test_estimate_evidence_lower_bound.py index 8d80c7c8f2a6e9c18facc4b20328c3bb9c248df7..072f1992a53a19b3f8b7636d3715ca9fe0048b9b 100644 --- a/test/test_re/test_estimate_evidence_lower_bound.py +++ b/test/test_re/test_estimate_evidence_lower_bound.py @@ -3,17 +3,16 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause # Author: Matteo Guardiani -import pytest - -pytest.importorskip("jax") - import jax import jax.random as random import numpy as np +import pytest import nifty8 as ift import nifty8.re as jft +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_evi.py b/test/test_re/test_evi.py index 88dab250bdbfc9ab7dc485dc97804f6168cd271b..b4f990342030792e213cb8c72a08b5f3943db539 100644 --- a/test/test_re/test_evi.py +++ b/test/test_re/test_evi.py @@ -1,21 +1,18 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - from functools import partial import jax import jax.numpy as jnp +import nifty8.re as jft import numpy as np +import pytest from jax import random from numpy.testing import assert_allclose, assert_array_equal -import nifty8.re as jft - jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_forest_math.py b/test/test_re/test_forest_math.py index 6ec08007f4bae0ca7f79381bf71a3b9784939344..308d0acbdeed9f3f431c214b86917c57b0df2886 100644 --- a/test/test_re/test_forest_math.py +++ b/test/test_re/test_forest_math.py @@ -1,8 +1,8 @@ +import jax +import nifty8.re as jft import pytest -pytest.importorskip("jax") - -import nifty8.re as jft +jax.config.update("jax_enable_x64", True) def test_map_forest_axes_validation(): diff --git a/test/test_re/test_gauss_markov.py b/test/test_re/test_gauss_markov.py index 497c5f5533c5b601748ca9f87874740b01105aac..117ca1e2e943a29b85089fb299cda7c8e3ad02ab 100644 --- a/test/test_re/test_gauss_markov.py +++ b/test/test_re/test_gauss_markov.py @@ -1,18 +1,17 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - import jax -import numpy as np import jax.numpy as jnp +import numpy as np +import pytest from jax import random, vmap from numpy.testing import assert_allclose import nifty8.re as jft +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_hmc_1d_distributions.py b/test/test_re/test_hmc_1d_distributions.py index 8ae98817f71878ef6ace39faca82e721eeabf046..8d159d5342e8c5c79cd0b118a086ce070a90fd50 100644 --- a/test/test_re/test_hmc_1d_distributions.py +++ b/test/test_re/test_hmc_1d_distributions.py @@ -1,16 +1,15 @@ import sys +import jax +import nifty8.re as jft import pytest - -pytest.importorskip("jax") - +import scipy from jax import numpy as jnp from jax.scipy import stats from numpy.testing import assert_allclose -import scipy from scipy.special import comb -import nifty8.re as jft +jax.config.update("jax_enable_x64", True) pmp = pytest.mark.parametrize diff --git a/test/test_re/test_hmc_hashes.py b/test/test_re/test_hmc_hashes.py index d114a1838768874c26d93cc9be89a14ffdda5408..c906b9784815a0cbd76cdf59c5f25f33c8c07d01 100644 --- a/test/test_re/test_hmc_hashes.py +++ b/test/test_re/test_hmc_hashes.py @@ -1,14 +1,12 @@ import sys -import pytest - -pytest.importorskip("jax") - import jax +import nifty8.re as jft +import pytest from jax import numpy as jnp from numpy import ndarray -import nifty8.re as jft +jax.config.update("jax_enable_x64", True) NDARRAY_TYPE = [ndarray] diff --git a/test/test_re/test_hmc_leapfrog.py b/test/test_re/test_hmc_leapfrog.py index 7a14826bbc795eef7c9b74561d333eea7353f54f..f763efe0f977f37f79efc0ece1b336af005a0867 100644 --- a/test/test_re/test_hmc_leapfrog.py +++ b/test/test_re/test_hmc_leapfrog.py @@ -1,14 +1,12 @@ -import pytest - -pytest.importorskip("jax") - import sys -from jax import grad +import jax +import nifty8.re as jft +import pytest from jax import numpy as jnp from numpy.testing import assert_allclose -import nifty8.re as jft +jax.config.update("jax_enable_x64", True) pmp = pytest.mark.parametrize @@ -29,7 +27,7 @@ def test_leapfrog_energy_conservation(potential_energy, rtol, interactive=False) mass_matrix = jnp.ones(shape=dims) kinetic_energy = lambda p: jnp.sum(p**2 / mass_matrix / 2.0) - potential_energy_gradient = grad(potential_energy) + potential_energy_gradient = jax.grad(potential_energy) positions = [jnp.array([-1.5, -1.55])] momenta = [jnp.array([-1, 1])] for _ in range(25): @@ -52,7 +50,7 @@ def test_leapfrog_energy_conservation(potential_energy, rtol, interactive=False) ): msg = ( f"q: {q}; p: {p}" - f"\nE_tot: {e_pot+e_kin:.2e}; E_pot: {e_pot:.2e}; E_kin: {e_kin:.2e}" + f"\nE_tot: {e_pot + e_kin:.2e}; E_pot: {e_pot:.2e}; E_kin: {e_kin:.2e}" ) print(msg, file=sys.stderr) diff --git a/test/test_re/test_hmc_pytree.py b/test/test_re/test_hmc_pytree.py index 46ad659b48d49c51fe315125dc44cc6bbfe59daa..0c26f1479e31faca7f781d15a09ffb8bec1e4540 100644 --- a/test/test_re/test_hmc_pytree.py +++ b/test/test_re/test_hmc_pytree.py @@ -1,14 +1,13 @@ -import pytest - -pytest.importorskip("jax") - from functools import partial +import jax +import nifty8.re as jft +import pytest from jax import numpy as jnp from jax.tree_util import tree_leaves from numpy.testing import assert_array_equal -import nifty8.re as jft +jax.config.update("jax_enable_x64", True) def test_hmc_pytree(): diff --git a/test/test_re/test_indexing.py b/test/test_re/test_indexing.py index 6afb4306cab204a29230d7067e01d128e4910ee3..ab5ba65e2f009631333fb691eb21ea88793708d5 100644 --- a/test/test_re/test_indexing.py +++ b/test/test_re/test_indexing.py @@ -2,16 +2,16 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - +import jax import numpy as np +import pytest from jax.tree_util import tree_flatten, tree_map from nifty8.re.multi_grid.grid import FlatGrid, Grid, MGrid, OpenGrid from nifty8.re.multi_grid.grid_impl import HEALPixGrid +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_lanczos.py b/test/test_re/test_lanczos.py index ff397e0e67ca4d93fefdd9fda74a6115ca135cce..179a81f06d13db5dcb1c67ff5bcea834e9fb0776 100644 --- a/test/test_re/test_lanczos.py +++ b/test/test_re/test_lanczos.py @@ -2,20 +2,21 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - -from functools import partial import sys -from jax import random +from functools import partial + +import jax import jax.numpy as jnp import numpy as np +import pytest +from jax import random from numpy.testing import assert_allclose from scipy.spatial import distance_matrix import nifty8.re as jft +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_likelihood.py b/test/test_re/test_likelihood.py index 3810681082f2395b7bbb07997b15e50834adb665..a15bccf91e728f35cf998b1dd31c0ab137af0779 100644 --- a/test/test_re/test_likelihood.py +++ b/test/test_re/test_likelihood.py @@ -2,13 +2,10 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - from functools import partial import jax +import pytest from jax import numpy as jnp from jax import random from jax.tree_util import tree_map @@ -17,6 +14,8 @@ from numpy.testing import assert_allclose import nifty8.re as jft from nifty8.re.likelihood import partial_insert_and_remove as jpartial +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize jax.config.update("jax_enable_x64", True) diff --git a/test/test_re/test_likelihood_impl.py b/test/test_re/test_likelihood_impl.py index f423a1facfe9aa3a33e4679e18570ea7df3bb954..c251e4679265143bddef97fc0164cd35cc5361c6 100644 --- a/test/test_re/test_likelihood_impl.py +++ b/test/test_re/test_likelihood_impl.py @@ -1,19 +1,18 @@ #!/usr/bin/env python3 -import pytest - -pytest.importorskip("jax") - from functools import partial, reduce import jax import jax.numpy as jnp +import pytest from jax import random from jax.tree_util import tree_map from numpy.testing import assert_allclose import nifty8.re as jft +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_minisanity.py b/test/test_re/test_minisanity.py index 0e8d3072bd7949db2a61250773d497bc4c09231c..4a5496d1aae34c5aa39656a2c2db81fd538829c8 100644 --- a/test/test_re/test_minisanity.py +++ b/test/test_re/test_minisanity.py @@ -2,15 +2,15 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - +import jax import numpy as np +import pytest from numpy.testing import assert_allclose, assert_array_equal from nifty8.re.minisanity import reduced_residual_stats +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_misc.py b/test/test_re/test_misc.py index 96b6053c9f3f2142a356d80c16a253abb33c6833..fec927853907aad975371f8f11cd12860c21daa1 100644 --- a/test/test_re/test_misc.py +++ b/test/test_re/test_misc.py @@ -1,9 +1,10 @@ +import jax +import numpy as np import pytest -pytest.importorskip("jax") - import nifty8.re as jft -import numpy as np + +jax.config.update("jax_enable_x64", True) pmp = pytest.mark.parametrize diff --git a/test/test_re/test_ncg.py b/test/test_re/test_ncg.py index 2adab31d6cb053a0d9e7887433623315a689bfa9..fc17651a1febb92f7f421c1e540d88955059c645 100644 --- a/test/test_re/test_ncg.py +++ b/test/test_re/test_ncg.py @@ -1,19 +1,18 @@ #!/usr/bin/env python3 import sys - -import pytest - -pytest.importorskip("jax") - from functools import partial +import jax import jax.numpy as jnp +import pytest from jax import jit, random, value_and_grad from numpy.testing import assert_allclose import nifty8.re as jft +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_num.py b/test/test_re/test_num.py index 2243fa255809395e51015cf5fa4cb5b0ce010017..0216a9030a0647f73e878433325113f4a95bc627 100644 --- a/test/test_re/test_num.py +++ b/test/test_re/test_num.py @@ -2,16 +2,16 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause +import jax +import numpy as np import pytest - -pytest.importorskip("jax") - from jax import numpy as jnp -import numpy as np from numpy.testing import assert_allclose, assert_array_equal from nifty8.re.num import amend_unique, amend_unique_, unique +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_optimize_kl.py b/test/test_re/test_optimize_kl.py index dc515c92871c06867f2440ef7497310f05a22831..1e5f8a2881b698d40acfd793c9343daa36253d53 100644 --- a/test/test_re/test_optimize_kl.py +++ b/test/test_re/test_optimize_kl.py @@ -1,15 +1,12 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - from functools import partial, reduce import jax import jax.numpy as jnp import numpy as np +import pytest from jax import random from jax.tree_util import tree_map from numpy.testing import assert_allclose, assert_array_equal @@ -17,6 +14,8 @@ from numpy.testing import assert_allclose, assert_array_equal import nifty8.re as jft from nifty8.re.optimize_kl import concatenate_zip +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize jax.config.update("jax_enable_x64", True) diff --git a/test/test_re/test_refine.py b/test/test_re/test_refine.py index 3b0a28c27599dfd992feb818931c880a8ae6528e..8af6d237a7bc5cb878b1581360463d58020a8c8b 100644 --- a/test/test_re/test_refine.py +++ b/test/test_re/test_refine.py @@ -2,18 +2,15 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -from functools import partial import sys - -import pytest - -pytest.importorskip("jax") +from functools import partial import jax -from jax import random import jax.numpy as jnp -from jax.tree_util import Partial import numpy as np +import pytest +from jax import random +from jax.tree_util import Partial from numpy.testing import assert_allclose from scipy.spatial import distance_matrix @@ -26,6 +23,8 @@ try: except ImportError: healpy = None +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_refine_healpix.py b/test/test_re/test_refine_healpix.py index 42e36fa65f145a3eeeaf2ad140581e924089ab99..d3adb2d3792ebed36614de6ba93e173882dfbad0 100644 --- a/test/test_re/test_refine_healpix.py +++ b/test/test_re/test_refine_healpix.py @@ -5,24 +5,20 @@ import sys from functools import partial -import pytest - -pytest.importorskip("jax") -pytest.importorskip("healpy") - import jax -from jax import random import jax.numpy as jnp -from jax.tree_util import Partial import numpy as np +import pytest +from jax import random +from jax.tree_util import Partial from numpy.testing import assert_allclose, assert_array_equal import nifty8.re as jft -pmp = pytest.mark.parametrize - jax.config.update("jax_enable_x64", True) +pmp = pytest.mark.parametrize + def lst2fixt(lst): @pytest.fixture(params=lst) diff --git a/test/test_re/test_refine_util.py b/test/test_re/test_refine_util.py index e97b1d1a4a76e9414b5b12613870f3874c969a7f..eba8420f265452a80f377f58e94100fbddd819fc 100644 --- a/test/test_re/test_refine_util.py +++ b/test/test_re/test_refine_util.py @@ -4,16 +4,15 @@ from functools import partial -import pytest - -pytest.importorskip("jax") - import jax import numpy as np +import pytest import nifty8.re as jft from nifty8.re.refine import util +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_stats_distributions.py b/test/test_re/test_stats_distributions.py index 00f71349f1c770750a08e900bbfdee96510d6863..abbc78cbaf1db41c2fd12df3dead80673f282c30 100644 --- a/test/test_re/test_stats_distributions.py +++ b/test/test_re/test_stats_distributions.py @@ -1,15 +1,13 @@ -import pytest - -pytest.importorskip("jax") - -from functools import partial - +import jax import numpy as np +import pytest from numpy.testing import assert_allclose from scipy import stats import nifty8.re as jft +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize diff --git a/test/test_re/test_vmodel.py b/test/test_re/test_vmodel.py index 465522d7e75aa31fcaab546455df94799ee38f5e..44f01fb41b8b5f9ac82d591fe092f7769f66c9e1 100644 --- a/test/test_re/test_vmodel.py +++ b/test/test_re/test_vmodel.py @@ -1,18 +1,17 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import pytest - -pytest.importorskip("jax") - +import jax import jax.numpy as jnp import jax.random as random -from jax import vmap +import pytest from jax.tree_util import tree_map from numpy.testing import assert_allclose import nifty8.re as jft +jax.config.update("jax_enable_x64", True) + pmp = pytest.mark.parametrize @@ -21,7 +20,7 @@ def _v_compare(model, axis_size, in_axes=0, out_axes=0): x = jft.random_like(random.PRNGKey(10), vmodel.domain) res = vmodel(x) - gt = vmap(model, in_axes=(in_axes,), out_axes=out_axes)(x) + gt = jax.vmap(model, in_axes=(in_axes,), out_axes=out_axes)(x) tree_map(assert_allclose, gt, res) @@ -73,7 +72,7 @@ def test_key(key): x = jft.random_like(random.PRNGKey(10), vmodel.domain) res = vmodel(x) - gt = vmap(model, in_axes=({"a": 0, "b": None},))(x) + gt = jax.vmap(model, in_axes=({"a": 0, "b": None},))(x) assert_allclose(gt, res) # Check intended dict handling with dummy input