Skip to content
Snippets Groups Projects
Commit db221234 authored by Jakob Roth's avatar Jakob Roth
Browse files

optimize_kl: map_over_devices -> devices

parent e5d865bd
No related branches found
No related tags found
1 merge request!993Multi gpu
Pipeline #241421 failed
...@@ -269,7 +269,7 @@ class OptimizeVI: ...@@ -269,7 +269,7 @@ class OptimizeVI:
residual_map="lmap", residual_map="lmap",
kl_reduce=_reduce, kl_reduce=_reduce,
mirror_samples=True, mirror_samples=True,
map_over_devices: Optional[list] = None, devices: Optional[list] = None,
kl_device_map="shard_map", kl_device_map="shard_map",
residual_device_map="pmap", residual_device_map="pmap",
_kl_value_and_grad: Optional[Callable] = None, _kl_value_and_grad: Optional[Callable] = None,
...@@ -300,7 +300,7 @@ class OptimizeVI: ...@@ -300,7 +300,7 @@ class OptimizeVI:
Reduce function used for the KL minimization. Reduce function used for the KL minimization.
mirror_samples: bool mirror_samples: bool
Whether to mirror the samples or not. Whether to mirror the samples or not.
map_over_devices : list of devices or None devices : list of devices or None
Devices over which the samples are mapped. If `None` only the Devices over which the samples are mapped. If `None` only the
default device is used. To use all detected devices pass default device is used. To use all detected devices pass
jax.devices(). Generally the samples needs to be evenly jax.devices(). Generally the samples needs to be evenly
...@@ -308,17 +308,17 @@ class OptimizeVI: ...@@ -308,17 +308,17 @@ class OptimizeVI:
the `kl_device_map` and `residual_device_map` arguments. the `kl_device_map` and `residual_device_map` arguments.
kl_device_map : str kl_device_map : str
Map function used for mapping KL minimization over the devices Map function used for mapping KL minimization over the devices
listed in `map_over_devices`. `kl_device_map` can be 'shard_map', listed in `devices`. `kl_device_map` can be 'shard_map', 'pmap', or
'pmap', or 'jit'. If set to 'pmap', 2*n_samples need to be equal to 'jit'. If set to 'pmap', 2*n_samples need to be equal to the number
the number of devices. For all other maps the samples needs to be of devices. For all other maps the samples needs to be equally
equally distributable over the devices. distributable over the devices.
residual_device_map : str residual_device_map : str
Map function used for mapping sampling over the devices listed in Map function used for mapping sampling over the devices listed in
`map_over_devices`. `residual_device_map` can be 'shard_map', `devices`. `residual_device_map` can be 'shard_map', 'pmap', or
'pmap', or 'jit'. If set to 'pmap', 2*n_samples need to be equal to 'jit'. If set to 'pmap', 2*n_samples need to be equal to the number
the number of devices. If only linear samples are drawn ,'pmap' also of devices. If only linear samples are drawn ,'pmap' also works if
works if n_samples equals the number of devices. For the other maps n_samples equals the number of devices. For the other maps it is
it is sufficient if 2*n_samples equals the number of devices, or sufficient if 2*n_samples equals the number of devices, or
n_samples can be evenly divided by the number of devices. n_samples can be evenly divided by the number of devices.
Notes Notes
...@@ -340,8 +340,8 @@ class OptimizeVI: ...@@ -340,8 +340,8 @@ class OptimizeVI:
residual_map = get_map(residual_map) residual_map = get_map(residual_map)
self.mesh = None self.mesh = None
self.pspec = None self.pspec = None
if (not map_over_devices is None) and len(map_over_devices) > 1: if (not devices is None) and len(devices) > 1:
self.mesh = Mesh(map_over_devices, ("x",)) self.mesh = Mesh(devices, ("x",))
self.pspec = Pspec("x") self.pspec = Pspec("x")
self.residual_device_map = residual_device_map self.residual_device_map = residual_device_map
...@@ -795,7 +795,7 @@ def optimize_kl( ...@@ -795,7 +795,7 @@ def optimize_kl(
resume: Union[str, bool] = False, resume: Union[str, bool] = False,
callback: Optional[Callable[[Samples, OptimizeVIState], None]] = None, callback: Optional[Callable[[Samples, OptimizeVIState], None]] = None,
odir: Optional[str] = None, odir: Optional[str] = None,
map_over_devices: Optional[list] = None, devices: Optional[list] = None,
kl_device_map="shard_map", kl_device_map="shard_map",
residual_device_map="pmap", residual_device_map="pmap",
_optimize_vi=None, _optimize_vi=None,
...@@ -835,7 +835,7 @@ def optimize_kl( ...@@ -835,7 +835,7 @@ def optimize_kl(
residual_map=residual_map, residual_map=residual_map,
kl_reduce=kl_reduce, kl_reduce=kl_reduce,
mirror_samples=mirror_samples, mirror_samples=mirror_samples,
map_over_devices=map_over_devices, devices=devices,
kl_device_map=kl_device_map, kl_device_map=kl_device_map,
residual_device_map=residual_device_map, residual_device_map=residual_device_map,
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment