Commit 3b706044 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add utilities

parent c4ffab5b
Pipeline #96487 passed with stages
in 22 minutes and 17 seconds
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -21,11 +21,14 @@ from itertools import product
import numpy as np
from .random import getState
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMeta", "my_sum", "my_lincomb_simple",
"my_lincomb", "indent",
"my_product", "frozendict", "special_add_at", "iscomplextype",
"value_reshaper", "lognormal_moments"]
"value_reshaper", "lognormal_moments",
"assert_rngs_synchronized"]
def my_sum(iterable):
......@@ -349,7 +352,6 @@ def allreduce_sum(obj, comm):
who = np.zeros(nobj, dtype=np.int32)
rank = 0
else:
ntask = comm.Get_size()
rank = comm.Get_rank()
nobj_list = comm.allgather(len(vals))
all_hi = list(np.cumsum(nobj_list))
......@@ -379,6 +381,17 @@ def allreduce_sum(obj, comm):
return comm.bcast(vals[0], root=who[0])
def assert_rngs_synchronized(comm):
if comm is None:
return
else:
lst = comm.allgather(getState)
ref = lst[0]
for ll in lst:
if ll != ref:
raise RuntimeError("Random states out of sync on MPI tasks")
def value_reshaper(x, N):
"""Produce arrays of shape `(N,)`.
If `x` is a scalar or array of length one, fill the target array with it.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment