From 5884a0ebd80a540b7ee27824f601ee63f4483c8a Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Thu, 19 Apr 2018 12:24:16 +0200 Subject: [PATCH] add sum_up flag to from_global_data() --- nifty4/data_objects/distributed_do.py | 6 ++++-- nifty4/data_objects/numpy_do.py | 2 +- nifty4/field.py | 11 +++++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/nifty4/data_objects/distributed_do.py b/nifty4/data_objects/distributed_do.py index 079414187..9085762cb 100644 --- a/nifty4/data_objects/distributed_do.py +++ b/nifty4/data_objects/distributed_do.py @@ -393,7 +393,9 @@ def from_local_data(shape, arr, distaxis=0): return data_object(shape, arr, distaxis) -def from_global_data(arr, distaxis=0): +def from_global_data(arr, sum_up=False, distaxis=0): + if sum_up: + arr = np_allreduce_sum(arr) if distaxis == -1: return data_object(arr.shape, arr, distaxis) lo, hi = _shareRange(arr.shape[distaxis], ntask, rank) @@ -427,7 +429,7 @@ def redistribute(arr, dist=None, nodist=None): break if arr._distaxis == -1: # all data available, just pick the proper subset - return from_global_data(arr._data, dist) + return from_global_data(arr._data, distaxis=dist) if dist == -1: # gather all data on all tasks tmp = np.moveaxis(arr._data, arr._distaxis, 0) slabsize = np.prod(tmp.shape[1:])*tmp.itemsize diff --git a/nifty4/data_objects/numpy_do.py b/nifty4/data_objects/numpy_do.py index e4324f49e..9e2c7354f 100644 --- a/nifty4/data_objects/numpy_do.py +++ b/nifty4/data_objects/numpy_do.py @@ -80,7 +80,7 @@ def from_local_data(shape, arr, distaxis=-1): return arr -def from_global_data(arr, distaxis=-1): +def from_global_data(arr, sum_up=False, distaxis=-1): return arr diff --git a/nifty4/field.py b/nifty4/field.py index 97d80934d..2f6d30f47 100644 --- a/nifty4/field.py +++ b/nifty4/field.py @@ -155,7 +155,7 @@ class Field(object): return Field.empty(field._domain, dtype) @staticmethod - def from_global_data(domain, arr): + def from_global_data(domain, arr, sum_up=False): """Returns a Field constructed from `domain` and `arr`. Parameters @@ -165,10 +165,13 @@ class Field(object): arr : numpy.ndarray The data content to be used for the new Field. Its shape must match the shape of `domain`. - If MPI is active, the contents of `arr` must be the same on all - MPI tasks. + sum_up : bool, optional + If True, the contents of `arr` are summed up over all MPI tasks + (if any), and the sum is used as data content. + If False, the contens of `arr` are used directly, and must be + identical on all MPI tasks. """ - return Field(domain, dobj.from_global_data(arr)) + return Field(domain, dobj.from_global_data(arr, sum_up)) @staticmethod def from_local_data(domain, arr): -- GitLab