Skip to content
Snippets Groups Projects

Multi gpu

Merged Jakob Roth requested to merge multi_gpu into NIFTy_8

This PR introduces the feature of distributing the samples of a variational inference reconstruction with optimize_kl over multiple JAX devices. Specifically, by passing a list of JAX devices and device mapping operation, samples can be drawn in parallel on multiple devices, and also the KL minimization can be distributed.

Parallelizing over multiple devices primarily brings the advantage that, in total, more GPU memory is available, and therefore, larger reconstructions can be done. Additionally, the inference can become computationally faster by distributing the compute over multiple GPUs. When mapping over devices the same random numbers are used as when running on a single device. Thus, up to numerical effects, the same samples are drawn.

JAX offers several ways to parallelize across multiple GPUs, and development does not seem to have converged. The current options to map over devices are pmap and shard_map for manual parallelization as well as automatic parallelization with jit compilation. Currently it is not clear (to me) what the best approach to be used in optimize_kl is, and this might depend on the model and future JAX versions. Therefore, I, for now, leave the choice to the user on how to map over devices similarly as we already left the choice to the user on how to map over the samples on a single device.

Edited by Jakob Roth

Merge request reports

Pipeline #241455 passed

Pipeline passed for c21afaec on multi_gpu

Test coverage 80.00% (0.00%) from 1 job

Merged by Gordian EdenhoferGordian Edenhofer 1 month ago (Mar 7, 2025 9:30pm UTC)

Loading

Pipeline #242880 failed

Pipeline failed for 1743c4a8 on NIFTy_8

Test coverage 79.00% (0.00%) from 1 job

Activity

