Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
9ffd5a56
Commit
9ffd5a56
authored
Jul 15, 2021
by
Philipp Arras
Browse files
KL energies: check that random state is synced across MPI tasks
parent
c2821f81
Pipeline
#105773
passed with stages
in 15 minutes and 15 seconds
Changes
3
Pipelines
1
Show whitespace changes
Inline
Side-by-side
src/minimization/kl_energies.py
View file @
9ffd5a56
...
...
@@ -211,6 +211,7 @@ class _MetricGaussianSampler:
def
draw_samples
(
self
,
comm
):
local_samples
=
[]
utilities
.
check_MPI_synced_random_state
(
comm
)
sseq
=
random
.
spawn_sseq
(
self
.
_n
)
for
i
in
range
(
*
_get_lo_hi
(
comm
,
self
.
_n
)):
with
random
.
Context
(
sseq
[
i
]):
...
...
@@ -315,6 +316,8 @@ class _GeoMetricSampler:
def
draw_samples
(
self
,
comm
):
local_samples
=
[]
prev
=
None
utilities
.
check_MPI_synced_random_state
(
comm
)
utilities
.
check_MPI_equality
(
self
.
_sseq
,
comm
)
for
i
in
range
(
*
_get_lo_hi
(
comm
,
self
.
n_eff_samples
)):
with
random
.
Context
(
self
.
_sseq
[
i
]):
neg
=
self
.
_neg
[
i
]
...
...
src/utilities.py
View file @
9ffd5a56
...
...
@@ -25,7 +25,8 @@ __all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo"
,
"NiftyMeta"
,
"my_sum"
,
"my_lincomb_simple"
,
"my_lincomb"
,
"indent"
,
"my_product"
,
"frozendict"
,
"special_add_at"
,
"iscomplextype"
,
"value_reshaper"
,
"lognormal_moments"
]
"value_reshaper"
,
"lognormal_moments"
,
"check_MPI_equality"
,
"check_MPI_synced_random_state"
]
def
my_sum
(
iterable
):
...
...
@@ -412,3 +413,80 @@ def myassert(val):
`__debug__` is False."""
if
not
val
:
raise
AssertionError
def
check_MPI_equality
(
obj
,
comm
):
"""Check that object is the same on all MPI tasks associated to a given
communicator.
Raises a RuntimeError if it differs.
Parameters
----------
obj :
Any Python object that implements __eq__.
comm : MPI communicator or None
If comm is None, no check will be performed
"""
# Special cases
if
comm
is
None
:
return
elif
isinstance
(
obj
,
list
):
_check_MPI_equality_lists
(
obj
,
comm
)
elif
isinstance
(
obj
,
np
.
random
.
SeedSequence
):
_check_MPI_equality_sseq
(
obj
,
comm
)
# /Special cases
else
:
if
not
_MPI_unique
(
obj
,
comm
):
raise
RuntimeError
(
"MPI tasks are not in sync"
)
def
_check_MPI_equality_lists
(
lst
,
comm
):
if
not
isinstance
(
lst
,
list
):
raise
TypeError
if
not
_MPI_unique
(
len
(
lst
),
comm
):
raise
RuntimeError
(
"MPI tasks are not in sync (lists have different lengths)"
)
is_sseq
=
comm
.
allgather
(
lst
[
0
])
if
is_sseq
[
0
]:
if
not
all
(
is_sseq
):
raise
RuntimeError
(
"First element in list is np.random.SeedSequence. The others (partly) not."
)
for
oo
in
lst
:
check_MPI_equality
(
oo
,
comm
)
return
for
ii
in
range
(
len
(
lst
)):
if
not
_MPI_unique
(
lst
[
ii
],
comm
):
raise
RuntimeError
(
f
"MPI tasks are not in sync (list element #
{
ii
}
does not match)"
)
def
_MPI_unique
(
obj
,
comm
):
return
len
(
set
(
comm
.
allgather
(
obj
)))
==
1
def
_check_MPI_equality_sseq
(
sseq
,
comm
):
from
.random
import
Context
,
spawn_sseq
,
current_rng
if
not
isinstance
(
sseq
,
np
.
random
.
SeedSequence
):
raise
TypeError
with
Context
(
spawn_sseq
(
1
,
parent
=
sseq
)[
0
]):
random_number
=
current_rng
().
normal
(
10.
,
1.2
,
(
1
,))[
0
]
gath
=
comm
.
allgather
(
random_number
)
if
gath
[
1
:]
!=
gath
[:
-
1
]:
raise
RuntimeError
(
"SeedSequences are not equal"
)
def
check_MPI_synced_random_state
(
comm
):
"""Check that random state is the same on all MPI tasks associated to a
given communicator.
Raises a RuntimeError if it differs.
Parameters
----------
comm : MPI communicator or None
If comm is None, no check will be performed
"""
from
.random
import
getState
if
comm
is
None
:
return
check_MPI_equality
(
getState
(),
comm
)
test/test_mpi/test_sync.py
0 → 100644
View file @
9ffd5a56
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import
pytest
from
mpi4py
import
MPI
import
nifty7
as
ift
from
..common
import
setup_function
,
teardown_function
comm
=
MPI
.
COMM_WORLD
ntask
=
comm
.
Get_size
()
rank
=
comm
.
Get_rank
()
master
=
(
rank
==
0
)
mpi
=
ntask
>
1
pmp
=
pytest
.
mark
.
parametrize
pms
=
pytest
.
mark
.
skipif
@
pms
(
not
mpi
,
reason
=
"requires at least two mpi tasks"
)
def
test_MPI_equality
():
obj
=
rank
with
pytest
.
raises
(
RuntimeError
):
ift
.
utilities
.
check_MPI_equality
(
obj
,
comm
)
obj
=
[
ii
+
rank
for
ii
in
range
(
10
,
12
)]
with
pytest
.
raises
(
RuntimeError
):
ift
.
utilities
.
check_MPI_equality
(
obj
,
comm
)
sseqs
=
ift
.
random
.
spawn_sseq
(
2
)
for
obj
in
[
12.
,
None
,
(
29
,
30
),
[
1
,
2
,
3
],
sseqs
[
0
],
sseqs
]:
ift
.
utilities
.
check_MPI_equality
(
obj
,
comm
)
obj
=
ift
.
random
.
spawn_sseq
(
2
,
parent
=
sseqs
[
comm
.
rank
])
with
pytest
.
raises
(
RuntimeError
):
ift
.
utilities
.
check_MPI_equality
(
obj
,
comm
)
@
pms
(
not
mpi
,
reason
=
"requires at least two mpi tasks"
)
def
test_MPI_synced_random_state
():
ift
.
utilities
.
check_MPI_synced_random_state
(
comm
)
if
master
:
ift
.
random
.
push_sseq_from_seed
(
123
)
with
pytest
.
raises
(
RuntimeError
):
ift
.
utilities
.
check_MPI_synced_random_state
(
comm
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment