diff --git a/src/re/optimize_kl.py b/src/re/optimize_kl.py
index fd6a471905aa7a560c52f34eed766a6ae27b735e..9d904a43ea77fb32aab920cae79979222dfb0873 100644
--- a/src/re/optimize_kl.py
+++ b/src/re/optimize_kl.py
@@ -269,7 +269,7 @@ class OptimizeVI:
         residual_map="lmap",
         kl_reduce=_reduce,
         mirror_samples=True,
-        map_over_devices: Optional[list] = None,
+        devices: Optional[list] = None,
         kl_device_map="shard_map",
         residual_device_map="pmap",
         _kl_value_and_grad: Optional[Callable] = None,
@@ -300,7 +300,7 @@ class OptimizeVI:
             Reduce function used for the KL minimization.
         mirror_samples: bool
             Whether to mirror the samples or not.
-        map_over_devices : list of devices or None
+        devices : list of devices or None
             Devices over which the samples are mapped. If `None` only the
             default device is used. To use all detected devices pass
             jax.devices(). Generally the samples needs to be evenly
@@ -308,17 +308,17 @@ class OptimizeVI:
             the `kl_device_map` and `residual_device_map` arguments.
         kl_device_map : str
             Map function used for mapping KL minimization over the devices
-            listed in `map_over_devices`. `kl_device_map` can be 'shard_map',
-            'pmap', or 'jit'. If set to 'pmap', 2*n_samples need to be equal to
-            the number of devices. For all other maps the samples needs to be
-            equally distributable over the devices.
+            listed in `devices`. `kl_device_map` can be 'shard_map', 'pmap', or
+            'jit'. If set to 'pmap', 2*n_samples need to be equal to the number
+            of devices. For all other maps the samples needs to be equally
+            distributable over the devices.
         residual_device_map : str
             Map function used for mapping sampling over the devices listed in
-            `map_over_devices`. `residual_device_map` can be 'shard_map',
-            'pmap', or 'jit'. If set to 'pmap', 2*n_samples need to be equal to
-            the number of devices. If only linear samples are drawn ,'pmap' also
-            works if n_samples equals the number of devices. For the other maps
-            it is sufficient if 2*n_samples equals the number of devices, or
+            `devices`. `residual_device_map` can be 'shard_map', 'pmap', or
+            'jit'. If set to 'pmap', 2*n_samples need to be equal to the number
+            of devices. If only linear samples are drawn ,'pmap' also works if
+            n_samples equals the number of devices. For the other maps it is
+            sufficient if 2*n_samples equals the number of devices, or
             n_samples can be evenly divided by the number of devices.
 
         Notes
@@ -340,8 +340,8 @@ class OptimizeVI:
         residual_map = get_map(residual_map)
         self.mesh = None
         self.pspec = None
-        if (not map_over_devices is None) and len(map_over_devices) > 1:
-            self.mesh = Mesh(map_over_devices, ("x",))
+        if (not devices is None) and len(devices) > 1:
+            self.mesh = Mesh(devices, ("x",))
             self.pspec = Pspec("x")
         self.residual_device_map = residual_device_map
 
@@ -795,7 +795,7 @@ def optimize_kl(
     resume: Union[str, bool] = False,
     callback: Optional[Callable[[Samples, OptimizeVIState], None]] = None,
     odir: Optional[str] = None,
-    map_over_devices: Optional[list] = None,
+    devices: Optional[list] = None,
     kl_device_map="shard_map",
     residual_device_map="pmap",
     _optimize_vi=None,
@@ -835,7 +835,7 @@ def optimize_kl(
             residual_map=residual_map,
             kl_reduce=kl_reduce,
             mirror_samples=mirror_samples,
-            map_over_devices=map_over_devices,
+            devices=devices,
             kl_device_map=kl_device_map,
             residual_device_map=residual_device_map,
         )