Commit 1909f6ca authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweaks

parent f872892a
...@@ -2,15 +2,17 @@ import numpy as np ...@@ -2,15 +2,17 @@ import numpy as np
from .random import Random from .random import Random
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD _comm = MPI.COMM_WORLD
ntask = comm.Get_size() ntask = _comm.Get_size()
rank = comm.Get_rank() rank = _comm.Get_rank()
master = rank==0
def shareSize(nwork, nshares, myshare): def _shareSize(nwork, nshares, myshare):
nbase = nwork//nshares nbase = nwork//nshares
return nbase if myshare>=nwork%nshares else nbase+1 return nbase if myshare>=nwork%nshares else nbase+1
def shareRange(nwork, nshares, myshare):
def _shareRange(nwork, nshares, myshare):
nbase = nwork//nshares; nbase = nwork//nshares;
additional = nwork%nshares; additional = nwork%nshares;
lo = myshare*nbase + min(myshare, additional) lo = myshare*nbase + min(myshare, additional)
...@@ -23,7 +25,7 @@ def local_shape(shape, distaxis): ...@@ -23,7 +25,7 @@ def local_shape(shape, distaxis):
if distaxis==-1: if distaxis==-1:
return shape return shape
shape2=list(shape) shape2=list(shape)
shape2[distaxis]=shareSize(shape[distaxis],ntask,rank) shape2[distaxis]=_shareSize(shape[distaxis],ntask,rank)
return tuple(shape2) return tuple(shape2)
class data_object(object): class data_object(object):
...@@ -36,19 +38,19 @@ class data_object(object): ...@@ -36,19 +38,19 @@ class data_object(object):
lshape = local_shape(self._shape, self._distaxis) lshape = local_shape(self._shape, self._distaxis)
self._data = data self._data = data
def sanity_checks(self): def _sanity_checks(self):
# check whether the distaxis is consistent # check whether the distaxis is consistent
if self._distaxis<-1 or self._distaxis>=len(self._shape): if self._distaxis<-1 or self._distaxis>=len(self._shape):
raise ValueError raise ValueError
itmp=np.array(self._distaxis) itmp=np.array(self._distaxis)
otmp=np.empty(ntask,dtype=np.int) otmp=np.empty(ntask,dtype=np.int)
comm.Allgather(itmp,otmp) _comm.Allgather(itmp,otmp)
if np.any(otmp!=self._distaxis): if np.any(otmp!=self._distaxis):
raise ValueError raise ValueError
# check whether the global shape is consistent # check whether the global shape is consistent
itmp=np.array(self._shape) itmp=np.array(self._shape)
otmp=np.empty((ntask,len(self._shape)),dtype=np.int) otmp=np.empty((ntask,len(self._shape)),dtype=np.int)
comm.Allgather(itmp,otmp) _comm.Allgather(itmp,otmp)
for i in range(ntask): for i in range(ntask):
if np.any(otmp[i,:]!=self._shape): if np.any(otmp[i,:]!=self._shape):
raise ValueError raise ValueError
...@@ -58,7 +60,7 @@ class data_object(object): ...@@ -58,7 +60,7 @@ class data_object(object):
raise ValueError raise ValueError
else: else:
itmp=np.array(self._shape) itmp=np.array(self._shape)
itmp[self._distaxis] = shareSize(self._shape[self._distaxis],ntask,rank) itmp[self._distaxis] = _shareSize(self._shape[self._distaxis],ntask,rank)
if np.any(self._data.shape!=itmp): if np.any(self._data.shape!=itmp):
raise ValueError raise ValueError
...@@ -91,13 +93,13 @@ class data_object(object): ...@@ -91,13 +93,13 @@ class data_object(object):
if (self._distaxis==-1): if (self._distaxis==-1):
return res[0] return res[0]
res2 = np.empty(1,dtype=res.dtype) res2 = np.empty(1,dtype=res.dtype)
comm.Allreduce(res,res2,mpiop) _comm.Allreduce(res,res2,mpiop)
return res2[0] return res2[0]
if self._distaxis in axis: if self._distaxis in axis:
res = getattr(self._data, op)(axis=axis) res = getattr(self._data, op)(axis=axis)
res2 = np.empty_like(res) res2 = np.empty_like(res)
comm.Allreduce(res,res2,mpiop) _comm.Allreduce(res,res2,mpiop)
return from_global_data(res2, distaxis=0) return from_global_data(res2, distaxis=0)
else: else:
# perform the contraction on the local data # perform the contraction on the local data
...@@ -209,12 +211,6 @@ class data_object(object): ...@@ -209,12 +211,6 @@ class data_object(object):
def __abs__(self): def __abs__(self):
return data_object(self._shape,np.abs(self._data),self._distaxis) return data_object(self._shape,np.abs(self._data),self._distaxis)
#def ravel(self):
# return data_object(self._data.ravel())
#def reshape(self, shape):
# return data_object(self._data.reshape(shape))
def all(self): def all(self):
return self._data.all() return self._data.all()
...@@ -243,9 +239,9 @@ def empty_like(a, dtype=None): ...@@ -243,9 +239,9 @@ def empty_like(a, dtype=None):
def vdot(a, b): def vdot(a, b):
tmp = np.vdot(a._data.ravel(), b._data.ravel()) tmp = np.vdot(a._data, b._data)
res = np.empty(1,dtype=type(tmp)) res = np.empty(1,dtype=type(tmp))
comm.Allreduce(tmp,res,MPI.SUM) _comm.Allreduce(tmp,res,MPI.SUM)
return res[0] return res[0]
...@@ -296,13 +292,13 @@ def local_data(arr): ...@@ -296,13 +292,13 @@ def local_data(arr):
def ibegin(arr): def ibegin(arr):
res = [0] * arr._data.ndim res = [0] * arr._data.ndim
res[arr._distaxis] = shareRange(arr._shape[arr._distaxis],ntask,rank)[0] res[arr._distaxis] = _shareRange(arr._shape[arr._distaxis],ntask,rank)[0]
return tuple(res) return tuple(res)
def np_allreduce_sum(arr): def np_allreduce_sum(arr):
res = np.empty_like(arr) res = np.empty_like(arr)
comm.Allreduce(arr,res,MPI.SUM) _comm.Allreduce(arr,res,MPI.SUM)
return res return res
...@@ -317,7 +313,7 @@ def from_local_data (shape, arr, distaxis): ...@@ -317,7 +313,7 @@ def from_local_data (shape, arr, distaxis):
def from_global_data (arr, distaxis=0): def from_global_data (arr, distaxis=0):
if distaxis==-1: if distaxis==-1:
return data_object(arr.shape, arr, distaxis) return data_object(arr.shape, arr, distaxis)
lo, hi = shareRange(arr.shape[distaxis],ntask,rank) lo, hi = _shareRange(arr.shape[distaxis],ntask,rank)
sl = [slice(None)]*len(arr.shape) sl = [slice(None)]*len(arr.shape)
sl[distaxis]=slice(lo,hi) sl[distaxis]=slice(lo,hi)
return data_object(arr.shape, arr[sl], distaxis) return data_object(arr.shape, arr[sl], distaxis)
...@@ -354,13 +350,13 @@ def redistribute (arr, dist=None, nodist=None): ...@@ -354,13 +350,13 @@ def redistribute (arr, dist=None, nodist=None):
slabsize=np.prod(tmp.shape[1:])*tmp.itemsize slabsize=np.prod(tmp.shape[1:])*tmp.itemsize
sz=np.empty(ntask,dtype=np.int) sz=np.empty(ntask,dtype=np.int)
for i in range(ntask): for i in range(ntask):
sz[i]=slabsize*shareSize(arr.shape[arr._distaxis],ntask,i) sz[i]=slabsize*_shareSize(arr.shape[arr._distaxis],ntask,i)
disp=np.empty(ntask,dtype=np.int) disp=np.empty(ntask,dtype=np.int)
disp[0]=0 disp[0]=0
disp[1:]=np.cumsum(sz[:-1]) disp[1:]=np.cumsum(sz[:-1])
tmp=tmp.flatten() tmp=tmp.flatten()
out = np.empty(arr.size,dtype=arr.dtype) out = np.empty(arr.size,dtype=arr.dtype)
comm.Allgatherv(tmp,[out,sz,disp,MPI.BYTE]) _comm.Allgatherv(tmp,[out,sz,disp,MPI.BYTE])
shp = np.array(arr._shape) shp = np.array(arr._shape)
shp[1:arr._distaxis+1] = shp[0:arr._distaxis] shp[1:arr._distaxis+1] = shp[0:arr._distaxis]
shp[0] = arr.shape[arr._distaxis] shp[0] = arr.shape[arr._distaxis]
...@@ -377,8 +373,8 @@ def redistribute (arr, dist=None, nodist=None): ...@@ -377,8 +373,8 @@ def redistribute (arr, dist=None, nodist=None):
ssz=np.empty(ntask,dtype=np.int) ssz=np.empty(ntask,dtype=np.int)
rsz=np.empty(ntask,dtype=np.int) rsz=np.empty(ntask,dtype=np.int)
for i in range(ntask): for i in range(ntask):
ssz[i]=shareSize(arr.shape[dist],ntask,i)*tmp.shape[1]*slabsize ssz[i]=_shareSize(arr.shape[dist],ntask,i)*tmp.shape[1]*slabsize
rsz[i]=shareSize(arr.shape[dist],ntask,rank)*shareSize(arr.shape[arr._distaxis],ntask,i)*slabsize rsz[i]=_shareSize(arr.shape[dist],ntask,rank)*_shareSize(arr.shape[arr._distaxis],ntask,i)*slabsize
sdisp=np.empty(ntask,dtype=np.int) sdisp=np.empty(ntask,dtype=np.int)
rdisp=np.empty(ntask,dtype=np.int) rdisp=np.empty(ntask,dtype=np.int)
sdisp[0]=0 sdisp[0]=0
...@@ -389,15 +385,15 @@ def redistribute (arr, dist=None, nodist=None): ...@@ -389,15 +385,15 @@ def redistribute (arr, dist=None, nodist=None):
out = np.empty(np.prod(local_shape(arr.shape,dist)),dtype=arr.dtype) out = np.empty(np.prod(local_shape(arr.shape,dist)),dtype=arr.dtype)
s_msg = [tmp, (ssz, sdisp), MPI.BYTE] s_msg = [tmp, (ssz, sdisp), MPI.BYTE]
r_msg = [out, (rsz, rdisp), MPI.BYTE] r_msg = [out, (rsz, rdisp), MPI.BYTE]
comm.Alltoallv(s_msg, r_msg) _comm.Alltoallv(s_msg, r_msg)
out2 = np.empty([shareSize(arr.shape[dist],ntask,rank), arr.shape[arr._distaxis]] +list(tshape[2:]), dtype=arr.dtype) out2 = np.empty([_shareSize(arr.shape[dist],ntask,rank), arr.shape[arr._distaxis]] +list(tshape[2:]), dtype=arr.dtype)
ofs=0 ofs=0
for i in range(ntask): for i in range(ntask):
lsize = rsz[i]//tmp.itemsize lsize = rsz[i]//tmp.itemsize
lo,hi = shareRange(arr.shape[arr._distaxis],ntask,i) lo,hi = _shareRange(arr.shape[arr._distaxis],ntask,i)
out2[slice(None),slice(lo,hi)] = out[ofs:ofs+lsize].reshape([shareSize(arr.shape[dist],ntask,rank),shareSize(arr.shape[arr._distaxis],ntask,i)]+list(tshape[2:])) out2[slice(None),slice(lo,hi)] = out[ofs:ofs+lsize].reshape([_shareSize(arr.shape[dist],ntask,rank),_shareSize(arr.shape[arr._distaxis],ntask,i)]+list(tshape[2:]))
ofs += lsize ofs += lsize
new_shape = [shareSize(arr.shape[dist],ntask,rank), arr.shape[arr._distaxis]] +list(tshape[2:]) new_shape = [_shareSize(arr.shape[dist],ntask,rank), arr.shape[arr._distaxis]] +list(tshape[2:])
out2=out2.reshape(new_shape) out2=out2.reshape(new_shape)
out2 = np.moveaxis(out2, (0, 1), (dist, arr._distaxis)) out2 = np.moveaxis(out2, (0, 1), (dist, arr._distaxis))
return from_local_data (arr.shape, out2, dist) return from_local_data (arr.shape, out2, dist)
......
...@@ -104,6 +104,9 @@ class data_object(object): ...@@ -104,6 +104,9 @@ class data_object(object):
def __rdiv__(self, other): def __rdiv__(self, other):
return self._binary_helper(other, op='__rdiv__') return self._binary_helper(other, op='__rdiv__')
def __idiv__(self, other):
return self._binary_helper(other, op='__idiv__')
def __truediv__(self, other): def __truediv__(self, other):
return self._binary_helper(other, op='__truediv__') return self._binary_helper(other, op='__truediv__')
...@@ -131,12 +134,6 @@ class data_object(object): ...@@ -131,12 +134,6 @@ class data_object(object):
def __abs__(self): def __abs__(self):
return data_object(np.abs(self._data)) return data_object(np.abs(self._data))
def ravel(self):
return data_object(self._data.ravel())
def reshape(self, shape):
return data_object(self._data.reshape(shape))
def all(self): def all(self):
return self._data.all() return self._data.all()
......
def probe_operation(soperation, domain, nprobes,
random_type, dtype):
for i in range(nprobes):
f = Field.from_random(random_type=random_type, domain=domain,
dtype=dtype)
tmp = operator(f)
if i==0:
mean = [0]*len(tmp)
var = [0]*len(tmp)
for i in range(len(tmp)):
mean[i] += tmp[i]
var[i] += tmp[i]**2
for i in range(len(tmp)):
mean[i] *= 1./nprobes
var[i] *= 1./nprobes
var[i] -= mean[i]**2
return mean, var
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment