Commit 0d382cd1 authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge branch 'backports' into 'NIFTy_7'

Backports nifty8 -> nifty7

See merge request !663
parents c9687c8a 9ffd5a56
Pipeline #105882 passed with stages
in 30 minutes and 42 seconds
......@@ -12,7 +12,7 @@ RUN apt-get update && apt-get install -y \
# Optional NIFTy dependencies
python3-mpi4py python3-matplotlib \
# more optional NIFTy dependencies
&& pip3 install ducc0 finufft jupyter sphinx pydata-sphinx-theme \
&& DUCC0_OPTIMIZATION=portable pip3 install ducc0 finufft jupyter sphinx pydata-sphinx-theme \
&& rm -rf /var/lib/apt/lists/*
# Set matplotlib backend
......
......@@ -211,6 +211,7 @@ class _MetricGaussianSampler:
def draw_samples(self, comm):
local_samples = []
utilities.check_MPI_synced_random_state(comm)
sseq = random.spawn_sseq(self._n)
for i in range(*_get_lo_hi(comm, self._n)):
with random.Context(sseq[i]):
......@@ -315,6 +316,8 @@ class _GeoMetricSampler:
def draw_samples(self, comm):
local_samples = []
prev = None
utilities.check_MPI_synced_random_state(comm)
utilities.check_MPI_equality(self._sseq, comm)
for i in range(*_get_lo_hi(comm, self.n_eff_samples)):
with random.Context(self._sseq[i]):
neg = self._neg[i]
......@@ -407,7 +410,7 @@ def MetricGaussianKL(mean, hamiltonian, n_samples, mirror_samples, constants=[],
_, ham_sampling = _reduce_by_keys(mean, hamiltonian, point_estimates)
sampler = _MetricGaussianSampler(mean, ham_sampling, n_samples,
mirror_samples)
mirror_samples, napprox)
local_samples = sampler.draw_samples(comm)
mean, hamiltonian = _reduce_by_keys(mean, hamiltonian, constants)
......@@ -416,7 +419,7 @@ def MetricGaussianKL(mean, hamiltonian, n_samples, mirror_samples, constants=[],
def GeoMetricKL(mean, hamiltonian, n_samples, minimizer_samp, mirror_samples,
start_from_lin = True, constants=[], point_estimates=[],
start_from_lin=True, constants=[], point_estimates=[],
napprox=0, comm=None, nanisinf=True):
"""Provides the sampled Kullback-Leibler used in geometric Variational
Inference (geoVI).
......@@ -487,10 +490,10 @@ def GeoMetricKL(mean, hamiltonian, n_samples, minimizer_samp, mirror_samples,
As in MGVI, mirroring samples can help to stabilize the latent mean as it
reduces sampling noise. But unlike MGVI a mirrored sample involves an
additional solve of the non-linear transformation. Therefore, when using
MPI, the mirrored samples also get distributed if enough tasks are available.
If there are more total samples than tasks, the mirrored counterparts
try to reside on the same task as their non mirrored partners. This ensures
that at least the starting position can be re-used.
MPI, the mirrored samples also get distributed if enough tasks are
available. If there are more total samples than tasks, the mirrored
counterparts try to reside on the same task as their non mirrored partners.
This ensures that at least the starting position can be re-used.
See also
--------
......@@ -517,7 +520,8 @@ def GeoMetricKL(mean, hamiltonian, n_samples, minimizer_samp, mirror_samples,
_, ham_sampling = _reduce_by_keys(mean, hamiltonian, point_estimates)
sampler = _GeoMetricSampler(mean, ham_sampling, minimizer_samp,
start_from_lin, n_samples, mirror_samples)
start_from_lin, n_samples, mirror_samples,
napprox)
local_samples = sampler.draw_samples(comm)
mean, hamiltonian = _reduce_by_keys(mean, hamiltonian, constants)
return _SampledKLEnergy(mean, hamiltonian, sampler.n_eff_samples, False,
......
......@@ -25,7 +25,8 @@ __all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMeta", "my_sum", "my_lincomb_simple",
"my_lincomb", "indent",
"my_product", "frozendict", "special_add_at", "iscomplextype",
"value_reshaper", "lognormal_moments"]
"value_reshaper", "lognormal_moments",
"check_MPI_equality", "check_MPI_synced_random_state"]
def my_sum(iterable):
......@@ -412,3 +413,80 @@ def myassert(val):
`__debug__` is False."""
if not val:
raise AssertionError
def check_MPI_equality(obj, comm):
"""Check that object is the same on all MPI tasks associated to a given
communicator.
Raises a RuntimeError if it differs.
Parameters
----------
obj :
Any Python object that implements __eq__.
comm : MPI communicator or None
If comm is None, no check will be performed
"""
# Special cases
if comm is None:
return
elif isinstance(obj, list):
_check_MPI_equality_lists(obj, comm)
elif isinstance(obj, np.random.SeedSequence):
_check_MPI_equality_sseq(obj, comm)
# /Special cases
else:
if not _MPI_unique(obj, comm):
raise RuntimeError("MPI tasks are not in sync")
def _check_MPI_equality_lists(lst, comm):
if not isinstance(lst, list):
raise TypeError
if not _MPI_unique(len(lst), comm):
raise RuntimeError("MPI tasks are not in sync (lists have different lengths)")
is_sseq = comm.allgather(lst[0])
if is_sseq[0]:
if not all(is_sseq):
raise RuntimeError("First element in list is np.random.SeedSequence. The others (partly) not.")
for oo in lst:
check_MPI_equality(oo, comm)
return
for ii in range(len(lst)):
if not _MPI_unique(lst[ii], comm):
raise RuntimeError(f"MPI tasks are not in sync (list element #{ii} does not match)")
def _MPI_unique(obj, comm):
return len(set(comm.allgather(obj))) == 1
def _check_MPI_equality_sseq(sseq, comm):
from .random import Context, spawn_sseq, current_rng
if not isinstance(sseq, np.random.SeedSequence):
raise TypeError
with Context(spawn_sseq(1, parent=sseq)[0]):
random_number = current_rng().normal(10., 1.2, (1,))[0]
gath = comm.allgather(random_number)
if gath[1:] != gath[:-1]:
raise RuntimeError("SeedSequences are not equal")
def check_MPI_synced_random_state(comm):
"""Check that random state is the same on all MPI tasks associated to a
given communicator.
Raises a RuntimeError if it differs.
Parameters
----------
comm : MPI communicator or None
If comm is None, no check will be performed
"""
from .random import getState
if comm is None:
return
check_MPI_equality(getState(), comm)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import pytest
from mpi4py import MPI
import nifty7 as ift
from ..common import setup_function, teardown_function
comm = MPI.COMM_WORLD
ntask = comm.Get_size()
rank = comm.Get_rank()
master = (rank == 0)
mpi = ntask > 1
pmp = pytest.mark.parametrize
pms = pytest.mark.skipif
@pms(not mpi, reason="requires at least two mpi tasks")
def test_MPI_equality():
obj = rank
with pytest.raises(RuntimeError):
ift.utilities.check_MPI_equality(obj, comm)
obj = [ii + rank for ii in range(10, 12)]
with pytest.raises(RuntimeError):
ift.utilities.check_MPI_equality(obj, comm)
sseqs = ift.random.spawn_sseq(2)
for obj in [12., None, (29, 30), [1, 2, 3], sseqs[0], sseqs]:
ift.utilities.check_MPI_equality(obj, comm)
obj = ift.random.spawn_sseq(2, parent=sseqs[comm.rank])
with pytest.raises(RuntimeError):
ift.utilities.check_MPI_equality(obj, comm)
@pms(not mpi, reason="requires at least two mpi tasks")
def test_MPI_synced_random_state():
ift.utilities.check_MPI_synced_random_state(comm)
if master:
ift.random.push_sseq_from_seed(123)
with pytest.raises(RuntimeError):
ift.utilities.check_MPI_synced_random_state(comm)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment