From 5884a0ebd80a540b7ee27824f601ee63f4483c8a Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Thu, 19 Apr 2018 12:24:16 +0200
Subject: [PATCH] add sum_up flag to from_global_data()

---
 nifty4/data_objects/distributed_do.py |  6 ++++--
 nifty4/data_objects/numpy_do.py       |  2 +-
 nifty4/field.py                       | 11 +++++++----
 3 files changed, 12 insertions(+), 7 deletions(-)

diff --git a/nifty4/data_objects/distributed_do.py b/nifty4/data_objects/distributed_do.py
index 079414187..9085762cb 100644
--- a/nifty4/data_objects/distributed_do.py
+++ b/nifty4/data_objects/distributed_do.py
@@ -393,7 +393,9 @@ def from_local_data(shape, arr, distaxis=0):
     return data_object(shape, arr, distaxis)
 
 
-def from_global_data(arr, distaxis=0):
+def from_global_data(arr, sum_up=False, distaxis=0):
+    if sum_up:
+        arr = np_allreduce_sum(arr)
     if distaxis == -1:
         return data_object(arr.shape, arr, distaxis)
     lo, hi = _shareRange(arr.shape[distaxis], ntask, rank)
@@ -427,7 +429,7 @@ def redistribute(arr, dist=None, nodist=None):
                 break
 
     if arr._distaxis == -1:  # all data available, just pick the proper subset
-        return from_global_data(arr._data, dist)
+        return from_global_data(arr._data, distaxis=dist)
     if dist == -1:  # gather all data on all tasks
         tmp = np.moveaxis(arr._data, arr._distaxis, 0)
         slabsize = np.prod(tmp.shape[1:])*tmp.itemsize
diff --git a/nifty4/data_objects/numpy_do.py b/nifty4/data_objects/numpy_do.py
index e4324f49e..9e2c7354f 100644
--- a/nifty4/data_objects/numpy_do.py
+++ b/nifty4/data_objects/numpy_do.py
@@ -80,7 +80,7 @@ def from_local_data(shape, arr, distaxis=-1):
     return arr
 
 
-def from_global_data(arr, distaxis=-1):
+def from_global_data(arr, sum_up=False, distaxis=-1):
     return arr
 
 
diff --git a/nifty4/field.py b/nifty4/field.py
index 97d80934d..2f6d30f47 100644
--- a/nifty4/field.py
+++ b/nifty4/field.py
@@ -155,7 +155,7 @@ class Field(object):
         return Field.empty(field._domain, dtype)
 
     @staticmethod
-    def from_global_data(domain, arr):
+    def from_global_data(domain, arr, sum_up=False):
         """Returns a Field constructed from `domain` and `arr`.
 
         Parameters
@@ -165,10 +165,13 @@ class Field(object):
         arr : numpy.ndarray
             The data content to be used for the new Field.
             Its shape must match the shape of `domain`.
-            If MPI is active, the contents of `arr` must be the same on all
-            MPI tasks.
+        sum_up : bool, optional
+            If True, the contents of `arr` are summed up over all MPI tasks
+            (if any), and the sum is used as data content.
+            If False, the contens of `arr` are used directly, and must be
+            identical on all MPI tasks.
         """
-        return Field(domain, dobj.from_global_data(arr))
+        return Field(domain, dobj.from_global_data(arr, sum_up))
 
     @staticmethod
     def from_local_data(domain, arr):
-- 
GitLab