diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index a98cf9a36d0a0d0499aac3008c6675a3416565e8..52cdb1e1ae9fbd7690dd499c5e28f971b59e6a70 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -43,7 +43,7 @@ test_serial:
   stage: test
   script:
     - pytest -n auto -q --cov=nifty8 --ignore test/test_re/test_optimize_kl.py test
-    - pytest -n auto -q --cov=nifty8 --cov-append test/test_re/test_optimize_kl.py
+    - env XLA_FLAGS="--xla_force_host_platform_device_count=4" pytest -n auto -q --cov=nifty8 --cov-append test/test_re/test_optimize_kl.py
     - >
       python3 -m coverage report --omit "*plot*" | tee coverage.txt
   coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
diff --git a/test/test_re/test_optimize_kl.py b/test/test_re/test_optimize_kl.py
index 06f57273641852082d35fcd77616c201a363a20d..dc515c92871c06867f2440ef7497310f05a22831 100644
--- a/test/test_re/test_optimize_kl.py
+++ b/test/test_re/test_optimize_kl.py
@@ -1,12 +1,6 @@
 #!/usr/bin/env python3
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
-import os
-
-os.environ["XLA_FLAGS"] = (
-    "--xla_force_host_platform_device_count=4"  # Use 4 CPU devices
-)
-
 import pytest
 
 pytest.importorskip("jax")
@@ -20,13 +14,13 @@ from jax import random
 from jax.tree_util import tree_map
 from numpy.testing import assert_allclose, assert_array_equal
 
-jax.config.update("jax_enable_x64", True)
-
 import nifty8.re as jft
 from nifty8.re.optimize_kl import concatenate_zip
 
 pmp = pytest.mark.parametrize
 
+jax.config.update("jax_enable_x64", True)
+
 
 def random_draw(key, shape, dtype, method):
     def _isleaf(x):
@@ -326,7 +320,7 @@ def test_optimize_kl_device_consistency(
 ):
     devices = jax.devices()
     if not len(devices) > 1:
-        raise RuntimeError("Need more than one device for test.")
+        pytest.skip("Need more than one device for test.")
     if residual_device_map == "pmap" and n_samples > len(devices):
         pytest.skip("n_samples>len(devices), skipping for pmap.")
     if (