Commit 6a59fbc8 authored by Martin Reinecke's avatar Martin Reinecke

small enhancements for data objects

parent e7b0e33b
Pipeline #23798 passed with stage
in 4 minutes and 38 seconds
......@@ -44,7 +44,7 @@ def _shareRange(nwork, nshares, myshare):
return lo, hi
def local_shape(shape, distaxis):
def local_shape(shape, distaxis=0):
if len(shape) == 0 or distaxis == -1:
return shape
shape2 = list(shape)
......@@ -344,6 +344,14 @@ def local_data(arr):
return arr._data
def ibegin_from_shape(glob_shape, distaxis=0):
res = [0] * len(glob_shape)
if distaxis<0:
return res
res[distaxis] = _shareRange(glob_shape[distaxis], ntask, rank)[0]
return tuple(res)
def ibegin(arr):
res = [0] * arr._data.ndim
res[arr._distaxis] = _shareRange(arr._shape[arr._distaxis], ntask, rank)[0]
......
......@@ -47,6 +47,10 @@ def local_data(arr):
return arr
def ibegin_from_shape(glob_shape, distaxis=-1):
return (0,)*len(glob_shape)
def ibegin(arr):
return (0,)*arr.ndim
......@@ -81,5 +85,5 @@ def default_distaxis():
return -1
def local_shape(glob_shape, distaxis):
def local_shape(glob_shape, distaxis=-1):
return glob_shape
......@@ -31,6 +31,6 @@ except ImportError:
__all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"empty", "zeros", "ones", "empty_like", "vdot", "abs", "exp",
"log", "tanh", "sqrt", "from_object", "from_random",
"local_data", "ibegin", "np_allreduce_sum", "distaxis",
"from_local_data", "from_global_data", "to_global_data",
"local_data", "ibegin", "ibegin_from_shape", "np_allreduce_sum",
"distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "mprint"]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment