This PR introduces the feature of distributing the samples of a variational inference reconstruction with optimize_kl
over multiple JAX devices. Specifically, by passing a list of JAX devices and device mapping operation, samples can be drawn in parallel on multiple devices, and also the KL minimization can be distributed.
Parallelizing over multiple devices primarily brings the advantage that, in total, more GPU memory is available, and therefore, larger reconstructions can be done. Additionally, the inference can become computationally faster by distributing the compute over multiple GPUs. When mapping over devices the same random numbers are used as when running on a single device. Thus, up to numerical effects, the same samples are drawn.
JAX offers several ways to parallelize across multiple GPUs, and development does not seem to have converged. The current options to map over devices are pmap
and shard_map
for manual parallelization as well as automatic parallelization with jit
compilation. Currently it is not clear (to me) what the best approach to be used in optimize_kl
is, and this might depend on the model and future JAX versions. Therefore, I, for now, leave the choice to the user on how to map over devices similarly as we already left the choice to the user on how to map over the samples on a single device.