Filter activity
  • Approvals
  • Assignees & reviewers
  • Comments (from bots)
  • Comments (from users)
  • Commits & branches
  • Edits
  • Labels
  • Lock status
  • Mentions
  • Merge request status
  • Tracking
  • The demo works great for me and the code looks great! :tada: I think it needs to be rebased on every_jaxlib_version and needs a jax version constraint to a very recent JAX version as to make jax.make_mesh available. (As always, I have some other nitpicky comments that I'm happy to share once you think the PR is ready for a merge :see_no_evil: :see_no_evil:)

    If you run benchmarks, I would be very curious to know how far jax.jit(..., in_shardings=..., out_shardings=...) gets us. I think the current implementation is better as it explicitly tells JAX to use the best axis for parallelization but I would be very curious to see how good JAX's automagic discovery is.

  • Gordian Edenhofer
  • Jakob Roth added 1 commit

    added 1 commit

    • 62816eb8 - optimize_kl: avoid using JAX>=0.5.0 functions

    Compare with previous version

  • Jakob Roth added 1 commit

    added 1 commit

    • e62a8191 - optimze_kl: simplify shared_map in/out_specs

    Compare with previous version

  • Author Maintainer

    In our current implementation of the newton_cg for minimizing the KL we solve one cg (per newton step) independent of how many samples we use. We do this by averaging over the gradients and metric of the individual samples in each step of this CG. I believe this approach is beneficial if all samples are on the same devices and the reduce (thus averaging) operations over the sample axis are quick.

    I am not sure if this is a good approach if the samples are on separate devices, as this triggers a lot of communication between the devices. In fact, after each CG step, the results of the metric application need to be cross-communicated between the devices to compute the mean value. I believe if samples are distributed over multiple devices, it would be better if each device solves the CG for its local samples, and then the results are averaged in the end.

    As the queues on the GPU cluster are currently very long, I did not have the opportunity to test how much this communication overhead actually slows things down. But I could imagine that it is a significant effect, making it worth implementing a NewtonCG as proposed.

    I think this was never actually done, but in the old NIFTy there should be the same effect when running with MPI tasks distributed over multiple nodes. @mtr, @pfrank, or @g-philipp had you ever thought about implementing something similar in the old NIFTy and have insides to share with me?

  • I believe if samples are distributed over multiple devices, it would be better if each device solves the CG for its local samples, and then the results are averaged in the end.

    It will most likely reduce communication, but I could imagine that this breaks some fundamental assumptions of the algorithm. To me it sounds pretty dangerous, and it will also make the computation results depend on how many GPUs were used.

    It's true that comunication overheads are even worse when computing on GPUs instead of CPUs, but I wouldn't trust this approach without very thorough testing and analytical double-checking.

    There may be the additional complication that CG requires a different number of iterations on different devices, which replaces the communication overhead with load imbalance. Or, if we are especially unlucky, that CG iteration count generally goes up with decreasing number of samples...

  • Author Maintainer

    Ah wait, I think this is even mathematically not possible. Even aside from numerics effects, this wouldn't give the same results. Sorry for the noise!

    • Thanks for the great work! Looks really cool. Indeed, as you pointed out already distributing this way would alter the algorithm in a non-tested way.

      I believe there is space to explore altering our VI algorithms by modifying what we use as the proximal curvature. If somebody wants to look at it I'd suggest to do this in a separate feature though. Only for completeness: the presumably most direct modification to test would be to replace the average metric with the metric evaluated at the latent mean and compare full runs against each other. (And potentially try other things @gedenhof experimented with such as curvature + trust-region-ncg instead of metric + nncg). As Martin pointed out already, the overall performance gain/loss will likely be problem dependent.

    • Jakob and I played around with trust region NCG but I think we did not find any significant differences and thus opted not to pursue it further.

      In JAX, I have the feeling that using the Hessian is better than using the metric. The NCG minimizer in its current form should already handle the case of negative eigenvalues correctly (basically detect them and then stop). In essence everything is already set. We probably should do a representative test in the group first though if we want to switch.

    • Author Maintainer

      I agree that trust region NCG will probably not make much of a difference.

      Using the metric at the mean latent mean position instead of averaging the metric might make a difference though, as it would require much less communication between GPUs.

    • Please register or sign in to reply
  • Author Maintainer

    Yes, leaving the minimizers' math unchanged in this PR sounds like a good idea to me. When I know how much communication overhead this creates, I will let you know.

    In a separate PR we could eventually test minimizes better parallelising over multiple devices.

  • Jakob Roth added 1 commit

    added 1 commit

    • 7cf18580 - test_optimize_kl: test single VS multi device

    Compare with previous version

  • Jakob Roth added 2 commits

    added 2 commits

    • 2a756f39 - optimize_kl: add multi gpu support for geoVI
    • a5a01f08 - test_optimize_kl: test shard_map for all smpl modes

    Compare with previous version

  • Jakob Roth added 3 commits

    added 3 commits

    • c18a6b62 - optimize_kl: special case n_smpl/2 = #devices
    • 0923b083 - test_optimize_kl: test for different n_sample
    • d95c7dbc - optimize_kl: add docsting for "map_over_devices"

    Compare with previous version

  • Jakob Roth added 3 commits

    added 3 commits

    • 22eef3eb - optimize_kl: optionally pmap over samples
    • 8b7e4aa9 - a_multi_gpu: showcase use of pmap
    • b483d865 - test_optimize_kl: test pmap for mapping over CG

    Compare with previous version

  • Author Maintainer

    I made initial performance tests on actual GPUs. Thereby, I noticed that for some models mapping with shared_map over the draw_linear_residual (here) leads to exceptionally long compile times. Specifically, when I increase the dims in the demo to a much larger number of pixels, I get extremely long compile times and warnings about constant folding in the integrated wiener process. Interestingly when running on a single GPU this doesn't happen. Also, this does not occur in the KL minimization even with shared_map over multiple GPUs.

    I believe these excessive complete times for some models in the shard_map sampling could have something to do with disabling some checks (here), as JAX currently does not fully support shard_map for our CG implementation. This would explain why long compile times are only observed when mapping the sampling. For other models, as for example when disabling flexibility and asperity in the demo, this problem does not occur.

    As a workaround, I have now implemented the possibility to also pmap over the CG. This seems to work fine but comes with the restriction that the number of GPUs has to be equal to the number of samples.

    In the coming days, I will do some final performance checks. If I don't run into more issues, this PR is mostly ready from my side.

  • Jakob Roth added 1 commit

    added 1 commit

    Compare with previous version

  • Jakob Roth added 1 commit

    added 1 commit

    Compare with previous version

  • Jakob Roth added 73 commits

    added 73 commits

    • f85c0598...0e7ee3a8 - 55 commits from branch NIFTy_8
    • 0e7ee3a8...3dc3075b - 8 earlier commits
    • 6b232c3c - test_optimize_kl: test for different n_sample
    • 7859a61a - optimize_kl: add docsting for "map_over_devices"
    • e1947469 - optimize_kl: optionally pmap over samples
    • 9842dc4c - a_multi_gpu: showcase use of pmap
    • 7ee1dcec - test_optimize_kl: test pmap for mapping over CG
    • 71bcbb3a - formatting
    • 64c98875 - optimize_kl: refactor switch between device maps
    • 7253ab3d - optimize_kl: allow device mapping with jit
    • a274223d - optimize_kl: allow different device maps for KL
    • 83c46efb - a_multi_gpu: update demo

    Compare with previous version

  • Jakob Roth added 1 commit

    added 1 commit

    • d54a3c68 - optimize_kl: update docstring for multi gpu

    Compare with previous version

  • Author Maintainer

    I have run some benchmarks on a note with 4 GPUs. Basically, I run versions of the demo script with way more pixels and the non-static newton-cg. Using static version lead to significantly higher memory consumption as already reported in #417.

    • Memory: The memory consumption when mapping over multiple devices seems to scale as expected. Specifically, when using n_samples=4, I was able to run approximately 3 times larger models when using 4 instead of 1 GPU. (It is very much expected that you can't run 4 times larger models, as you also need to save the data and the VI expansion point.)
    • Speed: Running on multiple GPUs can, in principle, give speedups. Nevertheless, the speedup is less than linear. For smaller models, distributing over GPUs can even lead to slowdowns. I believe this comes from the communication overhead from the KL minimization. Also, I observed that when making the model bigger, the computational time often doesn't increase by the same factor but increases less. I believe this comes from small models not fully utilizing the GPU.
  • Loading
  • Loading
  • Loading
  • Loading
  • Loading
  • Loading
  • Loading
  • Loading
  • Loading
  • Loading
  • Please register or sign in to reply
    Loading