diff --git a/nifty4/data_objects/distributed_do.py b/nifty4/data_objects/distributed_do.py index 079414187ec190f78f4e5b4445f92fa25c458ab7..9085762cbfbd0804c28107459fc3ba73148b32d1 100644 --- a/nifty4/data_objects/distributed_do.py +++ b/nifty4/data_objects/distributed_do.py @@ -393,7 +393,9 @@ def from_local_data(shape, arr, distaxis=0): return data_object(shape, arr, distaxis) -def from_global_data(arr, distaxis=0): +def from_global_data(arr, sum_up=False, distaxis=0): + if sum_up: + arr = np_allreduce_sum(arr) if distaxis == -1: return data_object(arr.shape, arr, distaxis) lo, hi = _shareRange(arr.shape[distaxis], ntask, rank) @@ -427,7 +429,7 @@ def redistribute(arr, dist=None, nodist=None): break if arr._distaxis == -1: # all data available, just pick the proper subset - return from_global_data(arr._data, dist) + return from_global_data(arr._data, distaxis=dist) if dist == -1: # gather all data on all tasks tmp = np.moveaxis(arr._data, arr._distaxis, 0) slabsize = np.prod(tmp.shape[1:])*tmp.itemsize diff --git a/nifty4/data_objects/numpy_do.py b/nifty4/data_objects/numpy_do.py index e4324f49eb6795ae8b26fdf43fcc1ac4733efce9..9e2c7354fee94d52cfdbac9ec5c61d1bbb1bad5a 100644 --- a/nifty4/data_objects/numpy_do.py +++ b/nifty4/data_objects/numpy_do.py @@ -80,7 +80,7 @@ def from_local_data(shape, arr, distaxis=-1): return arr -def from_global_data(arr, distaxis=-1): +def from_global_data(arr, sum_up=False, distaxis=-1): return arr diff --git a/nifty4/field.py b/nifty4/field.py index 97d80934df4048d938d936382a39fcc4f5330d07..2f6d30f478ee730c07b3a7252845a20ecd7a32bb 100644 --- a/nifty4/field.py +++ b/nifty4/field.py @@ -155,7 +155,7 @@ class Field(object): return Field.empty(field._domain, dtype) @staticmethod - def from_global_data(domain, arr): + def from_global_data(domain, arr, sum_up=False): """Returns a Field constructed from `domain` and `arr`. Parameters @@ -165,10 +165,13 @@ class Field(object): arr : numpy.ndarray The data content to be used for the new Field. Its shape must match the shape of `domain`. - If MPI is active, the contents of `arr` must be the same on all - MPI tasks. + sum_up : bool, optional + If True, the contents of `arr` are summed up over all MPI tasks + (if any), and the sum is used as data content. + If False, the contens of `arr` are used directly, and must be + identical on all MPI tasks. """ - return Field(domain, dobj.from_global_data(arr)) + return Field(domain, dobj.from_global_data(arr, sum_up)) @staticmethod def from_local_data(domain, arr):