diff --git a/src/re/optimize_kl.py b/src/re/optimize_kl.py index 9d904a43ea77fb32aab920cae79979222dfb0873..24701163e9640bcb66f39a7952fd378a802bf283 100644 --- a/src/re/optimize_kl.py +++ b/src/re/optimize_kl.py @@ -90,8 +90,7 @@ def _kl_vg( *, map=jax.vmap, reduce=_reduce, - mesh=None, - pspec=None, + named_sharding=None, kl_device_map="shard_map", ): assert isinstance(primals_samples, Samples) @@ -101,26 +100,27 @@ def _kl_vg( if len(primals_samples) == 0: return jax.value_and_grad(ham)(primals) - if mesh is None: + if named_sharding is None: vvg = map(jax.value_and_grad(ham)) else: if kl_device_map == "shard_map": vvg = map(jax.value_and_grad(ham)) - spec_tree = tree_map(lambda x: pspec, primals) - out_spec = (pspec, spec_tree) + spec_tree = tree_map(lambda x: named_sharding.spec, primals) + out_spec = (named_sharding.spec, spec_tree) in_spec = (spec_tree,) - vvg = shard_map(vvg, mesh=mesh, in_specs=in_spec, out_specs=out_spec) + vvg = shard_map( + vvg, mesh=named_sharding.mesh, in_specs=in_spec, out_specs=out_spec + ) elif kl_device_map == "jit": vvg = map(jax.value_and_grad(ham)) - sharding = NamedSharding(mesh, pspec) - sharding_tree = tree_map(lambda x: sharding, primals) - out_sharding = (sharding, sharding_tree) + sharding_tree = tree_map(lambda x: named_sharding, primals) + out_sharding = (named_sharding, sharding_tree) in_sharding = (sharding_tree,) vvg = jax.jit(vvg, in_shardings=in_sharding, out_shardings=out_sharding) elif kl_device_map == "pmap": vvg = jax.pmap(jax.value_and_grad(ham)) else: - ve = f"`residual_device_map` need to be `pmap`, `shard_map`, or `jit`, not {self.residual_device_map}" + ve = f"`kl_device_map` need to be `pmap`, `shard_map`, or `jit`, not {kl_device_map}" raise ValueError(ve) s = vvg(primals_samples.at(primals).samples) @@ -135,8 +135,7 @@ def _kl_met( *, map=jax.vmap, reduce=_reduce, - mesh=None, - pspec=None, + named_sharding=None, kl_device_map="shard_map", ): assert isinstance(primals_samples, Samples) @@ -147,31 +146,30 @@ def _kl_met( return ham.metric(primals, tangents) met = Partial(ham.metric, tangents=tangents) - if mesh is None: + if named_sharding is None: vmet = map(met) else: if kl_device_map == "shard_map": vmet = map(met) - spec_tree = tree_map(lambda x: pspec, primals) + spec_tree = tree_map(lambda x: named_sharding.spec, primals) out_spec = spec_tree in_spec = (spec_tree,) vmet = shard_map( vmet, - mesh=mesh, + mesh=named_sharding.mesh, in_specs=in_spec, out_specs=out_spec, ) elif kl_device_map == "jit": vmet = map(met) - sharding = NamedSharding(mesh, pspec) - sharding_tree = tree_map(lambda x: sharding, primals) + sharding_tree = tree_map(lambda x: named_sharding, primals) out_sharding = sharding_tree in_sharding = (sharding_tree,) vmet = jax.jit(vmet, in_shardings=in_sharding, out_shardings=out_sharding) elif kl_device_map == "pmap": vmet = jax.pmap(met) else: - ve = f"`residual_device_map` need to be `pmap`, `shard_map`, or `jit`, not {self.residual_device_map}" + ve = f"`kl_device_map` need to be `pmap`, `shard_map`, or `jit`, not {kl_device_map}" raise ValueError(ve) s = vmet(primals_samples.at(primals).samples) return reduce(s) @@ -338,11 +336,11 @@ class OptimizeVI: kl_jit = _parse_jit(kl_jit) residual_jit = _parse_jit(residual_jit) residual_map = get_map(residual_map) - self.mesh = None - self.pspec = None + self.named_sharding = None if (not devices is None) and len(devices) > 1: - self.mesh = Mesh(devices, ("x",)) - self.pspec = Pspec("x") + mesh = Mesh(devices, ("x",)) + pspec = Pspec("x") + self.named_sharding = NamedSharding(mesh, pspec) self.residual_device_map = residual_device_map if mirror_samples is False: @@ -352,26 +350,34 @@ class OptimizeVI: _kl_value_and_grad = partial( kl_jit( _kl_vg, - static_argnames=("map", "reduce", "mesh", "pspec", "kl_device_map"), + static_argnames=( + "map", + "reduce", + "named_sharding", + "kl_device_map", + ), ), likelihood, map=kl_map, reduce=kl_reduce, - mesh=self.mesh, - pspec=self.pspec, + named_sharding=self.named_sharding, kl_device_map=kl_device_map, ) if _kl_metric is None: _kl_metric = partial( kl_jit( _kl_met, - static_argnames=("map", "reduce", "mesh", "pspec", "kl_device_map"), + static_argnames=( + "map", + "reduce", + "named_sharding", + "kl_device_map", + ), ), likelihood, map=kl_map, reduce=kl_reduce, - mesh=self.mesh, - pspec=self.pspec, + named_sharding=self.named_sharding, kl_device_map=kl_device_map, ) if _draw_linear_residual is None: @@ -401,7 +407,7 @@ class OptimizeVI: # NOTE, use `Partial` in favor of `partial` to allow the (potentially) # re-jitting `residual_map` to trace the kwargs kwargs = hide_strings(kwargs) - if self.mesh is None: + if self.named_sharding is None: sampler = Partial(self.draw_linear_residual, **kwargs) sampler = self.residual_map(sampler, in_axes=(None, 0)) smpls, smpls_states = sampler(primals, keys) @@ -410,44 +416,48 @@ class OptimizeVI: smpls = concatenate_zip(smpls, -smpls) else: n_samples = len(keys) - if n_samples == self.mesh.size / 2: + if n_samples == self.named_sharding.mesh.size / 2: keys = jnp.repeat(keys, 2, axis=0) - keys = jax.device_put(keys, NamedSharding(self.mesh, self.pspec)) + keys = jax.device_put(keys, self.named_sharding) # zip samples such that the mirrored-counterpart always comes right # after the original sample. out_shardings is need for telling JAX # not to move all samples to a single device. - @partial(jax.jit, out_shardings=NamedSharding(self.mesh, self.pspec)) + @partial(jax.jit, out_shardings=self.named_sharding) def concatenate_zip_pmap(*arrays): return tree_map( lambda *x: jnp.stack(x, axis=1).reshape((-1,) + x[0].shape[1:]), *arrays, ) - @partial(jax.jit, out_shardings=NamedSharding(self.mesh, self.pspec)) + @partial(jax.jit, out_shardings=self.named_sharding) def _special_mirror_samples(samples): return samples.at[1::2].set(-samples[1::2]) if self.residual_device_map == "pmap": sampler = Partial(self.draw_linear_residual, **kwargs) sampler = jax.pmap(sampler, in_axes=(None, 0)) - keys = jax.device_put(keys, NamedSharding(self.mesh, self.pspec)) + keys = jax.device_put(keys, self.named_sharding) smpls, smpls_states = sampler(primals, keys) elif self.residual_device_map == "jit": sampler = Partial(self.draw_linear_residual, primals, **kwargs) sampler = self.residual_map(sampler) - sharding = NamedSharding(self.mesh, self.pspec) sampler = jax.jit( - sampler, in_shardings=sharding, out_shardings=sharding + sampler, + in_shardings=self.named_sharding, + out_shardings=self.named_sharding, ) smpls, smpls_states = sampler(keys) elif self.residual_device_map == "shard_map": sampler = Partial(self.draw_linear_residual, primals, **kwargs) - out_spec = (tree_map(lambda x: self.pspec, primals), self.pspec) + out_spec = ( + tree_map(lambda x: self.named_sharding.spec, primals), + self.named_sharding.spec, + ) sampler = shard_map( self.residual_map(sampler), - mesh=self.mesh, - in_specs=self.pspec, + mesh=self.named_sharding.mesh, + in_specs=self.named_sharding.spec, out_specs=out_spec, check_rep=False, # FIXME Maybe enable in future JAX releases ) @@ -455,7 +465,7 @@ class OptimizeVI: else: ve = f"`residual_device_map` need to be `pmap`, `shard_map`, or `jit`, not {self.residual_device_map}" raise ValueError(ve) - if n_samples == self.mesh.size / 2: + if n_samples == self.named_sharding.mesh.size / 2: smpls = tree_map(_special_mirror_samples, smpls) keys = keys[::2] # undo jnp.repeat else: @@ -472,7 +482,7 @@ class OptimizeVI: metric_sample_key = concatenate_zip(*((samples.keys,) * 2)) sgn = jnp.ones(len(samples.keys)) sgn = concatenate_zip(sgn, -sgn) - if self.mesh is None: + if self.named_sharding is None: curver = Partial(self.nonlinearly_update_residual, **kwargs) curver = self.residual_map(curver, in_axes=(None, 0, 0, 0)) smpls, smpls_states = curver( @@ -481,26 +491,31 @@ class OptimizeVI: else: curver = Partial(self.nonlinearly_update_residual, samples.pos, **kwargs) if self.residual_device_map == "shard_map": - spec_tree = tree_map(lambda x: self.pspec, samples.pos) - out_spec = (spec_tree, self.pspec) - in_spec = (spec_tree, self.pspec, self.pspec) + spec_tree = tree_map(lambda x: self.named_sharding.spec, samples.pos) + out_spec = (spec_tree, self.named_sharding.spec) + in_spec = ( + spec_tree, + self.named_sharding.spec, + self.named_sharding.spec, + ) curver = shard_map( self.residual_map(curver, in_axes=(0, 0, 0)), - mesh=self.mesh, + mesh=self.named_sharding.mesh, in_specs=in_spec, out_specs=out_spec, check_rep=False, ) elif self.residual_device_map == "jit": - sharding = NamedSharding(self.mesh, self.pspec) - sharding_tree = tree_map(lambda x: sharding, samples.pos) - out_sharding = (sharding_tree, sharding) - in_sharding = (sharding_tree, sharding, sharding) + sharding_tree = tree_map(lambda x: self.named_sharding, samples.pos) + out_sharding = (sharding_tree, self.named_sharding) + in_sharding = (sharding_tree, self.named_sharding, self.named_sharding) curver = self.residual_map(curver) curver = jax.jit( curver, in_shardings=in_sharding, out_shardings=out_sharding ) - metric_sample_key = jax.device_put(metric_sample_key, sharding) + metric_sample_key = jax.device_put( + metric_sample_key, self.named_sharding + ) elif self.residual_device_map == "pmap": curver = jax.pmap(curver, in_axes=(0, 0, 0)) else: diff --git a/test/test_re/test_optimize_kl.py b/test/test_re/test_optimize_kl.py index 7a5d1ed2700be5cf8b9ab0a8ae5fbf5a6f099669..06f57273641852082d35fcd77616c201a363a20d 100644 --- a/test/test_re/test_optimize_kl.py +++ b/test/test_re/test_optimize_kl.py @@ -379,10 +379,10 @@ def test_optimize_kl_device_consistency( kl_jit=False, residual_jit=False, ) - samples_single_device, _ = jft.optimize_kl(**opt_kl_kwargs, map_over_devices=None) + samples_single_device, _ = jft.optimize_kl(**opt_kl_kwargs, devices=None) samples_multiple_devices, _ = jft.optimize_kl( **opt_kl_kwargs, - map_over_devices=jax.devices(), + devices=jax.devices(), residual_device_map=residual_device_map, kl_device_map=kl_device_map, )