Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
16827157
Commit
16827157
authored
Jul 15, 2021
by
Martin Reinecke
Browse files
Merge branch 'mpi_tweaks' into 'NIFTy_8'
Mpi tweaks See merge request
!662
parents
f2ec76f0
a269bbb6
Pipeline
#105795
passed with stages
in 34 minutes and 56 seconds
Changes
4
Pipelines
2
Hide whitespace changes
Inline
Side-by-side
Dockerfile
View file @
16827157
...
...
@@ -12,7 +12,7 @@ RUN apt-get update && apt-get install -y \
# Optional NIFTy dependencies
python3-mpi4py python3-matplotlib \
# more optional NIFTy dependencies
&& pip3 install ducc0 finufft jupyter jax jaxlib sphinx pydata-sphinx-theme \
&&
DUCC0_OPTIMIZATION=portable
pip3 install ducc0 finufft jupyter jax jaxlib sphinx pydata-sphinx-theme \
&& rm -rf /var/lib/apt/lists/*
# Set matplotlib backend
...
...
src/minimization/kl_energies.py
View file @
16827157
...
...
@@ -211,6 +211,7 @@ class _MetricGaussianSampler:
def
draw_samples
(
self
,
comm
):
local_samples
=
[]
utilities
.
check_MPI_synced_random_state
(
comm
)
sseq
=
random
.
spawn_sseq
(
self
.
_n
)
for
i
in
range
(
*
_get_lo_hi
(
comm
,
self
.
_n
)):
with
random
.
Context
(
sseq
[
i
]):
...
...
@@ -315,6 +316,8 @@ class _GeoMetricSampler:
def
draw_samples
(
self
,
comm
):
local_samples
=
[]
prev
=
None
utilities
.
check_MPI_synced_random_state
(
comm
)
utilities
.
check_MPI_equality
(
self
.
_sseq
,
comm
)
for
i
in
range
(
*
_get_lo_hi
(
comm
,
self
.
n_eff_samples
)):
with
random
.
Context
(
self
.
_sseq
[
i
]):
neg
=
self
.
_neg
[
i
]
...
...
@@ -407,7 +410,7 @@ def MetricGaussianKL(mean, hamiltonian, n_samples, mirror_samples, constants=[],
_
,
ham_sampling
=
_reduce_by_keys
(
mean
,
hamiltonian
,
point_estimates
)
sampler
=
_MetricGaussianSampler
(
mean
,
ham_sampling
,
n_samples
,
mirror_samples
)
mirror_samples
,
napprox
)
local_samples
=
sampler
.
draw_samples
(
comm
)
mean
,
hamiltonian
=
_reduce_by_keys
(
mean
,
hamiltonian
,
constants
)
...
...
@@ -416,7 +419,7 @@ def MetricGaussianKL(mean, hamiltonian, n_samples, mirror_samples, constants=[],
def
GeoMetricKL
(
mean
,
hamiltonian
,
n_samples
,
minimizer_samp
,
mirror_samples
,
start_from_lin
=
True
,
constants
=
[],
point_estimates
=
[],
start_from_lin
=
True
,
constants
=
[],
point_estimates
=
[],
napprox
=
0
,
comm
=
None
,
nanisinf
=
True
):
"""Provides the sampled Kullback-Leibler used in geometric Variational
Inference (geoVI).
...
...
@@ -487,10 +490,10 @@ def GeoMetricKL(mean, hamiltonian, n_samples, minimizer_samp, mirror_samples,
As in MGVI, mirroring samples can help to stabilize the latent mean as it
reduces sampling noise. But unlike MGVI a mirrored sample involves an
additional solve of the non-linear transformation. Therefore, when using
MPI, the mirrored samples also get distributed if enough tasks are
available.
If there are more total samples than tasks, the mirrored
counterparts
try to reside on the same task as their non mirrored partners.
This ensures
that at least the starting position can be re-used.
MPI, the mirrored samples also get distributed if enough tasks are
available.
If there are more total samples than tasks, the mirrored
counterparts
try to reside on the same task as their non mirrored partners.
This ensures
that at least the starting position can be re-used.
See also
--------
...
...
@@ -517,7 +520,8 @@ def GeoMetricKL(mean, hamiltonian, n_samples, minimizer_samp, mirror_samples,
_
,
ham_sampling
=
_reduce_by_keys
(
mean
,
hamiltonian
,
point_estimates
)
sampler
=
_GeoMetricSampler
(
mean
,
ham_sampling
,
minimizer_samp
,
start_from_lin
,
n_samples
,
mirror_samples
)
start_from_lin
,
n_samples
,
mirror_samples
,
napprox
)
local_samples
=
sampler
.
draw_samples
(
comm
)
mean
,
hamiltonian
=
_reduce_by_keys
(
mean
,
hamiltonian
,
constants
)
return
_SampledKLEnergy
(
mean
,
hamiltonian
,
sampler
.
n_eff_samples
,
False
,
...
...
src/utilities.py
View file @
16827157
...
...
@@ -25,7 +25,8 @@ __all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo"
,
"NiftyMeta"
,
"my_sum"
,
"my_lincomb_simple"
,
"my_lincomb"
,
"indent"
,
"my_product"
,
"frozendict"
,
"special_add_at"
,
"iscomplextype"
,
"value_reshaper"
,
"lognormal_moments"
,
"check_domain_equality"
]
"value_reshaper"
,
"lognormal_moments"
,
"check_domain_equality"
,
"check_MPI_equality"
,
"check_MPI_synced_random_state"
]
def
my_sum
(
iterable
):
...
...
@@ -428,3 +429,80 @@ def check_domain_equality(domain0, domain1):
f
"ift.MultiDomain nor of ift.DomainTuple.
\n
{
dom
}
"
)
if
domain0
!=
domain1
:
raise
ValueError
(
f
"Domain mismatch:
\n
{
domain0
}
\n
{
domain1
}
"
)
def
check_MPI_equality
(
obj
,
comm
):
"""Check that object is the same on all MPI tasks associated to a given
communicator.
Raises a RuntimeError if it differs.
Parameters
----------
obj :
Any Python object that implements __eq__.
comm : MPI communicator or None
If comm is None, no check will be performed
"""
# Special cases
if
comm
is
None
:
return
elif
isinstance
(
obj
,
list
):
_check_MPI_equality_lists
(
obj
,
comm
)
elif
isinstance
(
obj
,
np
.
random
.
SeedSequence
):
_check_MPI_equality_sseq
(
obj
,
comm
)
# /Special cases
else
:
if
not
_MPI_unique
(
obj
,
comm
):
raise
RuntimeError
(
"MPI tasks are not in sync"
)
def
_check_MPI_equality_lists
(
lst
,
comm
):
if
not
isinstance
(
lst
,
list
):
raise
TypeError
if
not
_MPI_unique
(
len
(
lst
),
comm
):
raise
RuntimeError
(
"MPI tasks are not in sync (lists have different lengths)"
)
is_sseq
=
comm
.
allgather
(
lst
[
0
])
if
is_sseq
[
0
]:
if
not
all
(
is_sseq
):
raise
RuntimeError
(
"First element in list is np.random.SeedSequence. The others (partly) not."
)
for
oo
in
lst
:
check_MPI_equality
(
oo
,
comm
)
return
for
ii
in
range
(
len
(
lst
)):
if
not
_MPI_unique
(
lst
[
ii
],
comm
):
raise
RuntimeError
(
f
"MPI tasks are not in sync (list element #
{
ii
}
does not match)"
)
def
_MPI_unique
(
obj
,
comm
):
return
len
(
set
(
comm
.
allgather
(
obj
)))
==
1
def
_check_MPI_equality_sseq
(
sseq
,
comm
):
from
.random
import
Context
,
spawn_sseq
,
current_rng
if
not
isinstance
(
sseq
,
np
.
random
.
SeedSequence
):
raise
TypeError
with
Context
(
spawn_sseq
(
1
,
parent
=
sseq
)[
0
]):
random_number
=
current_rng
().
normal
(
10.
,
1.2
,
(
1
,))[
0
]
gath
=
comm
.
allgather
(
random_number
)
if
gath
[
1
:]
!=
gath
[:
-
1
]:
raise
RuntimeError
(
"SeedSequences are not equal"
)
def
check_MPI_synced_random_state
(
comm
):
"""Check that random state is the same on all MPI tasks associated to a
given communicator.
Raises a RuntimeError if it differs.
Parameters
----------
comm : MPI communicator or None
If comm is None, no check will be performed
"""
from
.random
import
getState
if
comm
is
None
:
return
check_MPI_equality
(
getState
(),
comm
)
test/test_mpi/test_sync.py
0 → 100644
View file @
16827157
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import
pytest
from
mpi4py
import
MPI
import
nifty8
as
ift
from
..common
import
setup_function
,
teardown_function
comm
=
MPI
.
COMM_WORLD
ntask
=
comm
.
Get_size
()
rank
=
comm
.
Get_rank
()
master
=
(
rank
==
0
)
mpi
=
ntask
>
1
pmp
=
pytest
.
mark
.
parametrize
pms
=
pytest
.
mark
.
skipif
@
pms
(
not
mpi
,
reason
=
"requires at least two mpi tasks"
)
def
test_MPI_equality
():
obj
=
rank
with
pytest
.
raises
(
RuntimeError
):
ift
.
utilities
.
check_MPI_equality
(
obj
,
comm
)
obj
=
[
ii
+
rank
for
ii
in
range
(
10
,
12
)]
with
pytest
.
raises
(
RuntimeError
):
ift
.
utilities
.
check_MPI_equality
(
obj
,
comm
)
sseqs
=
ift
.
random
.
spawn_sseq
(
2
)
for
obj
in
[
12.
,
None
,
(
29
,
30
),
[
1
,
2
,
3
],
sseqs
[
0
],
sseqs
]:
ift
.
utilities
.
check_MPI_equality
(
obj
,
comm
)
obj
=
ift
.
random
.
spawn_sseq
(
2
,
parent
=
sseqs
[
comm
.
rank
])
with
pytest
.
raises
(
RuntimeError
):
ift
.
utilities
.
check_MPI_equality
(
obj
,
comm
)
@
pms
(
not
mpi
,
reason
=
"requires at least two mpi tasks"
)
def
test_MPI_synced_random_state
():
ift
.
utilities
.
check_MPI_synced_random_state
(
comm
)
if
master
:
ift
.
random
.
push_sseq_from_seed
(
123
)
with
pytest
.
raises
(
RuntimeError
):
ift
.
utilities
.
check_MPI_synced_random_state
(
comm
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment