Commit 5884a0eb authored by Martin Reinecke's avatar Martin Reinecke
Browse files

add sum_up flag to from_global_data()

parent 93648c24
Pipeline #27628 passed with stages
in 20 minutes and 11 seconds
...@@ -393,7 +393,9 @@ def from_local_data(shape, arr, distaxis=0): ...@@ -393,7 +393,9 @@ def from_local_data(shape, arr, distaxis=0):
return data_object(shape, arr, distaxis) return data_object(shape, arr, distaxis)
def from_global_data(arr, distaxis=0): def from_global_data(arr, sum_up=False, distaxis=0):
if sum_up:
arr = np_allreduce_sum(arr)
if distaxis == -1: if distaxis == -1:
return data_object(arr.shape, arr, distaxis) return data_object(arr.shape, arr, distaxis)
lo, hi = _shareRange(arr.shape[distaxis], ntask, rank) lo, hi = _shareRange(arr.shape[distaxis], ntask, rank)
...@@ -427,7 +429,7 @@ def redistribute(arr, dist=None, nodist=None): ...@@ -427,7 +429,7 @@ def redistribute(arr, dist=None, nodist=None):
break break
if arr._distaxis == -1: # all data available, just pick the proper subset if arr._distaxis == -1: # all data available, just pick the proper subset
return from_global_data(arr._data, dist) return from_global_data(arr._data, distaxis=dist)
if dist == -1: # gather all data on all tasks if dist == -1: # gather all data on all tasks
tmp = np.moveaxis(arr._data, arr._distaxis, 0) tmp = np.moveaxis(arr._data, arr._distaxis, 0)
slabsize = np.prod(tmp.shape[1:])*tmp.itemsize slabsize = np.prod(tmp.shape[1:])*tmp.itemsize
......
...@@ -80,7 +80,7 @@ def from_local_data(shape, arr, distaxis=-1): ...@@ -80,7 +80,7 @@ def from_local_data(shape, arr, distaxis=-1):
return arr return arr
def from_global_data(arr, distaxis=-1): def from_global_data(arr, sum_up=False, distaxis=-1):
return arr return arr
......
...@@ -155,7 +155,7 @@ class Field(object): ...@@ -155,7 +155,7 @@ class Field(object):
return Field.empty(field._domain, dtype) return Field.empty(field._domain, dtype)
@staticmethod @staticmethod
def from_global_data(domain, arr): def from_global_data(domain, arr, sum_up=False):
"""Returns a Field constructed from `domain` and `arr`. """Returns a Field constructed from `domain` and `arr`.
Parameters Parameters
...@@ -165,10 +165,13 @@ class Field(object): ...@@ -165,10 +165,13 @@ class Field(object):
arr : numpy.ndarray arr : numpy.ndarray
The data content to be used for the new Field. The data content to be used for the new Field.
Its shape must match the shape of `domain`. Its shape must match the shape of `domain`.
If MPI is active, the contents of `arr` must be the same on all sum_up : bool, optional
MPI tasks. If True, the contents of `arr` are summed up over all MPI tasks
(if any), and the sum is used as data content.
If False, the contens of `arr` are used directly, and must be
identical on all MPI tasks.
""" """
return Field(domain, dobj.from_global_data(arr)) return Field(domain, dobj.from_global_data(arr, sum_up))
@staticmethod @staticmethod
def from_local_data(domain, arr): def from_local_data(domain, arr):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment