Multi gpu
Compare changes
Beginning with 2nd June, only the "Single Sign On" option for login to the GitLab web interface will be possible. If you don't have an MPCDF wide second factor so far, please get one at our SelfService (https://selfservice.mpcdf.mpg.de). The GitLab internal second factor will not work.
This PR introduces the feature of distributing the samples of a variational inference reconstruction with optimize_kl
over multiple JAX devices. Specifically, by passing a list of JAX devices and device mapping operation, samples can be drawn in parallel on multiple devices, and also the KL minimization can be distributed.
Parallelizing over multiple devices primarily brings the advantage that, in total, more GPU memory is available, and therefore, larger reconstructions can be done. Additionally, the inference can become computationally faster by distributing the compute over multiple GPUs. When mapping over devices the same random numbers are used as when running on a single device. Thus, up to numerical effects, the same samples are drawn.
JAX offers several ways to parallelize across multiple GPUs, and development does not seem to have converged. The current options to map over devices are pmap
and shard_map
for manual parallelization as well as automatic parallelization with jit
compilation. Currently it is not clear (to me) what the best approach to be used in optimize_kl
is, and this might depend on the model and future JAX versions. Therefore, I, for now, leave the choice to the user on how to map over devices similarly as we already left the choice to the user on how to map over the samples on a single device.