Commit bee6f818 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add MPI sync tests

parent b7246142
Pipeline #106270 passed with stages
in 20 minutes and 52 seconds
......@@ -54,8 +54,38 @@ def test_MPI_equality():
@pms(not mpi, reason="requires at least two mpi tasks")
def test_MPI_synced_random_state():
ift.utilities.check_MPI_synced_random_state(comm)
with ift.random.Context(123 if master else 111):
with pytest.raises(RuntimeError):
ift.utilities.check_MPI_synced_random_state(comm)
if master:
ift.random.push_sseq_from_seed(123)
with pytest.raises(RuntimeError):
ift.utilities.check_MPI_synced_random_state(comm)
@pms(not mpi, reason="requires at least two mpi tasks")
@pmp("kl", [lambda a, b, c, d: ift.MetricGaussianKL(a, b, c, d, comm=comm),
lambda a, b, c, d: ift.GeoMetricKL(a, b, c,
ift.NewtonCG(ift.AbsDeltaEnergyController(0.1, iteration_limit=2)),
d, comm=comm)
])
@pmp("mirror", [False, True])
@pmp("n_samples", [2, 3])
def test_MPI_synced_random_state_kl_energies(kl, mirror, n_samples):
ic = ift.AbsDeltaEnergyController(0.1, iteration_limit=2)
lh = ift.GaussianEnergy(ift.full(ift.UnstructuredDomain(2), 2.)).ducktape("a")
ham = ift.StandardHamiltonian(lh, ic)
ift.utilities.check_MPI_synced_random_state(comm)
with ift.random.Context(123 if master else 111):
mean = ift.from_random(ham.domain)
with pytest.raises(RuntimeError):
kl(mean, ham, n_samples, mirror)
@pms(not mpi, reason="requires at least two mpi tasks")
@pmp("sync", [False, True])
def test_random_field_generation(sync):
with ift.random.Context(123 if master and not sync else 111):
dom = ift.UnstructuredDomain(5)
fld = ift.from_random(dom)
if sync:
ift.utilities.check_MPI_equality(fld, comm)
else:
with pytest.raises(RuntimeError):
ift.utilities.check_MPI_equality(fld, comm)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment