import numpy as np from .random import Random from mpi4py import MPI comm = MPI.COMM_WORLD ntask = comm.Get_size() rank = comm.Get_rank() def shareSize(nwork, nshares, myshare): nbase = nwork//nshares return nbase if myshare>=nwork%nshares else nbase+1 def shareRange(nwork, nshares, myshare): nbase = nwork//nshares; additional = nwork%nshares; lo = myshare*nbase + min(myshare, additional) hi = lo+nbase+ (1 if myshare=len(self._shape): raise ValueError itmp=np.array(self._distaxis) otmp=np.empty(ntask,dtype=np.int) comm.Allgather(itmp,otmp) if np.any(otmp!=self._distaxis): raise ValueError # check whether the global shape is consistent itmp=np.array(self._shape) otmp=np.empty((ntask,len(self._shape)),dtype=np.int) comm.Allgather(itmp,otmp) for i in range(ntask): if np.any(otmp[i,:]!=self._shape): raise ValueError # check shape of local data if self._distaxis<0: if self._data.shape!=self._shape: raise ValueError else: itmp=np.array(self._shape) itmp[self._distaxis] = shareSize(self._shape[self._distaxis],ntask,rank) if np.any(self._data.shape!=itmp): raise ValueError @property def dtype(self): return self._data.dtype @property def shape(self): return self._shape @property def size(self): return np.prod(self._shape) @property def real(self): return data_object(self._shape, self._data.real, self._distaxis) @property def imag(self): return data_object(self._shape, self._data.imag, self._distaxis) def _contraction_helper(self, op, mpiop, axis): if axis is not None: if len(axis)==len(self._data.shape): axis = None if axis is None: res = np.array(getattr(self._data, op)()) if (self._distaxis==-1): return res[0] res2 = np.empty(1,dtype=res.dtype) comm.Allreduce(res,res2,mpiop) return res2[0] if self._distaxis in axis: res = getattr(self._data, op)(axis=axis) res2 = np.empty_like(res) comm.Allreduce(res,res2,mpiop) return from_global_data(res2, distaxis=0) else: # perform the contraction on the local data res = getattr(self._data, op)(axis=axis) if self._distaxis == -1: return from_global_data(res,distaxis=0) shp = list(res.shape) shift=0 for ax in axis: if ax