Commit 2ce453c2 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

fix FFT, cleanups

parent 4bed5bd5
Pipeline #21444 passed with stage
in 5 minutes and 6 seconds
...@@ -12,63 +12,66 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", ...@@ -12,63 +12,66 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
_comm = MPI.COMM_WORLD _comm = MPI.COMM_WORLD
ntask = _comm.Get_size() ntask = _comm.Get_size()
rank = _comm.Get_rank() rank = _comm.Get_rank()
master = rank==0 master = rank == 0
def _shareSize(nwork, nshares, myshare): def _shareSize(nwork, nshares, myshare):
nbase = nwork//nshares nbase = nwork//nshares
return nbase if myshare>=nwork%nshares else nbase+1 return nbase if myshare >= nwork % nshares else nbase+1
def _shareRange(nwork, nshares, myshare): def _shareRange(nwork, nshares, myshare):
nbase = nwork//nshares; nbase = nwork//nshares
additional = nwork%nshares; additional = nwork % nshares
lo = myshare*nbase + min(myshare, additional) lo = myshare*nbase + min(myshare, additional)
hi = lo+nbase+ (1 if myshare<additional else 0) hi = lo + nbase + (1 if myshare < additional else 0)
return lo,hi return lo, hi
def local_shape(shape, distaxis): def local_shape(shape, distaxis):
if len(shape)==0: if len(shape) == 0:
distaxis = -1 distaxis = -1
if distaxis==-1: if distaxis == -1:
return shape return shape
shape2=list(shape) shape2 = list(shape)
shape2[distaxis]=_shareSize(shape[distaxis],ntask,rank) shape2[distaxis] = _shareSize(shape[distaxis], ntask, rank)
return tuple(shape2) return tuple(shape2)
class data_object(object): class data_object(object):
def __init__(self, shape, data, distaxis): def __init__(self, shape, data, distaxis):
"""Must not be called directly by users""" """Must not be called directly by users"""
self._shape = tuple(shape) self._shape = tuple(shape)
if len(self._shape)==0: if len(self._shape) == 0:
distaxis = -1 distaxis = -1
self._distaxis = distaxis self._distaxis = distaxis
lshape = local_shape(self._shape, self._distaxis)
self._data = data self._data = data
def _sanity_checks(self): def _sanity_checks(self):
# check whether the distaxis is consistent # check whether the distaxis is consistent
if self._distaxis<-1 or self._distaxis>=len(self._shape): if self._distaxis < -1 or self._distaxis >= len(self._shape):
raise ValueError raise ValueError
itmp=np.array(self._distaxis) itmp = np.array(self._distaxis)
otmp=np.empty(ntask,dtype=np.int) otmp = np.empty(ntask, dtype=np.int)
_comm.Allgather(itmp,otmp) _comm.Allgather(itmp, otmp)
if np.any(otmp!=self._distaxis): if np.any(otmp != self._distaxis):
raise ValueError raise ValueError
# check whether the global shape is consistent # check whether the global shape is consistent
itmp=np.array(self._shape) itmp = np.array(self._shape)
otmp=np.empty((ntask,len(self._shape)),dtype=np.int) otmp = np.empty((ntask, len(self._shape)), dtype=np.int)
_comm.Allgather(itmp,otmp) _comm.Allgather(itmp, otmp)
for i in range(ntask): for i in range(ntask):
if np.any(otmp[i,:]!=self._shape): if np.any(otmp[i, :] != self._shape):
raise ValueError raise ValueError
# check shape of local data # check shape of local data
if self._distaxis<0: if self._distaxis < 0:
if self._data.shape!=self._shape: if self._data.shape != self._shape:
raise ValueError raise ValueError
else: else:
itmp=np.array(self._shape) itmp = np.array(self._shape)
itmp[self._distaxis] = _shareSize(self._shape[self._distaxis],ntask,rank) itmp[self._distaxis] = _shareSize(self._shape[self._distaxis],
if np.any(self._data.shape!=itmp): ntask, rank)
if np.any(self._data.shape != itmp):
raise ValueError raise ValueError
@property @property
...@@ -93,52 +96,50 @@ class data_object(object): ...@@ -93,52 +96,50 @@ class data_object(object):
def _contraction_helper(self, op, mpiop, axis): def _contraction_helper(self, op, mpiop, axis):
if axis is not None: if axis is not None:
if len(axis)==len(self._data.shape): if len(axis) == len(self._data.shape):
axis = None axis = None
if axis is None: if axis is None:
res = np.array(getattr(self._data, op)()) res = np.array(getattr(self._data, op)())
if (self._distaxis==-1): if (self._distaxis == -1):
return res[()] return res[()]
res2 = np.empty((),dtype=res.dtype) res2 = np.empty((), dtype=res.dtype)
_comm.Allreduce(res,res2,mpiop) _comm.Allreduce(res, res2, mpiop)
return res2[()] return res2[()]
if self._distaxis in axis: if self._distaxis in axis:
res = getattr(self._data, op)(axis=axis) res = getattr(self._data, op)(axis=axis)
res2 = np.empty_like(res) res2 = np.empty_like(res)
_comm.Allreduce(res,res2,mpiop) _comm.Allreduce(res, res2, mpiop)
return from_global_data(res2, distaxis=0) return from_global_data(res2, distaxis=0)
else: else:
# perform the contraction on the local data # perform the contraction on the local data
res = getattr(self._data, op)(axis=axis) res = getattr(self._data, op)(axis=axis)
if self._distaxis == -1: if self._distaxis == -1:
return from_global_data(res,distaxis=0) return from_global_data(res, distaxis=0)
shp = list(res.shape) shp = list(res.shape)
shift=0 shift = 0
for ax in axis: for ax in axis:
if ax<self._distaxis: if ax < self._distaxis:
shift+=1 shift += 1
shp[self._distaxis-shift] = self.shape[self._distaxis] shp[self._distaxis-shift] = self.shape[self._distaxis]
return from_local_data(shp, res, self._distaxis-shift) return from_local_data(shp, res, self._distaxis-shift)
# check if the result is scalar or if a result_field must be constr.
if np.isscalar(data):
return data
else:
return data_object(data)
def sum(self, axis=None): def sum(self, axis=None):
return self._contraction_helper("sum", MPI.SUM, axis) return self._contraction_helper("sum", MPI.SUM, axis)
def min(self, axis=None): def min(self, axis=None):
return self._contraction_helper("min", MPI.MIN, axis) return self._contraction_helper("min", MPI.MIN, axis)
def max(self, axis=None): def max(self, axis=None):
return self._contraction_helper("max", MPI.MAX, axis) return self._contraction_helper("max", MPI.MAX, axis)
# FIXME: to be improved! # FIXME: to be improved!
def mean(self): def mean(self):
return self.sum()/self.size return self.sum()/self.size
def std(self): def std(self):
return np.sqrt(self.var()) return np.sqrt(self.var())
def var(self): def var(self):
return (abs(self-self.mean())**2).mean() return (abs(self-self.mean())**2).mean()
...@@ -157,7 +158,10 @@ class data_object(object): ...@@ -157,7 +158,10 @@ class data_object(object):
b = other b = other
tval = getattr(a, op)(b) tval = getattr(a, op)(b)
return self if tval is a else data_object(self._shape, tval, self._distaxis) if tval is a:
return self
else:
return data_object(self._shape, tval, self._distaxis)
def __add__(self, other): def __add__(self, other):
return self._binary_helper(other, op='__add__') return self._binary_helper(other, op='__add__')
...@@ -217,10 +221,10 @@ class data_object(object): ...@@ -217,10 +221,10 @@ class data_object(object):
return self._binary_helper(other, op='__ne__') return self._binary_helper(other, op='__ne__')
def __neg__(self): def __neg__(self):
return data_object(self._shape,-self._data,self._distaxis) return data_object(self._shape, -self._data, self._distaxis)
def __abs__(self): def __abs__(self):
return data_object(self._shape,np.abs(self._data),self._distaxis) return data_object(self._shape, np.abs(self._data), self._distaxis)
def all(self): def all(self):
return self._data.all() return self._data.all()
...@@ -230,19 +234,23 @@ class data_object(object): ...@@ -230,19 +234,23 @@ class data_object(object):
def full(shape, fill_value, dtype=None, distaxis=0): def full(shape, fill_value, dtype=None, distaxis=0):
return data_object(shape, np.full(local_shape(shape, distaxis), fill_value, dtype), distaxis) return data_object(shape, np.full(local_shape(shape, distaxis),
fill_value, dtype), distaxis)
def empty(shape, dtype=None, distaxis=0): def empty(shape, dtype=None, distaxis=0):
return data_object(shape, np.empty(local_shape(shape, distaxis), dtype), distaxis) return data_object(shape, np.empty(local_shape(shape, distaxis),
dtype), distaxis)
def zeros(shape, dtype=None, distaxis=0): def zeros(shape, dtype=None, distaxis=0):
return data_object(shape, np.zeros(local_shape(shape, distaxis), dtype), distaxis) return data_object(shape, np.zeros(local_shape(shape, distaxis), dtype),
distaxis)
def ones(shape, dtype=None, distaxis=0): def ones(shape, dtype=None, distaxis=0):
return data_object(shape, np.ones(local_shape(shape, distaxis), dtype), distaxis) return data_object(shape, np.ones(local_shape(shape, distaxis), dtype),
distaxis)
def empty_like(a, dtype=None): def empty_like(a, dtype=None):
...@@ -251,8 +259,8 @@ def empty_like(a, dtype=None): ...@@ -251,8 +259,8 @@ def empty_like(a, dtype=None):
def vdot(a, b): def vdot(a, b):
tmp = np.array(np.vdot(a._data, b._data)) tmp = np.array(np.vdot(a._data, b._data))
res = np.empty((),dtype=tmp.dtype) res = np.empty((), dtype=tmp.dtype)
_comm.Allreduce(tmp,res,MPI.SUM) _comm.Allreduce(tmp, res, MPI.SUM)
return res[()] return res[()]
...@@ -261,7 +269,7 @@ def _math_helper(x, function, out): ...@@ -261,7 +269,7 @@ def _math_helper(x, function, out):
function(x._data, out=out._data) function(x._data, out=out._data)
return out return out
else: else:
return data_object(x.shape,function(x._data),x._distaxis) return data_object(x.shape, function(x._data), x._distaxis)
def abs(a, out=None): def abs(a, out=None):
...@@ -288,28 +296,31 @@ def bincount(x, weights=None, minlength=None): ...@@ -288,28 +296,31 @@ def bincount(x, weights=None, minlength=None):
def from_object(object, dtype=None, copy=True): def from_object(object, dtype=None, copy=True):
return data_object(object._shape, np.array(object._data, dtype=dtype, copy=copy), distaxis=object._distaxis) return data_object(object._shape, np.array(object._data, dtype=dtype,
copy=copy),
distaxis=object._distaxis)
def from_random(random_type, shape, dtype=np.float64, distaxis=0, **kwargs): def from_random(random_type, shape, dtype=np.float64, distaxis=0, **kwargs):
generator_function = getattr(Random, random_type) generator_function = getattr(Random, random_type)
#lshape = local_shape(shape, distaxis) # lshape = local_shape(shape, distaxis)
#return data_object(shape, generator_function(dtype=dtype, shape=lshape, **kwargs), distaxis=distaxis) # return data_object(shape, generator_function(dtype=dtype, shape=lshape, **kwargs), distaxis=distaxis)
return from_global_data(generator_function(dtype=dtype, shape=shape, **kwargs), distaxis=distaxis) return from_global_data(generator_function(dtype=dtype, shape=shape, **kwargs), distaxis=distaxis)
def local_data(arr): def local_data(arr):
return arr._data return arr._data
def ibegin(arr): def ibegin(arr):
res = [0] * arr._data.ndim res = [0] * arr._data.ndim
res[arr._distaxis] = _shareRange(arr._shape[arr._distaxis],ntask,rank)[0] res[arr._distaxis] = _shareRange(arr._shape[arr._distaxis], ntask, rank)[0]
return tuple(res) return tuple(res)
def np_allreduce_sum(arr): def np_allreduce_sum(arr):
res = np.empty_like(arr) res = np.empty_like(arr)
_comm.Allreduce(arr,res,MPI.SUM) _comm.Allreduce(arr, res, MPI.SUM)
return res return res
...@@ -317,97 +328,102 @@ def distaxis(arr): ...@@ -317,97 +328,102 @@ def distaxis(arr):
return arr._distaxis return arr._distaxis
def from_local_data (shape, arr, distaxis): def from_local_data(shape, arr, distaxis):
return data_object(shape, arr, distaxis) return data_object(shape, arr, distaxis)
def from_global_data (arr, distaxis=0): def from_global_data(arr, distaxis=0):
if distaxis==-1: if distaxis == -1:
return data_object(arr.shape, arr, distaxis) return data_object(arr.shape, arr, distaxis)
lo, hi = _shareRange(arr.shape[distaxis],ntask,rank) lo, hi = _shareRange(arr.shape[distaxis], ntask, rank)
sl = [slice(None)]*len(arr.shape) sl = [slice(None)]*len(arr.shape)
sl[distaxis]=slice(lo,hi) sl[distaxis] = slice(lo, hi)
return data_object(arr.shape, arr[sl], distaxis) return data_object(arr.shape, arr[sl], distaxis)
def to_global_data (arr): def to_global_data(arr):
if arr._distaxis==-1: if arr._distaxis == -1:
return arr._data return arr._data
tmp = redistribute(arr, dist=-1) tmp = redistribute(arr, dist=-1)
return tmp._data return tmp._data
def redistribute (arr, dist=None, nodist=None): def redistribute(arr, dist=None, nodist=None):
if dist is not None: if dist is not None:
if nodist is not None: if nodist is not None:
raise ValueError raise ValueError
if dist==arr._distaxis: if dist == arr._distaxis:
return arr return arr
else: else:
if nodist is None: if nodist is None:
raise ValueError raise ValueError
if arr._distaxis not in nodist: if arr._distaxis not in nodist:
return arr return arr
dist=-1 dist = -1
for i in range(len(arr.shape)): for i in range(len(arr.shape)):
if i not in nodist: if i not in nodist:
dist=i dist = i
break break
if arr._distaxis==-1: # just pick the proper subset if arr._distaxis == -1: # just pick the proper subset
return from_global_data(arr._data, dist) return from_global_data(arr._data, dist)
if dist==-1: # gather data if dist == -1: # gather data
tmp = np.moveaxis(arr._data, arr._distaxis, 0) tmp = np.moveaxis(arr._data, arr._distaxis, 0)
slabsize=np.prod(tmp.shape[1:])*tmp.itemsize slabsize = np.prod(tmp.shape[1:])*tmp.itemsize
sz=np.empty(ntask,dtype=np.int) sz = np.empty(ntask, dtype=np.int)
for i in range(ntask): for i in range(ntask):
sz[i]=slabsize*_shareSize(arr.shape[arr._distaxis],ntask,i) sz[i] = slabsize*_shareSize(arr.shape[arr._distaxis], ntask, i)
disp=np.empty(ntask,dtype=np.int) disp = np.empty(ntask, dtype=np.int)
disp[0]=0 disp[0] = 0
disp[1:]=np.cumsum(sz[:-1]) disp[1:] = np.cumsum(sz[:-1])
tmp=tmp.flatten() tmp = tmp.flatten()
out = np.empty(arr.size,dtype=arr.dtype) out = np.empty(arr.size, dtype=arr.dtype)
_comm.Allgatherv(tmp,[out,sz,disp,MPI.BYTE]) _comm.Allgatherv(tmp, [out, sz, disp, MPI.BYTE])
shp = np.array(arr._shape) shp = np.array(arr._shape)
shp[1:arr._distaxis+1] = shp[0:arr._distaxis] shp[1:arr._distaxis+1] = shp[0:arr._distaxis]
shp[0] = arr.shape[arr._distaxis] shp[0] = arr.shape[arr._distaxis]
out = out.reshape(shp) out = out.reshape(shp)
out = np.moveaxis(out, 0, arr._distaxis) out = np.moveaxis(out, 0, arr._distaxis)
return from_global_data (out, distaxis=-1) return from_global_data(out, distaxis=-1)
# real redistribution via Alltoallv # real redistribution via Alltoallv
# temporary slow, but simple solution for comparison purposes: # temporary slow, but simple solution for comparison purposes:
#return redistribute(redistribute(arr,dist=-1),dist=dist) # return redistribute(redistribute(arr,dist=-1),dist=dist)
tmp = np.moveaxis(arr._data, (dist, arr._distaxis), (0, 1)) tmp = np.moveaxis(arr._data, (dist, arr._distaxis), (0, 1))
tshape = tmp.shape tshape = tmp.shape
slabsize=np.prod(tmp.shape[2:])*tmp.itemsize slabsize = np.prod(tmp.shape[2:])*tmp.itemsize
ssz=np.empty(ntask,dtype=np.int) ssz = np.empty(ntask, dtype=np.int)
rsz=np.empty(ntask,dtype=np.int) rsz = np.empty(ntask, dtype=np.int)
for i in range(ntask): for i in range(ntask):
ssz[i]=_shareSize(arr.shape[dist],ntask,i)*tmp.shape[1]*slabsize ssz[i] = _shareSize(arr.shape[dist], ntask, i)*tmp.shape[1]*slabsize
rsz[i]=_shareSize(arr.shape[dist],ntask,rank)*_shareSize(arr.shape[arr._distaxis],ntask,i)*slabsize rsz[i] = _shareSize(arr.shape[dist], ntask, rank) * \
sdisp=np.empty(ntask,dtype=np.int) _shareSize(arr.shape[arr._distaxis], ntask, i) * \
rdisp=np.empty(ntask,dtype=np.int) slabsize
sdisp[0]=0 sdisp = np.empty(ntask, dtype=np.int)
rdisp[0]=0 rdisp = np.empty(ntask, dtype=np.int)
sdisp[1:]=np.cumsum(ssz[:-1]) sdisp[0] = 0
rdisp[1:]=np.cumsum(rsz[:-1]) rdisp[0] = 0
tmp=tmp.flatten() sdisp[1:] = np.cumsum(ssz[:-1])
out = np.empty(np.prod(local_shape(arr.shape,dist)),dtype=arr.dtype) rdisp[1:] = np.cumsum(rsz[:-1])
tmp = tmp.flatten()
out = np.empty(np.prod(local_shape(arr.shape, dist)), dtype=arr.dtype)
s_msg = [tmp, (ssz, sdisp), MPI.BYTE] s_msg = [tmp, (ssz, sdisp), MPI.BYTE]
r_msg = [out, (rsz, rdisp), MPI.BYTE] r_msg = [out, (rsz, rdisp), MPI.BYTE]
_comm.Alltoallv(s_msg, r_msg) _comm.Alltoallv(s_msg, r_msg)
out2 = np.empty([_shareSize(arr.shape[dist],ntask,rank), arr.shape[arr._distaxis]] +list(tshape[2:]), dtype=arr.dtype) out2 = np.empty([_shareSize(arr.shape[dist], ntask, rank),
ofs=0 arr.shape[arr._distaxis]] + list(tshape[2:]),
dtype=arr.dtype)
ofs = 0
for i in range(ntask): for i in range(ntask):
lsize = rsz[i]//tmp.itemsize lsize = rsz[i]//tmp.itemsize
lo,hi = _shareRange(arr.shape[arr._distaxis],ntask,i) lo, hi = _shareRange(arr.shape[arr._distaxis], ntask, i)
out2[slice(None),slice(lo,hi)] = out[ofs:ofs+lsize].reshape([_shareSize(arr.shape[dist],ntask,rank),_shareSize(arr.shape[arr._distaxis],ntask,i)]+list(tshape[2:])) out2[slice(None), slice(lo, hi)] = \
out[ofs:ofs+lsize].reshape([_shareSize(arr.shape[dist], ntask, rank),_shareSize(arr.shape[arr._distaxis],ntask,i)]+list(tshape[2:]))
ofs += lsize ofs += lsize
new_shape = [_shareSize(arr.shape[dist],ntask,rank), arr.shape[arr._distaxis]] +list(tshape[2:]) new_shape = [_shareSize(arr.shape[dist],ntask,rank), arr.shape[arr._distaxis]] +list(tshape[2:])
out2=out2.reshape(new_shape) out2 = out2.reshape(new_shape)
out2 = np.moveaxis(out2, (0, 1), (dist, arr._distaxis)) out2 = np.moveaxis(out2, (0, 1), (dist, arr._distaxis))
return from_local_data (arr.shape, out2, dist) return from_local_data(arr.shape, out2, dist)
def default_distaxis(): def default_distaxis():
......
...@@ -6,17 +6,6 @@ class data_object(object): ...@@ -6,17 +6,6 @@ class data_object(object):
def __init__(self, npdata): def __init__(self, npdata):
self._data = np.asarray(npdata) self._data = np.asarray(npdata)
# FIXME: subscripting support will most likely go away
def __getitem__(self, key):
res = self._data[key]
return res if np.isscalar(res) else data_object(res)
def __setitem__(self, key, value):
if isinstance(value, data_object):
self._data[key] = value._data
else:
self._data[key] = value
@property @property
def dtype(self): def dtype(self):
return self._data.dtype return self._data.dtype
...@@ -39,7 +28,7 @@ class data_object(object): ...@@ -39,7 +28,7 @@ class data_object(object):
def _contraction_helper(self, op, axis): def _contraction_helper(self, op, axis):