Multi gpu
This PR introduces the feature of distributing the samples of a variational inference reconstruction with optimize_kl
over multiple JAX devices. Specifically, by passing a list of JAX devices and device mapping operation, samples can be drawn in parallel on multiple devices, and also the KL minimization can be distributed.
Parallelizing over multiple devices primarily brings the advantage that, in total, more GPU memory is available, and therefore, larger reconstructions can be done. Additionally, the inference can become computationally faster by distributing the compute over multiple GPUs. When mapping over devices the same random numbers are used as when running on a single device. Thus, up to numerical effects, the same samples are drawn.
JAX offers several ways to parallelize across multiple GPUs, and development does not seem to have converged. The current options to map over devices are pmap
and shard_map
for manual parallelization as well as automatic parallelization with jit
compilation. Currently it is not clear (to me) what the best approach to be used in optimize_kl
is, and this might depend on the model and future JAX versions. Therefore, I, for now, leave the choice to the user on how to map over devices similarly as we already left the choice to the user on how to map over the samples on a single device.
Merge request reports
Activity
The demo works great for me and the code looks great!
I think it needs to be rebased onevery_jaxlib_version
and needs a jax version constraint to a very recent JAX version as to makejax.make_mesh
available. (As always, I have some other nitpicky comments that I'm happy to share once you think the PR is ready for a merge )If you run benchmarks, I would be very curious to know how far
jax.jit(..., in_shardings=..., out_shardings=...)
gets us. I think the current implementation is better as it explicitly tells JAX to use the best axis for parallelization but I would be very curious to see how good JAX's automagic discovery is.- Resolved by Gordian Edenhofer
added 1 commit
- 62816eb8 - optimize_kl: avoid using JAX>=0.5.0 functions
added 1 commit
- e62a8191 - optimze_kl: simplify shared_map in/out_specs
In our current implementation of the
newton_cg
for minimizing the KL we solve onecg
(per newton step) independent of how many samples we use. We do this by averaging over the gradients and metric of the individual samples in each step of this CG. I believe this approach is beneficial if all samples are on the same devices and the reduce (thus averaging) operations over the sample axis are quick.I am not sure if this is a good approach if the samples are on separate devices, as this triggers a lot of communication between the devices. In fact, after each CG step, the results of the metric application need to be cross-communicated between the devices to compute the mean value. I believe if samples are distributed over multiple devices, it would be better if each device solves the CG for its local samples, and then the results are averaged in the end.
As the queues on the GPU cluster are currently very long, I did not have the opportunity to test how much this communication overhead actually slows things down. But I could imagine that it is a significant effect, making it worth implementing a NewtonCG as proposed.
I think this was never actually done, but in the old NIFTy there should be the same effect when running with MPI tasks distributed over multiple nodes. @mtr, @pfrank, or @g-philipp had you ever thought about implementing something similar in the old NIFTy and have insides to share with me?
I believe if samples are distributed over multiple devices, it would be better if each device solves the CG for its local samples, and then the results are averaged in the end.
It will most likely reduce communication, but I could imagine that this breaks some fundamental assumptions of the algorithm. To me it sounds pretty dangerous, and it will also make the computation results depend on how many GPUs were used.
It's true that comunication overheads are even worse when computing on GPUs instead of CPUs, but I wouldn't trust this approach without very thorough testing and analytical double-checking.
There may be the additional complication that CG requires a different number of iterations on different devices, which replaces the communication overhead with load imbalance. Or, if we are especially unlucky, that CG iteration count generally goes up with decreasing number of samples...
Thanks for the great work! Looks really cool. Indeed, as you pointed out already distributing this way would alter the algorithm in a non-tested way.
I believe there is space to explore altering our VI algorithms by modifying what we use as the proximal curvature. If somebody wants to look at it I'd suggest to do this in a separate feature though. Only for completeness: the presumably most direct modification to test would be to replace the average metric with the metric evaluated at the latent mean and compare full runs against each other. (And potentially try other things @gedenhof experimented with such as curvature + trust-region-ncg instead of metric + nncg). As Martin pointed out already, the overall performance gain/loss will likely be problem dependent.
Jakob and I played around with trust region NCG but I think we did not find any significant differences and thus opted not to pursue it further.
In JAX, I have the feeling that using the Hessian is better than using the metric. The NCG minimizer in its current form should already handle the case of negative eigenvalues correctly (basically detect them and then stop). In essence everything is already set. We probably should do a representative test in the group first though if we want to switch.
added 1 commit
- 7cf18580 - test_optimize_kl: test single VS multi device
I made initial performance tests on actual GPUs. Thereby, I noticed that for some models mapping with
shared_map
over thedraw_linear_residual
(here) leads to exceptionally long compile times. Specifically, when I increase thedims
in the demo to a much larger number of pixels, I get extremely long compile times and warnings about constant folding in the integrated wiener process. Interestingly when running on a single GPU this doesn't happen. Also, this does not occur in the KL minimization even withshared_map
over multiple GPUs.I believe these excessive complete times for some models in the
shard_map
sampling could have something to do with disabling some checks (here), as JAX currently does not fully supportshard_map
for our CG implementation. This would explain why long compile times are only observed when mapping the sampling. For other models, as for example when disabling flexibility and asperity in the demo, this problem does not occur.As a workaround, I have now implemented the possibility to also
pmap
over the CG. This seems to work fine but comes with the restriction that the number of GPUs has to be equal to the number of samples.In the coming days, I will do some final performance checks. If I don't run into more issues, this PR is mostly ready from my side.
added 73 commits
-
f85c0598...0e7ee3a8 - 55 commits from branch
NIFTy_8
- 0e7ee3a8...3dc3075b - 8 earlier commits
- 6b232c3c - test_optimize_kl: test for different n_sample
- 7859a61a - optimize_kl: add docsting for "map_over_devices"
- e1947469 - optimize_kl: optionally pmap over samples
- 9842dc4c - a_multi_gpu: showcase use of pmap
- 7ee1dcec - test_optimize_kl: test pmap for mapping over CG
- 71bcbb3a - formatting
- 64c98875 - optimize_kl: refactor switch between device maps
- 7253ab3d - optimize_kl: allow device mapping with jit
- a274223d - optimize_kl: allow different device maps for KL
- 83c46efb - a_multi_gpu: update demo
Toggle commit list-
f85c0598...0e7ee3a8 - 55 commits from branch
I have run some benchmarks on a note with 4 GPUs. Basically, I run versions of the demo script with way more pixels and the non-static newton-cg. Using static version lead to significantly higher memory consumption as already reported in #417.
- Memory: The memory consumption when mapping over multiple devices seems to scale as expected. Specifically, when using n_samples=4, I was able to run approximately 3 times larger models when using 4 instead of 1 GPU. (It is very much expected that you can't run 4 times larger models, as you also need to save the data and the VI expansion point.)
- Speed: Running on multiple GPUs can, in principle, give speedups. Nevertheless, the speedup is less than linear. For smaller models, distributing over GPUs can even lead to slowdowns. I believe this comes from the communication overhead from the KL minimization. Also, I observed that when making the model bigger, the computational time often doesn't increase by the same factor but increases less. I believe this comes from small models not fully utilizing the GPU.