Skip to content
Snippets Groups Projects

Multi gpu

Merged Jakob Roth requested to merge multi_gpu into NIFTy_8
3 files
+ 14
11
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 6
6
@@ -10,7 +10,10 @@
# %%
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
os.environ["XLA_FLAGS"] = (
"--xla_force_host_platform_device_count=8" # Use 8 CPU devices
)
import jax
import matplotlib.pyplot as plt
@@ -96,8 +99,6 @@ key, k_i, k_o = random.split(key, 3)
# `resamples=False`, as more samples have to be drawn that did not exist before.
init_pos = jft.Vector(lh.init(k_i))
jax.debug.visualize_array_sharding(init_pos)
# raise
samples, state = jft.optimize_kl(
lh,
@@ -134,15 +135,14 @@ samples, state = jft.optimize_kl(
sample_mode="linear_resample",
odir="results_intro",
resume=False,
kl_map='smap',
residual_map='smap',
kl_map="smap",
residual_map="smap",
map_over_devices=jax.devices(),
# NOTE: The IWP model of the correlated field currently creates performance
# issues when used with shared_map device parallelization for the CG. For
# mitigation either disable flexibility and asperity or use pmap to
# parallelize over the sampling CG.
use_pmap=True,
)
# %%
Loading