diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a98cf9a36d0a0d0499aac3008c6675a3416565e8..52cdb1e1ae9fbd7690dd499c5e28f971b59e6a70 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -43,7 +43,7 @@ test_serial: stage: test script: - pytest -n auto -q --cov=nifty8 --ignore test/test_re/test_optimize_kl.py test - - pytest -n auto -q --cov=nifty8 --cov-append test/test_re/test_optimize_kl.py + - env XLA_FLAGS="--xla_force_host_platform_device_count=4" pytest -n auto -q --cov=nifty8 --cov-append test/test_re/test_optimize_kl.py - > python3 -m coverage report --omit "*plot*" | tee coverage.txt coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/' diff --git a/test/test_re/test_optimize_kl.py b/test/test_re/test_optimize_kl.py index 06f57273641852082d35fcd77616c201a363a20d..dc515c92871c06867f2440ef7497310f05a22831 100644 --- a/test/test_re/test_optimize_kl.py +++ b/test/test_re/test_optimize_kl.py @@ -1,12 +1,6 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause -import os - -os.environ["XLA_FLAGS"] = ( - "--xla_force_host_platform_device_count=4" # Use 4 CPU devices -) - import pytest pytest.importorskip("jax") @@ -20,13 +14,13 @@ from jax import random from jax.tree_util import tree_map from numpy.testing import assert_allclose, assert_array_equal -jax.config.update("jax_enable_x64", True) - import nifty8.re as jft from nifty8.re.optimize_kl import concatenate_zip pmp = pytest.mark.parametrize +jax.config.update("jax_enable_x64", True) + def random_draw(key, shape, dtype, method): def _isleaf(x): @@ -326,7 +320,7 @@ def test_optimize_kl_device_consistency( ): devices = jax.devices() if not len(devices) > 1: - raise RuntimeError("Need more than one device for test.") + pytest.skip("Need more than one device for test.") if residual_device_map == "pmap" and n_samples > len(devices): pytest.skip("n_samples>len(devices), skipping for pmap.") if (