diff --git a/src/re/optimize_kl.py b/src/re/optimize_kl.py
index 9d904a43ea77fb32aab920cae79979222dfb0873..24701163e9640bcb66f39a7952fd378a802bf283 100644
--- a/src/re/optimize_kl.py
+++ b/src/re/optimize_kl.py
@@ -90,8 +90,7 @@ def _kl_vg(
     *,
     map=jax.vmap,
     reduce=_reduce,
-    mesh=None,
-    pspec=None,
+    named_sharding=None,
     kl_device_map="shard_map",
 ):
     assert isinstance(primals_samples, Samples)
@@ -101,26 +100,27 @@ def _kl_vg(
     if len(primals_samples) == 0:
         return jax.value_and_grad(ham)(primals)
 
-    if mesh is None:
+    if named_sharding is None:
         vvg = map(jax.value_and_grad(ham))
     else:
         if kl_device_map == "shard_map":
             vvg = map(jax.value_and_grad(ham))
-            spec_tree = tree_map(lambda x: pspec, primals)
-            out_spec = (pspec, spec_tree)
+            spec_tree = tree_map(lambda x: named_sharding.spec, primals)
+            out_spec = (named_sharding.spec, spec_tree)
             in_spec = (spec_tree,)
-            vvg = shard_map(vvg, mesh=mesh, in_specs=in_spec, out_specs=out_spec)
+            vvg = shard_map(
+                vvg, mesh=named_sharding.mesh, in_specs=in_spec, out_specs=out_spec
+            )
         elif kl_device_map == "jit":
             vvg = map(jax.value_and_grad(ham))
-            sharding = NamedSharding(mesh, pspec)
-            sharding_tree = tree_map(lambda x: sharding, primals)
-            out_sharding = (sharding, sharding_tree)
+            sharding_tree = tree_map(lambda x: named_sharding, primals)
+            out_sharding = (named_sharding, sharding_tree)
             in_sharding = (sharding_tree,)
             vvg = jax.jit(vvg, in_shardings=in_sharding, out_shardings=out_sharding)
         elif kl_device_map == "pmap":
             vvg = jax.pmap(jax.value_and_grad(ham))
         else:
-            ve = f"`residual_device_map` need to be `pmap`, `shard_map`, or `jit`, not {self.residual_device_map}"
+            ve = f"`kl_device_map` need to be `pmap`, `shard_map`, or `jit`, not {kl_device_map}"
             raise ValueError(ve)
 
     s = vvg(primals_samples.at(primals).samples)
@@ -135,8 +135,7 @@ def _kl_met(
     *,
     map=jax.vmap,
     reduce=_reduce,
-    mesh=None,
-    pspec=None,
+    named_sharding=None,
     kl_device_map="shard_map",
 ):
     assert isinstance(primals_samples, Samples)
@@ -147,31 +146,30 @@ def _kl_met(
         return ham.metric(primals, tangents)
     met = Partial(ham.metric, tangents=tangents)
 
-    if mesh is None:
+    if named_sharding is None:
         vmet = map(met)
     else:
         if kl_device_map == "shard_map":
             vmet = map(met)
-            spec_tree = tree_map(lambda x: pspec, primals)
+            spec_tree = tree_map(lambda x: named_sharding.spec, primals)
             out_spec = spec_tree
             in_spec = (spec_tree,)
             vmet = shard_map(
                 vmet,
-                mesh=mesh,
+                mesh=named_sharding.mesh,
                 in_specs=in_spec,
                 out_specs=out_spec,
             )
         elif kl_device_map == "jit":
             vmet = map(met)
-            sharding = NamedSharding(mesh, pspec)
-            sharding_tree = tree_map(lambda x: sharding, primals)
+            sharding_tree = tree_map(lambda x: named_sharding, primals)
             out_sharding = sharding_tree
             in_sharding = (sharding_tree,)
             vmet = jax.jit(vmet, in_shardings=in_sharding, out_shardings=out_sharding)
         elif kl_device_map == "pmap":
             vmet = jax.pmap(met)
         else:
-            ve = f"`residual_device_map` need to be `pmap`, `shard_map`, or `jit`, not {self.residual_device_map}"
+            ve = f"`kl_device_map` need to be `pmap`, `shard_map`, or `jit`, not {kl_device_map}"
             raise ValueError(ve)
     s = vmet(primals_samples.at(primals).samples)
     return reduce(s)
@@ -338,11 +336,11 @@ class OptimizeVI:
         kl_jit = _parse_jit(kl_jit)
         residual_jit = _parse_jit(residual_jit)
         residual_map = get_map(residual_map)
-        self.mesh = None
-        self.pspec = None
+        self.named_sharding = None
         if (not devices is None) and len(devices) > 1:
-            self.mesh = Mesh(devices, ("x",))
-            self.pspec = Pspec("x")
+            mesh = Mesh(devices, ("x",))
+            pspec = Pspec("x")
+            self.named_sharding = NamedSharding(mesh, pspec)
         self.residual_device_map = residual_device_map
 
         if mirror_samples is False:
@@ -352,26 +350,34 @@ class OptimizeVI:
             _kl_value_and_grad = partial(
                 kl_jit(
                     _kl_vg,
-                    static_argnames=("map", "reduce", "mesh", "pspec", "kl_device_map"),
+                    static_argnames=(
+                        "map",
+                        "reduce",
+                        "named_sharding",
+                        "kl_device_map",
+                    ),
                 ),
                 likelihood,
                 map=kl_map,
                 reduce=kl_reduce,
-                mesh=self.mesh,
-                pspec=self.pspec,
+                named_sharding=self.named_sharding,
                 kl_device_map=kl_device_map,
             )
         if _kl_metric is None:
             _kl_metric = partial(
                 kl_jit(
                     _kl_met,
-                    static_argnames=("map", "reduce", "mesh", "pspec", "kl_device_map"),
+                    static_argnames=(
+                        "map",
+                        "reduce",
+                        "named_sharding",
+                        "kl_device_map",
+                    ),
                 ),
                 likelihood,
                 map=kl_map,
                 reduce=kl_reduce,
-                mesh=self.mesh,
-                pspec=self.pspec,
+                named_sharding=self.named_sharding,
                 kl_device_map=kl_device_map,
             )
         if _draw_linear_residual is None:
@@ -401,7 +407,7 @@ class OptimizeVI:
         # NOTE, use `Partial` in favor of `partial` to allow the (potentially)
         # re-jitting `residual_map` to trace the kwargs
         kwargs = hide_strings(kwargs)
-        if self.mesh is None:
+        if self.named_sharding is None:
             sampler = Partial(self.draw_linear_residual, **kwargs)
             sampler = self.residual_map(sampler, in_axes=(None, 0))
             smpls, smpls_states = sampler(primals, keys)
@@ -410,44 +416,48 @@ class OptimizeVI:
             smpls = concatenate_zip(smpls, -smpls)
         else:
             n_samples = len(keys)
-            if n_samples == self.mesh.size / 2:
+            if n_samples == self.named_sharding.mesh.size / 2:
                 keys = jnp.repeat(keys, 2, axis=0)
-            keys = jax.device_put(keys, NamedSharding(self.mesh, self.pspec))
+            keys = jax.device_put(keys, self.named_sharding)
 
             # zip samples such that the mirrored-counterpart always comes right
             # after the original sample. out_shardings is need for telling JAX
             # not to move all samples to a single device.
-            @partial(jax.jit, out_shardings=NamedSharding(self.mesh, self.pspec))
+            @partial(jax.jit, out_shardings=self.named_sharding)
             def concatenate_zip_pmap(*arrays):
                 return tree_map(
                     lambda *x: jnp.stack(x, axis=1).reshape((-1,) + x[0].shape[1:]),
                     *arrays,
                 )
 
-            @partial(jax.jit, out_shardings=NamedSharding(self.mesh, self.pspec))
+            @partial(jax.jit, out_shardings=self.named_sharding)
             def _special_mirror_samples(samples):
                 return samples.at[1::2].set(-samples[1::2])
 
             if self.residual_device_map == "pmap":
                 sampler = Partial(self.draw_linear_residual, **kwargs)
                 sampler = jax.pmap(sampler, in_axes=(None, 0))
-                keys = jax.device_put(keys, NamedSharding(self.mesh, self.pspec))
+                keys = jax.device_put(keys, self.named_sharding)
                 smpls, smpls_states = sampler(primals, keys)
             elif self.residual_device_map == "jit":
                 sampler = Partial(self.draw_linear_residual, primals, **kwargs)
                 sampler = self.residual_map(sampler)
-                sharding = NamedSharding(self.mesh, self.pspec)
                 sampler = jax.jit(
-                    sampler, in_shardings=sharding, out_shardings=sharding
+                    sampler,
+                    in_shardings=self.named_sharding,
+                    out_shardings=self.named_sharding,
                 )
                 smpls, smpls_states = sampler(keys)
             elif self.residual_device_map == "shard_map":
                 sampler = Partial(self.draw_linear_residual, primals, **kwargs)
-                out_spec = (tree_map(lambda x: self.pspec, primals), self.pspec)
+                out_spec = (
+                    tree_map(lambda x: self.named_sharding.spec, primals),
+                    self.named_sharding.spec,
+                )
                 sampler = shard_map(
                     self.residual_map(sampler),
-                    mesh=self.mesh,
-                    in_specs=self.pspec,
+                    mesh=self.named_sharding.mesh,
+                    in_specs=self.named_sharding.spec,
                     out_specs=out_spec,
                     check_rep=False,  # FIXME Maybe enable in future JAX releases
                 )
@@ -455,7 +465,7 @@ class OptimizeVI:
             else:
                 ve = f"`residual_device_map` need to be `pmap`, `shard_map`, or `jit`, not {self.residual_device_map}"
                 raise ValueError(ve)
-            if n_samples == self.mesh.size / 2:
+            if n_samples == self.named_sharding.mesh.size / 2:
                 smpls = tree_map(_special_mirror_samples, smpls)
                 keys = keys[::2]  # undo jnp.repeat
             else:
@@ -472,7 +482,7 @@ class OptimizeVI:
         metric_sample_key = concatenate_zip(*((samples.keys,) * 2))
         sgn = jnp.ones(len(samples.keys))
         sgn = concatenate_zip(sgn, -sgn)
-        if self.mesh is None:
+        if self.named_sharding is None:
             curver = Partial(self.nonlinearly_update_residual, **kwargs)
             curver = self.residual_map(curver, in_axes=(None, 0, 0, 0))
             smpls, smpls_states = curver(
@@ -481,26 +491,31 @@ class OptimizeVI:
         else:
             curver = Partial(self.nonlinearly_update_residual, samples.pos, **kwargs)
             if self.residual_device_map == "shard_map":
-                spec_tree = tree_map(lambda x: self.pspec, samples.pos)
-                out_spec = (spec_tree, self.pspec)
-                in_spec = (spec_tree, self.pspec, self.pspec)
+                spec_tree = tree_map(lambda x: self.named_sharding.spec, samples.pos)
+                out_spec = (spec_tree, self.named_sharding.spec)
+                in_spec = (
+                    spec_tree,
+                    self.named_sharding.spec,
+                    self.named_sharding.spec,
+                )
                 curver = shard_map(
                     self.residual_map(curver, in_axes=(0, 0, 0)),
-                    mesh=self.mesh,
+                    mesh=self.named_sharding.mesh,
                     in_specs=in_spec,
                     out_specs=out_spec,
                     check_rep=False,
                 )
             elif self.residual_device_map == "jit":
-                sharding = NamedSharding(self.mesh, self.pspec)
-                sharding_tree = tree_map(lambda x: sharding, samples.pos)
-                out_sharding = (sharding_tree, sharding)
-                in_sharding = (sharding_tree, sharding, sharding)
+                sharding_tree = tree_map(lambda x: self.named_sharding, samples.pos)
+                out_sharding = (sharding_tree, self.named_sharding)
+                in_sharding = (sharding_tree, self.named_sharding, self.named_sharding)
                 curver = self.residual_map(curver)
                 curver = jax.jit(
                     curver, in_shardings=in_sharding, out_shardings=out_sharding
                 )
-                metric_sample_key = jax.device_put(metric_sample_key, sharding)
+                metric_sample_key = jax.device_put(
+                    metric_sample_key, self.named_sharding
+                )
             elif self.residual_device_map == "pmap":
                 curver = jax.pmap(curver, in_axes=(0, 0, 0))
             else:
diff --git a/test/test_re/test_optimize_kl.py b/test/test_re/test_optimize_kl.py
index 7a5d1ed2700be5cf8b9ab0a8ae5fbf5a6f099669..06f57273641852082d35fcd77616c201a363a20d 100644
--- a/test/test_re/test_optimize_kl.py
+++ b/test/test_re/test_optimize_kl.py
@@ -379,10 +379,10 @@ def test_optimize_kl_device_consistency(
         kl_jit=False,
         residual_jit=False,
     )
-    samples_single_device, _ = jft.optimize_kl(**opt_kl_kwargs, map_over_devices=None)
+    samples_single_device, _ = jft.optimize_kl(**opt_kl_kwargs, devices=None)
     samples_multiple_devices, _ = jft.optimize_kl(
         **opt_kl_kwargs,
-        map_over_devices=jax.devices(),
+        devices=jax.devices(),
         residual_device_map=residual_device_map,
         kl_device_map=kl_device_map,
     )