Multi gpu

This PR introduces the feature of distributing the samples of a variational inference reconstruction with optimize_kl over multiple JAX devices. Specifically, by passing a list of JAX devices and device mapping operation, samples can be drawn in parallel on multiple devices, and also the KL minimization can be distributed.

Parallelizing over multiple devices primarily brings the advantage that, in total, more GPU memory is available, and therefore, larger reconstructions can be done. Additionally, the inference can become computationally faster by distributing the compute over multiple GPUs. When mapping over devices the same random numbers are used as when running on a single device. Thus, up to numerical effects, the same samples are drawn.

JAX offers several ways to parallelize across multiple GPUs, and development does not seem to have converged. The current options to map over devices are pmap and shard_map for manual parallelization as well as automatic parallelization with jit compilation. Currently it is not clear (to me) what the best approach to be used in optimize_kl is, and this might depend on the model and future JAX versions. Therefore, I, for now, leave the choice to the user on how to map over devices similarly as we already left the choice to the user on how to map over the samples on a single device.

Edited by Jakob Roth

Merge request reports

Loading