diff --git a/src/minimization/kl_energies.py b/src/minimization/kl_energies.py
index caf31d0e1e5590aea105558680b6e62cbc6298b8..0025996ef055e0473d10b396d749a3f9d3aed2df 100644
--- a/src/minimization/kl_energies.py
+++ b/src/minimization/kl_energies.py
@@ -211,6 +211,7 @@ class _MetricGaussianSampler:
def draw_samples(self, comm):
local_samples = []
+ utilities.check_MPI_synced_random_state(comm)
sseq = random.spawn_sseq(self._n)
for i in range(*_get_lo_hi(comm, self._n)):
with random.Context(sseq[i]):
@@ -315,6 +316,8 @@ class _GeoMetricSampler:
def draw_samples(self, comm):
local_samples = []
prev = None
+ utilities.check_MPI_synced_random_state(comm)
+ utilities.check_MPI_equality(self._sseq, comm)
for i in range(*_get_lo_hi(comm, self.n_eff_samples)):
with random.Context(self._sseq[i]):
neg = self._neg[i]
diff --git a/src/utilities.py b/src/utilities.py
index 7f843fce03f9655ccf45dfddd2fc90a4a272497c..465f829e70310f9e3fb10edadf4af29f5e2784c6 100644
--- a/src/utilities.py
+++ b/src/utilities.py
@@ -25,7 +25,8 @@ __all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMeta", "my_sum", "my_lincomb_simple",
"my_lincomb", "indent",
"my_product", "frozendict", "special_add_at", "iscomplextype",
- "value_reshaper", "lognormal_moments"]
+ "value_reshaper", "lognormal_moments",
+ "check_MPI_equality", "check_MPI_synced_random_state"]
def my_sum(iterable):
@@ -412,3 +413,80 @@ def myassert(val):
`__debug__` is False."""
if not val:
raise AssertionError
+
+
+def check_MPI_equality(obj, comm):
+ """Check that object is the same on all MPI tasks associated to a given
+ communicator.
+
+ Raises a RuntimeError if it differs.
+
+ Parameters
+ ----------
+ obj :
+ Any Python object that implements __eq__.
+ comm : MPI communicator or None
+ If comm is None, no check will be performed
+ """
+ # Special cases
+ if comm is None:
+ return
+ elif isinstance(obj, list):
+ _check_MPI_equality_lists(obj, comm)
+ elif isinstance(obj, np.random.SeedSequence):
+ _check_MPI_equality_sseq(obj, comm)
+ # /Special cases
+ else:
+ if not _MPI_unique(obj, comm):
+ raise RuntimeError("MPI tasks are not in sync")
+
+
+def _check_MPI_equality_lists(lst, comm):
+ if not isinstance(lst, list):
+ raise TypeError
+ if not _MPI_unique(len(lst), comm):
+ raise RuntimeError("MPI tasks are not in sync (lists have different lengths)")
+
+ is_sseq = comm.allgather(lst[0])
+ if is_sseq[0]:
+ if not all(is_sseq):
+ raise RuntimeError("First element in list is np.random.SeedSequence. The others (partly) not.")
+ for oo in lst:
+ check_MPI_equality(oo, comm)
+ return
+
+ for ii in range(len(lst)):
+ if not _MPI_unique(lst[ii], comm):
+ raise RuntimeError(f"MPI tasks are not in sync (list element #{ii} does not match)")
+
+
+def _MPI_unique(obj, comm):
+ return len(set(comm.allgather(obj))) == 1
+
+
+def _check_MPI_equality_sseq(sseq, comm):
+ from .random import Context, spawn_sseq, current_rng
+ if not isinstance(sseq, np.random.SeedSequence):
+ raise TypeError
+ with Context(spawn_sseq(1, parent=sseq)[0]):
+ random_number = current_rng().normal(10., 1.2, (1,))[0]
+ gath = comm.allgather(random_number)
+ if gath[1:] != gath[:-1]:
+ raise RuntimeError("SeedSequences are not equal")
+
+
+def check_MPI_synced_random_state(comm):
+ """Check that random state is the same on all MPI tasks associated to a
+ given communicator.
+
+ Raises a RuntimeError if it differs.
+
+ Parameters
+ ----------
+ comm : MPI communicator or None
+ If comm is None, no check will be performed
+ """
+ from .random import getState
+ if comm is None:
+ return
+ check_MPI_equality(getState(), comm)
diff --git a/test/test_mpi/test_sync.py b/test/test_mpi/test_sync.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4744990edba178d7b39d711584bcf60b36c00ab
--- /dev/null
+++ b/test/test_mpi/test_sync.py
@@ -0,0 +1,61 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+#
+# Copyright(C) 2021 Max-Planck-Society
+#
+# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
+
+import pytest
+from mpi4py import MPI
+
+import nifty7 as ift
+
+from ..common import setup_function, teardown_function
+
+comm = MPI.COMM_WORLD
+ntask = comm.Get_size()
+rank = comm.Get_rank()
+master = (rank == 0)
+mpi = ntask > 1
+
+pmp = pytest.mark.parametrize
+pms = pytest.mark.skipif
+
+
+@pms(not mpi, reason="requires at least two mpi tasks")
+def test_MPI_equality():
+ obj = rank
+ with pytest.raises(RuntimeError):
+ ift.utilities.check_MPI_equality(obj, comm)
+
+ obj = [ii + rank for ii in range(10, 12)]
+ with pytest.raises(RuntimeError):
+ ift.utilities.check_MPI_equality(obj, comm)
+
+ sseqs = ift.random.spawn_sseq(2)
+ for obj in [12., None, (29, 30), [1, 2, 3], sseqs[0], sseqs]:
+ ift.utilities.check_MPI_equality(obj, comm)
+
+ obj = ift.random.spawn_sseq(2, parent=sseqs[comm.rank])
+ with pytest.raises(RuntimeError):
+ ift.utilities.check_MPI_equality(obj, comm)
+
+
+@pms(not mpi, reason="requires at least two mpi tasks")
+def test_MPI_synced_random_state():
+ ift.utilities.check_MPI_synced_random_state(comm)
+
+ if master:
+ ift.random.push_sseq_from_seed(123)
+ with pytest.raises(RuntimeError):
+ ift.utilities.check_MPI_synced_random_state(comm)