Commit 88258c6e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'mpi_equality' into 'NIFTy_8'

Simplify MPI equality check

See merge request !664
parents 16827157 b8344c60
Pipeline #106060 passed with stages
in 34 minutes and 51 seconds
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import collections import collections
from functools import reduce from functools import reduce
from itertools import product from itertools import product
import pickle
import numpy as np import numpy as np
...@@ -444,51 +445,14 @@ def check_MPI_equality(obj, comm): ...@@ -444,51 +445,14 @@ def check_MPI_equality(obj, comm):
comm : MPI communicator or None comm : MPI communicator or None
If comm is None, no check will be performed If comm is None, no check will be performed
""" """
# Special cases
if comm is None: if comm is None:
return return
elif isinstance(obj, list): if not _MPI_unique(obj, comm):
_check_MPI_equality_lists(obj, comm) raise RuntimeError("MPI tasks are not in sync")
elif isinstance(obj, np.random.SeedSequence):
_check_MPI_equality_sseq(obj, comm)
# /Special cases
else:
if not _MPI_unique(obj, comm):
raise RuntimeError("MPI tasks are not in sync")
def _check_MPI_equality_lists(lst, comm):
if not isinstance(lst, list):
raise TypeError
if not _MPI_unique(len(lst), comm):
raise RuntimeError("MPI tasks are not in sync (lists have different lengths)")
is_sseq = comm.allgather(lst[0])
if is_sseq[0]:
if not all(is_sseq):
raise RuntimeError("First element in list is np.random.SeedSequence. The others (partly) not.")
for oo in lst:
check_MPI_equality(oo, comm)
return
for ii in range(len(lst)):
if not _MPI_unique(lst[ii], comm):
raise RuntimeError(f"MPI tasks are not in sync (list element #{ii} does not match)")
def _MPI_unique(obj, comm): def _MPI_unique(obj, comm):
return len(set(comm.allgather(obj))) == 1 return len(set(comm.allgather(pickle.dumps(obj)))) == 1
def _check_MPI_equality_sseq(sseq, comm):
from .random import Context, spawn_sseq, current_rng
if not isinstance(sseq, np.random.SeedSequence):
raise TypeError
with Context(spawn_sseq(1, parent=sseq)[0]):
random_number = current_rng().normal(10., 1.2, (1,))[0]
gath = comm.allgather(random_number)
if gath[1:] != gath[:-1]:
raise RuntimeError("SeedSequences are not equal")
def check_MPI_synced_random_state(comm): def check_MPI_synced_random_state(comm):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment