Skip to content
Snippets Groups Projects

Multi gpu

Merged Jakob Roth requested to merge multi_gpu into NIFTy_8
1 file
+ 5
0
Compare changes
  • Side-by-side
  • Inline
+ 5
0
@@ -137,6 +137,11 @@ samples, state = jft.optimize_kl(
kl_map='smap',
residual_map='smap',
map_over_devices=jax.devices(),
# NOTE: The IWP model of the correlated field currently creates performance
# issues when used with shared_map device parallelization for the CG. For
# mitigation either disable flexibility and asperity or use pmap to
# parallelize over the sampling CG.
use_pmap=True,
)
Loading