Commit f53b3afa authored by Martin Reinecke's avatar Martin Reinecke
Browse files


parent 153d1ce1
Pipeline #21248 passed with stage
in 4 minutes and 16 seconds
...@@ -10,6 +10,12 @@ rank = comm.Get_rank() ...@@ -10,6 +10,12 @@ rank = comm.Get_rank()
def shareSize(nwork, nshares, myshare): def shareSize(nwork, nshares, myshare):
nbase = nwork//nshares nbase = nwork//nshares
return nbase if myshare>=nwork%nshares else nbase+1 return nbase if myshare>=nwork%nshares else nbase+1
def shareRange(nwork, nshares, myshare):
nbase = nwork//nshares;
additional = nwork%nshares;
lo = myshare*nbase + min(myshare, additional)
hi = lo+nbase+ (1 if myshare<additional else 0)
return lo,hi
def get_locshape(shape, distaxis): def get_locshape(shape, distaxis):
if distaxis==-1: if distaxis==-1:
...@@ -37,10 +43,10 @@ class data_object(object): ...@@ -37,10 +43,10 @@ class data_object(object):
raise ValueError raise ValueError
# check whether the global shape is consistent # check whether the global shape is consistent
itmp=np.array(self._shape) itmp=np.array(self._shape)
otmp=np.empty((len(self._shape),ntask), otmp=np.empty((ntask,len(self._shape)),
comm.Allgather(itmp,otmp) comm.Allgather(itmp,otmp)
for i in range(ntask): for i in range(ntask):
if (otmp[i,:]!=self._shape).any(): if np.any(otmp[i,:]!=self._shape):
raise ValueError raise ValueError
# check shape of local data # check shape of local data
if self._distaxis<0: if self._distaxis<0:
...@@ -48,8 +54,8 @@ class data_object(object): ...@@ -48,8 +54,8 @@ class data_object(object):
raise ValueError raise ValueError
else: else:
itmp=np.array(self._shape) itmp=np.array(self._shape)
itmp[self._distaxis] = get_local_length(self._shape[self._distaxis],ntask,rank) itmp[self._distaxis] = shareSize(self._shape[self._distaxis],ntask,rank)
if self._data.shape!=itmp: if np.any(self._data.shape!=itmp):
raise ValueError raise ValueError
@property @property
...@@ -66,27 +72,30 @@ class data_object(object): ...@@ -66,27 +72,30 @@ class data_object(object):
@property @property
def real(self): def real(self):
return data_object(self._shape, self._data.real, self._dist_axis) return data_object(self._shape, self._data.real, self._distaxis)
@property @property
def imag(self): def imag(self):
return data_object(self._shape, self._data.imag, self._dist_axis) return data_object(self._shape, self._data.imag, self._distaxis)
def _contraction_helper(self, op, axis): def _contraction_helper(self, op, mpiop, axis):
if axis is not None: if axis is not None:
if len(axis)==len(self._data.shape): if len(axis)==len(self._data.shape):
axis = None axis = None
if axis is None: if axis is None:
res = getattr(self._data, op)() res = np.array(getattr(self._data, op)())
if (self._distaxis==-1):
return res
res2 = np.empty(1,dtype=res.dtype)
MPI.COMM_WORLD.Allreduce(res,res2,mpiop) MPI.COMM_WORLD.Allreduce(res,res2,mpiop)
return res2[0]
if self._distaxis in axis: if self._distaxis in axis:
pass# reduce globally, redistribute the result along axis 0(?) pass# reduce globally, redistribute the result along axis 0(?)
else: else:
pass# reduce locally # perform the contraction on the local data
data = getattr(self._data, op)(axis=axis)
# perform the contraction on the data #shp =
data = getattr(self._data, op)(axis=axis)
# check if the result is scalar or if a result_field must be constr. # check if the result is scalar or if a result_field must be constr.
if np.isscalar(data): if np.isscalar(data):
...@@ -98,14 +107,17 @@ class data_object(object): ...@@ -98,14 +107,17 @@ class data_object(object):
return self._contraction_helper("sum", MPI.SUM, axis) return self._contraction_helper("sum", MPI.SUM, axis)
def _binary_helper(self, other, op): def _binary_helper(self, other, op):
a = self._data a = self
if isinstance(other, data_object): if isinstance(other, data_object):
b = other._data b = other
if a._shape != b._shape: if a._shape != b._shape:
raise ValueError("shapes are incompatible.") raise ValueError("shapes are incompatible.")
if a._distaxis != b._distaxis: if a._distaxis != b._distaxis:
raise ValueError("distributions are incompatible.") raise ValueError("distributions are incompatible.")
a = a._data
b = b._data
else: else:
a = a._data
b = other b = other
tval = getattr(a, op)(b) tval = getattr(a, op)(b)
...@@ -184,8 +196,8 @@ class data_object(object): ...@@ -184,8 +196,8 @@ class data_object(object):
return self._data.any() return self._data.any()
def full(shape, fill_value, dtype=None, dist_axis=0): def full(shape, fill_value, dtype=None, distaxis=0):
return data_object(shape, np.full(shape, local_shape(shape, dist_axis), fill_value, dtype)) return data_object(shape, np.full(shape, local_shape(shape, distaxis), fill_value, dtype))
def empty(shape, dtype=np.float): def empty(shape, dtype=np.float):
...@@ -254,3 +266,97 @@ def to_ndarray(arr): ...@@ -254,3 +266,97 @@ def to_ndarray(arr):
def from_ndarray(arr): def from_ndarray(arr):
return data_object(arr.shape,arr,-1) return data_object(arr.shape,arr,-1)
def local_data(arr):
return arr._data
#def ibegin(arr):
# return (0,)*arr._data.ndim
#def np_allreduce_sum(arr):
# return arr
def distaxis(arr):
return arr._distaxis
def from_local_data (shape, arr, distaxis):
return data_object(shape, arr, distaxis)
def from_global_data (arr, distaxis=0):
if distaxis==-1:
return data_object(arr.shape, arr, distaxis)
lo, hi = shareRange(arr.shape[distaxis],ntask,rank)
sl = [slice(None)]*len(arr.shape)
return data_object(arr.shape, arr[sl], distaxis)
def redistribute (arr, dist=None, nodist=None):
if dist is not None:
if nodist is not None:
raise ValueError
if dist==arr._distaxis:
return arr
if nodist is None:
raise ValueError
if arr._distaxis not in nodist:
return arr
for i in range(len(arr.shape)):
if i not in nodist:
if arr._distaxis==-1: # just pick the proper subset
return from_global_data(arr._data, dist)
if dist==-1: # gather data
tmp = np.moveaxis(arr._data, arr._distaxis, 0)[1:])*tmp.itemsize
for i in range(ntask):
out = np.empty(arr.size,dtype=arr.dtype)
print tmp.shape, out.shape, sz, disp
out = out.reshape(arr._shape)
out = np.moveaxis(out, 0, arr._distaxis)
return from_global_data (out, distaxis=-1)
# real redistribution via Alltoallv
tmp = np.moveaxis(arr._data, (dist, arr._distaxis), (0, 1))
tshape = tmp.shape[2:])*tmp.itemsize
for i in range(ntask):
print ssz, rsz
out = np.empty(,dist)),dtype=arr.dtype)
s_msg = [tmp, (ssz, sdisp), MPI.BYTE]
r_msg = [out, (rsz, rdisp), MPI.BYTE]
comm.Alltoallv(s_msg, r_msg)
new_shape = [shareSize(arr.shape[dist],ntask,rank), arr.shape[arr._distaxis]] +list(tshape[2:])
out = np.moveaxis(out, (0, 1), (dist, arr._distaxis))
return from_local_data (arr.shape, out, dist)
def default_distaxis():
return 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment