Skip to content
Snippets Groups Projects
Commit 694f7d9e authored by Gordian Edenhofer's avatar Gordian Edenhofer
Browse files

Skip instead of raise optimize_kl device test for one device

parent c87fdeb8
No related branches found
No related tags found
1 merge request!993Multi gpu
...@@ -43,7 +43,7 @@ test_serial: ...@@ -43,7 +43,7 @@ test_serial:
stage: test stage: test
script: script:
- pytest -n auto -q --cov=nifty8 --ignore test/test_re/test_optimize_kl.py test - pytest -n auto -q --cov=nifty8 --ignore test/test_re/test_optimize_kl.py test
- pytest -n auto -q --cov=nifty8 --cov-append test/test_re/test_optimize_kl.py - env XLA_FLAGS="--xla_force_host_platform_device_count=4" pytest -n auto -q --cov=nifty8 --cov-append test/test_re/test_optimize_kl.py
- > - >
python3 -m coverage report --omit "*plot*" | tee coverage.txt python3 -m coverage report --omit "*plot*" | tee coverage.txt
coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/' coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
import os
os.environ["XLA_FLAGS"] = (
"--xla_force_host_platform_device_count=4" # Use 4 CPU devices
)
import pytest import pytest
pytest.importorskip("jax") pytest.importorskip("jax")
...@@ -20,13 +14,13 @@ from jax import random ...@@ -20,13 +14,13 @@ from jax import random
from jax.tree_util import tree_map from jax.tree_util import tree_map
from numpy.testing import assert_allclose, assert_array_equal from numpy.testing import assert_allclose, assert_array_equal
jax.config.update("jax_enable_x64", True)
import nifty8.re as jft import nifty8.re as jft
from nifty8.re.optimize_kl import concatenate_zip from nifty8.re.optimize_kl import concatenate_zip
pmp = pytest.mark.parametrize pmp = pytest.mark.parametrize
jax.config.update("jax_enable_x64", True)
def random_draw(key, shape, dtype, method): def random_draw(key, shape, dtype, method):
def _isleaf(x): def _isleaf(x):
...@@ -326,7 +320,7 @@ def test_optimize_kl_device_consistency( ...@@ -326,7 +320,7 @@ def test_optimize_kl_device_consistency(
): ):
devices = jax.devices() devices = jax.devices()
if not len(devices) > 1: if not len(devices) > 1:
raise RuntimeError("Need more than one device for test.") pytest.skip("Need more than one device for test.")
if residual_device_map == "pmap" and n_samples > len(devices): if residual_device_map == "pmap" and n_samples > len(devices):
pytest.skip("n_samples>len(devices), skipping for pmap.") pytest.skip("n_samples>len(devices), skipping for pmap.")
if ( if (
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment