diff --git a/src/re/optimize_kl.py b/src/re/optimize_kl.py
index 24701163e9640bcb66f39a7952fd378a802bf283..25aebd326b05c63a39687f9bad9a3ffeb44536a8 100644
--- a/src/re/optimize_kl.py
+++ b/src/re/optimize_kl.py
@@ -15,9 +15,10 @@ import jax
 import numpy as np
 from jax import numpy as jnp
 from jax import random
-from jax.tree_util import Partial, tree_map
-from jax.sharding import PartitionSpec as Pspec, NamedSharding, Mesh
 from jax.experimental.shard_map import shard_map
+from jax.sharding import Mesh, NamedSharding
+from jax.sharding import PartitionSpec as Pspec
+from jax.tree_util import Partial, tree_map
 
 from . import optimize
 from .evi import (
@@ -603,7 +604,7 @@ class OptimizeVI:
         pl = samples.pos
         if constants:
             from .likelihood import _parse_point_estimates, partial_insert_and_remove
-            from .tree_math import zeros_like, Vector
+            from .tree_math import Vector, zeros_like
 
             insert_axes, pl, primals_frozen = _parse_point_estimates(constants, pl)
             unflatten = Vector if insert_axes else None
@@ -777,7 +778,7 @@ class OptimizeVI:
         state = self.init_state(*args, **kwargs)
         nm = self.__class__.__name__
         for i in range(state.nit, self.n_total_iterations):
-            logger.info(f"{nm}: Starting {i+1:04d}")
+            logger.info(f"{nm}: Starting {i + 1:04d}")
             samples, state = self.update(samples, state)
             msg = self.get_status_message(
                 samples, state, map=self.residual_map, name=nm
@@ -894,7 +895,7 @@ def optimize_kl(
 
     nm = "OPTIMIZE_KL"
     for i in range(opt_vi_st.nit, opt_vi.n_total_iterations):
-        logger.info(f"{nm}: Starting {i+1:04d}")
+        logger.info(f"{nm}: Starting {i + 1:04d}")
         samples, opt_vi_st = opt_vi.update(samples, opt_vi_st)
         msg = opt_vi.get_status_message(samples, opt_vi_st, name=nm)
         logger.info(msg)