Commit 285d1960 authored by Philipp Arras's avatar Philipp Arras
Browse files

KL energies: check that random state is synced across MPI tasks

parent ba7e20e2
......@@ -211,6 +211,7 @@ class _MetricGaussianSampler:
def draw_samples(self, comm):
local_samples = []
utilities.check_MPI_synced_random_state(comm)
sseq = random.spawn_sseq(self._n)
for i in range(*_get_lo_hi(comm, self._n)):
with random.Context(sseq[i]):
......@@ -315,6 +316,8 @@ class _GeoMetricSampler:
def draw_samples(self, comm):
local_samples = []
prev = None
utilities.check_MPI_synced_random_state(comm)
utilities.check_MPI_equality(self._sseq, comm)
for i in range(*_get_lo_hi(comm, self.n_eff_samples)):
with random.Context(self._sseq[i]):
neg = self._neg[i]
......
......@@ -25,7 +25,8 @@ __all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMeta", "my_sum", "my_lincomb_simple",
"my_lincomb", "indent",
"my_product", "frozendict", "special_add_at", "iscomplextype",
"value_reshaper", "lognormal_moments", "check_domain_equality"]
"value_reshaper", "lognormal_moments", "check_domain_equality",
"check_MPI_equality", "check_MPI_synced_random_state"]
def my_sum(iterable):
......@@ -428,3 +429,80 @@ def check_domain_equality(domain0, domain1):
f"ift.MultiDomain nor of ift.DomainTuple.\n{dom}")
if domain0 != domain1:
raise ValueError(f"Domain mismatch:\n{domain0}\n{domain1}")
def check_MPI_equality(obj, comm):
"""Check that object is the same on all MPI tasks associated to a given
communicator.
Raises a RuntimeError if it differs.
Parameters
----------
obj :
Any Python object that implements __eq__.
comm : MPI communicator or None
If comm is None, no check will be performed
"""
# Special cases
if comm is None:
return
elif isinstance(obj, list):
_check_MPI_equality_lists(obj, comm)
elif isinstance(obj, np.random.SeedSequence):
_check_MPI_equality_sseq(obj, comm)
# /Special cases
else:
if not _MPI_unique(obj, comm):
raise RuntimeError("MPI tasks are not in sync")
def _check_MPI_equality_lists(lst, comm):
if not isinstance(lst, list):
raise TypeError
if not _MPI_unique(len(lst), comm):
raise RuntimeError("MPI tasks are not in sync (lists have different lengths)")
is_sseq = comm.allgather(lst[0])
if is_sseq[0]:
if not all(is_sseq):
raise RuntimeError("First element in list is np.random.SeedSequence. The others (partly) not.")
for oo in lst:
check_MPI_equality(oo, comm)
return
for ii in range(len(lst)):
if not _MPI_unique(lst[ii], comm):
raise RuntimeError(f"MPI tasks are not in sync (list element #{ii} does not match)")
def _MPI_unique(obj, comm):
return len(set(comm.allgather(obj))) == 1
def _check_MPI_equality_sseq(sseq, comm):
from .random import Context, spawn_sseq, current_rng
if not isinstance(sseq, np.random.SeedSequence):
raise TypeError
with Context(spawn_sseq(1, parent=sseq)[0]):
random_number = current_rng().normal(10., 1.2, (1,))[0]
gath = comm.allgather(random_number)
if gath[1:] != gath[:-1]:
raise RuntimeError("SeedSequences are not equal")
def check_MPI_synced_random_state(comm):
"""Check that random state is the same on all MPI tasks associated to a
given communicator.
Raises a RuntimeError if it differs.
Parameters
----------
comm : MPI communicator or None
If comm is None, no check will be performed
"""
from .random import getState
if comm is None:
return
check_MPI_equality(getState(), comm)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import pytest
from mpi4py import MPI
import nifty8 as ift
from ..common import setup_function, teardown_function
comm = MPI.COMM_WORLD
ntask = comm.Get_size()
rank = comm.Get_rank()
master = (rank == 0)
mpi = ntask > 1
pmp = pytest.mark.parametrize
pms = pytest.mark.skipif
@pms(not mpi, reason="requires at least two mpi tasks")
def test_MPI_equality():
obj = rank
with pytest.raises(RuntimeError):
ift.utilities.check_MPI_equality(obj, comm)
obj = [ii + rank for ii in range(10, 12)]
with pytest.raises(RuntimeError):
ift.utilities.check_MPI_equality(obj, comm)
sseqs = ift.random.spawn_sseq(2)
for obj in [12., None, (29, 30), [1, 2, 3], sseqs[0], sseqs]:
ift.utilities.check_MPI_equality(obj, comm)
obj = ift.random.spawn_sseq(2, parent=sseqs[comm.rank])
with pytest.raises(RuntimeError):
ift.utilities.check_MPI_equality(obj, comm)
@pms(not mpi, reason="requires at least two mpi tasks")
def test_MPI_synced_random_state():
ift.utilities.check_MPI_synced_random_state(comm)
if master:
ift.random.push_sseq_from_seed(123)
with pytest.raises(RuntimeError):
ift.utilities.check_MPI_synced_random_state(comm)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment