Commit f53b3afa authored by Martin Reinecke's avatar Martin Reinecke
Browse files

progress

parent 153d1ce1
Pipeline #21248 passed with stage
in 4 minutes and 16 seconds
......@@ -10,6 +10,12 @@ rank = comm.Get_rank()
def shareSize(nwork, nshares, myshare):
nbase = nwork//nshares
return nbase if myshare>=nwork%nshares else nbase+1
def shareRange(nwork, nshares, myshare):
nbase = nwork//nshares;
additional = nwork%nshares;
lo = myshare*nbase + min(myshare, additional)
hi = lo+nbase+ (1 if myshare<additional else 0)
return lo,hi
def get_locshape(shape, distaxis):
if distaxis==-1:
......@@ -37,10 +43,10 @@ class data_object(object):
raise ValueError
# check whether the global shape is consistent
itmp=np.array(self._shape)
otmp=np.empty((len(self._shape),ntask),dtype=np.int)
otmp=np.empty((ntask,len(self._shape)),dtype=np.int)
comm.Allgather(itmp,otmp)
for i in range(ntask):
if (otmp[i,:]!=self._shape).any():
if np.any(otmp[i,:]!=self._shape):
raise ValueError
# check shape of local data
if self._distaxis<0:
......@@ -48,8 +54,8 @@ class data_object(object):
raise ValueError
else:
itmp=np.array(self._shape)
itmp[self._distaxis] = get_local_length(self._shape[self._distaxis],ntask,rank)
if self._data.shape!=itmp:
itmp[self._distaxis] = shareSize(self._shape[self._distaxis],ntask,rank)
if np.any(self._data.shape!=itmp):
raise ValueError
@property
......@@ -66,27 +72,30 @@ class data_object(object):
@property
def real(self):
return data_object(self._shape, self._data.real, self._dist_axis)
return data_object(self._shape, self._data.real, self._distaxis)
@property
def imag(self):
return data_object(self._shape, self._data.imag, self._dist_axis)
return data_object(self._shape, self._data.imag, self._distaxis)
def _contraction_helper(self, op, axis):
def _contraction_helper(self, op, mpiop, axis):
if axis is not None:
if len(axis)==len(self._data.shape):
axis = None
if axis is None:
res = getattr(self._data, op)()
res = np.array(getattr(self._data, op)())
if (self._distaxis==-1):
return res
res2 = np.empty(1,dtype=res.dtype)
MPI.COMM_WORLD.Allreduce(res,res2,mpiop)
return res2[0]
if self._distaxis in axis:
pass# reduce globally, redistribute the result along axis 0(?)
else:
pass# reduce locally
# perform the contraction on the data
data = getattr(self._data, op)(axis=axis)
# perform the contraction on the local data
data = getattr(self._data, op)(axis=axis)
#shp =
# check if the result is scalar or if a result_field must be constr.
if np.isscalar(data):
......@@ -98,14 +107,17 @@ class data_object(object):
return self._contraction_helper("sum", MPI.SUM, axis)
def _binary_helper(self, other, op):
a = self._data
a = self
if isinstance(other, data_object):
b = other._data
b = other
if a._shape != b._shape:
raise ValueError("shapes are incompatible.")
if a._distaxis != b._distaxis:
raise ValueError("distributions are incompatible.")
a = a._data
b = b._data
else:
a = a._data
b = other
tval = getattr(a, op)(b)
......@@ -184,8 +196,8 @@ class data_object(object):
return self._data.any()
def full(shape, fill_value, dtype=None, dist_axis=0):
return data_object(shape, np.full(shape, local_shape(shape, dist_axis), fill_value, dtype))
def full(shape, fill_value, dtype=None, distaxis=0):
return data_object(shape, np.full(shape, local_shape(shape, distaxis), fill_value, dtype))
def empty(shape, dtype=np.float):
......@@ -254,3 +266,97 @@ def to_ndarray(arr):
def from_ndarray(arr):
return data_object(arr.shape,arr,-1)
def local_data(arr):
return arr._data
#def ibegin(arr):
# return (0,)*arr._data.ndim
#def np_allreduce_sum(arr):
# return arr
def distaxis(arr):
return arr._distaxis
def from_local_data (shape, arr, distaxis):
return data_object(shape, arr, distaxis)
def from_global_data (arr, distaxis=0):
if distaxis==-1:
return data_object(arr.shape, arr, distaxis)
lo, hi = shareRange(arr.shape[distaxis],ntask,rank)
sl = [slice(None)]*len(arr.shape)
sl[distaxis]=slice(lo,hi)
return data_object(arr.shape, arr[sl], distaxis)
def redistribute (arr, dist=None, nodist=None):
if dist is not None:
if nodist is not None:
raise ValueError
if dist==arr._distaxis:
return arr
else:
if nodist is None:
raise ValueError
if arr._distaxis not in nodist:
return arr
dist=-1
for i in range(len(arr.shape)):
if i not in nodist:
dist=i
break
if arr._distaxis==-1: # just pick the proper subset
return from_global_data(arr._data, dist)
if dist==-1: # gather data
tmp = np.moveaxis(arr._data, arr._distaxis, 0)
slabsize=np.prod(tmp.shape[1:])*tmp.itemsize
sz=np.empty(ntask,dtype=np.int)
for i in range(ntask):
sz[i]=slabsize*shareSize(arr.shape[arr._distaxis],ntask,i)
disp=np.empty(ntask,dtype=np.int)
disp[0]=0
disp[1:]=np.cumsum(sz[:-1])
tmp=tmp.flatten()
out = np.empty(arr.size,dtype=arr.dtype)
print tmp.shape, out.shape, sz, disp
comm.Allgatherv(tmp,[out,sz,disp,MPI.BYTE])
out = out.reshape(arr._shape)
out = np.moveaxis(out, 0, arr._distaxis)
return from_global_data (out, distaxis=-1)
# real redistribution via Alltoallv
tmp = np.moveaxis(arr._data, (dist, arr._distaxis), (0, 1))
tshape = tmp.shape
slabsize=np.prod(tmp.shape[2:])*tmp.itemsize
ssz=np.empty(ntask,dtype=np.int)
rsz=np.empty(ntask,dtype=np.int)
for i in range(ntask):
ssz[i]=slabsize*tmp.shape[1]*shareSize(arr.shape[dist],ntask,i)
rsz[i]=slabsize*shareSize(arr.shape[dist],ntask,rank)*shareSize(arr.shape[arr._distaxis],ntask,i)
sdisp=np.empty(ntask,dtype=np.int)
rdisp=np.empty(ntask,dtype=np.int)
sdisp[0]=0
rdisp[0]=0
sdisp[1:]=np.cumsum(ssz[:-1])
rdisp[1:]=np.cumsum(rsz[:-1])
print ssz, rsz
tmp=tmp.flatten()
out = np.empty(np.prod(get_locshape(arr.shape,dist)),dtype=arr.dtype)
s_msg = [tmp, (ssz, sdisp), MPI.BYTE]
r_msg = [out, (rsz, rdisp), MPI.BYTE]
comm.Alltoallv(s_msg, r_msg)
new_shape = [shareSize(arr.shape[dist],ntask,rank), arr.shape[arr._distaxis]] +list(tshape[2:])
out=out.reshape(new_shape)
out = np.moveaxis(out, (0, 1), (dist, arr._distaxis))
return from_local_data (arr.shape, out, dist)
def default_distaxis():
return 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment