diff --git a/src/re/optimize_kl.py b/src/re/optimize_kl.py index 24701163e9640bcb66f39a7952fd378a802bf283..25aebd326b05c63a39687f9bad9a3ffeb44536a8 100644 --- a/src/re/optimize_kl.py +++ b/src/re/optimize_kl.py @@ -15,9 +15,10 @@ import jax import numpy as np from jax import numpy as jnp from jax import random -from jax.tree_util import Partial, tree_map -from jax.sharding import PartitionSpec as Pspec, NamedSharding, Mesh from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as Pspec +from jax.tree_util import Partial, tree_map from . import optimize from .evi import ( @@ -603,7 +604,7 @@ class OptimizeVI: pl = samples.pos if constants: from .likelihood import _parse_point_estimates, partial_insert_and_remove - from .tree_math import zeros_like, Vector + from .tree_math import Vector, zeros_like insert_axes, pl, primals_frozen = _parse_point_estimates(constants, pl) unflatten = Vector if insert_axes else None @@ -777,7 +778,7 @@ class OptimizeVI: state = self.init_state(*args, **kwargs) nm = self.__class__.__name__ for i in range(state.nit, self.n_total_iterations): - logger.info(f"{nm}: Starting {i+1:04d}") + logger.info(f"{nm}: Starting {i + 1:04d}") samples, state = self.update(samples, state) msg = self.get_status_message( samples, state, map=self.residual_map, name=nm @@ -894,7 +895,7 @@ def optimize_kl( nm = "OPTIMIZE_KL" for i in range(opt_vi_st.nit, opt_vi.n_total_iterations): - logger.info(f"{nm}: Starting {i+1:04d}") + logger.info(f"{nm}: Starting {i + 1:04d}") samples, opt_vi_st = opt_vi.update(samples, opt_vi_st) msg = opt_vi.get_status_message(samples, opt_vi_st, name=nm) logger.info(msg)