diff --git a/src/re/optimize_kl.py b/src/re/optimize_kl.py index fd6a471905aa7a560c52f34eed766a6ae27b735e..9d904a43ea77fb32aab920cae79979222dfb0873 100644 --- a/src/re/optimize_kl.py +++ b/src/re/optimize_kl.py @@ -269,7 +269,7 @@ class OptimizeVI: residual_map="lmap", kl_reduce=_reduce, mirror_samples=True, - map_over_devices: Optional[list] = None, + devices: Optional[list] = None, kl_device_map="shard_map", residual_device_map="pmap", _kl_value_and_grad: Optional[Callable] = None, @@ -300,7 +300,7 @@ class OptimizeVI: Reduce function used for the KL minimization. mirror_samples: bool Whether to mirror the samples or not. - map_over_devices : list of devices or None + devices : list of devices or None Devices over which the samples are mapped. If `None` only the default device is used. To use all detected devices pass jax.devices(). Generally the samples needs to be evenly @@ -308,17 +308,17 @@ class OptimizeVI: the `kl_device_map` and `residual_device_map` arguments. kl_device_map : str Map function used for mapping KL minimization over the devices - listed in `map_over_devices`. `kl_device_map` can be 'shard_map', - 'pmap', or 'jit'. If set to 'pmap', 2*n_samples need to be equal to - the number of devices. For all other maps the samples needs to be - equally distributable over the devices. + listed in `devices`. `kl_device_map` can be 'shard_map', 'pmap', or + 'jit'. If set to 'pmap', 2*n_samples need to be equal to the number + of devices. For all other maps the samples needs to be equally + distributable over the devices. residual_device_map : str Map function used for mapping sampling over the devices listed in - `map_over_devices`. `residual_device_map` can be 'shard_map', - 'pmap', or 'jit'. If set to 'pmap', 2*n_samples need to be equal to - the number of devices. If only linear samples are drawn ,'pmap' also - works if n_samples equals the number of devices. For the other maps - it is sufficient if 2*n_samples equals the number of devices, or + `devices`. `residual_device_map` can be 'shard_map', 'pmap', or + 'jit'. If set to 'pmap', 2*n_samples need to be equal to the number + of devices. If only linear samples are drawn ,'pmap' also works if + n_samples equals the number of devices. For the other maps it is + sufficient if 2*n_samples equals the number of devices, or n_samples can be evenly divided by the number of devices. Notes @@ -340,8 +340,8 @@ class OptimizeVI: residual_map = get_map(residual_map) self.mesh = None self.pspec = None - if (not map_over_devices is None) and len(map_over_devices) > 1: - self.mesh = Mesh(map_over_devices, ("x",)) + if (not devices is None) and len(devices) > 1: + self.mesh = Mesh(devices, ("x",)) self.pspec = Pspec("x") self.residual_device_map = residual_device_map @@ -795,7 +795,7 @@ def optimize_kl( resume: Union[str, bool] = False, callback: Optional[Callable[[Samples, OptimizeVIState], None]] = None, odir: Optional[str] = None, - map_over_devices: Optional[list] = None, + devices: Optional[list] = None, kl_device_map="shard_map", residual_device_map="pmap", _optimize_vi=None, @@ -835,7 +835,7 @@ def optimize_kl( residual_map=residual_map, kl_reduce=kl_reduce, mirror_samples=mirror_samples, - map_over_devices=map_over_devices, + devices=devices, kl_device_map=kl_device_map, residual_device_map=residual_device_map, )