Commit 4f5692dd authored by Martin Reinecke's avatar Martin Reinecke
Browse files

towards MPI

parent 8dbc55d6
Pipeline #21180 passed with stage
in 4 minutes and 11 seconds
......@@ -236,3 +236,35 @@ def np_allreduce_sum(arr):
def dist_axis(arr):
return -1
def from_local_data (shape, arr, dist_axis):
if dist_axis!=-1:
raise NotImplementedError
if shape!=arr.shape:
raise ValueError
return data_object(arr)
def from_global_data (arr, dist_axis):
if dist_axis!=-1:
raise NotImplementedError
if shape!=arr.shape:
raise ValueError
return data_object(arr)
def redistribute (arr, dist=None, nodist=None):
if dist is not None and dist!=-1:
raise NotImplementedError
return arr
def default_dist_axis():
return -1
def local_shape(glob_shape, dist_axis):
if dist_axis!=-1:
raise NotImplementedError
return glob_shape
......@@ -43,3 +43,35 @@ def np_allreduce_sum(arr):
def dist_axis(arr):
return -1
def from_local_data (shape, arr, dist_axis):
if dist_axis!=-1:
raise NotImplementedError
if shape!=arr.shape:
raise ValueError
return arr
def from_global_data (arr, dist_axis):
if dist_axis!=-1:
raise NotImplementedError
if shape!=arr.shape:
raise ValueError
return arr
def redistribute (arr, dist=None, nodist=None):
if dist is not None and dist!=-1:
raise NotImplementedError
return arr
def default_dist_axis():
return -1
def local_shape(glob_shape, dist_axis):
if dist_axis!=-1:
raise NotImplementedError
return glob_shape
#from .data_objects.my_own_do import *
from .data_objects.numpy_do import *
from .data_objects.my_own_do import *
#from .data_objects.numpy_do import *
......@@ -114,16 +114,23 @@ class SphericalTransformation(Transformation):
def transform(self, x):
axes = x.domain.axes[self.space]
if dobj.dist_axis(x.val) in axes:
raise NotImplementedError
axis = axes[0]
tval = x.val
if dobj.dist_axis(tval) == axis:
tval = dobj.redistribute(tval, nodist=(axis,))
distaxis = dobj.dist_axis(tval)
p2h = x.domain == self.pdom
idat = dobj.local_data(x.val)
idat = dobj.local_data(tval)
if p2h:
res = Field(self.hdom, dtype=x.dtype)
odat = dobj.local_data(res.val)
odat = np.empty(dobj.local_shape(self.hdom.shape, dist_axis=distaxis), dtype=x.dtype)
for slice in utilities.get_slice_list(idat.shape, axes):
odat[slice] = self._slice_p2h(idat[slice])
odat = dobj.from_local_data(self.hdom.shape, odat, distaxis)
if distaxis!= dobj.dist_axis(x):
odat = dobj.redistribute(odat, dist=distaxis)
return Field(self.hdom, odat)
else:
res = Field(self.pdom, dtype=x.dtype)
odat = dobj.local_data(res.val)
......
......@@ -52,16 +52,30 @@ class PowerProjectionOperator(LinearOperator):
def _times(self, x):
pindex = self._target[self._space].pindex
pindex = pindex.reshape((1, pindex.size, 1))
arr = x.weight(1).val.reshape(
x.domain.collapsed_shape_for_domain(self._space))
out = dobj.zeros(self._target.collapsed_shape_for_domain(self._space),
dtype=x.dtype)
out = dobj.to_ndarray(out)
np.add.at(out, (slice(None), dobj.to_ndarray(pindex.ravel()), slice(None)), dobj.to_ndarray(arr))
out = dobj.from_ndarray(out)
return Field(self._target, out.reshape(self._target.shape))\
.weight(-1, spaces=self._space)
res = Field.zeros(self._target, dtype=x.dtype)
if dobj.dist_axis(x.val) in x.domain.axes[self._space]: # the distributed axis is part of the projected space
pindex = dobj.local_data(pindex)
pindex.reshape((1, pindex.size, 1))
arr = dobj.local_data(x.weight(1).val)
firstaxis = x.domain.axes[self._space][0]
lastaxis = x.domain.axes[self._space][-1]
presize = np.prod(arr.shape[0:firstaxis], dtype=np.int)
postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int)
arr = arr.reshape((presize,pindex.size,postsize))
oarr = dobj.local_data(res.val).reshape((presize,-1,postsize))
np.add.at(oarr, (slice(None), pindex.ravel(), slice(None)), arr)
else:
pindex = dobj.to_ndarray(pindex)
pindex.reshape((1, pindex.size, 1))
arr = dobj.local_data(x.weight(1).val)
firstaxis = x.domain.axes[self._space][0]
lastaxis = x.domain.axes[self._space][-1]
presize = np.prod(arr.shape[0:firstaxis], dtype=np.int)
postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int)
arr = arr.reshape((presize,pindex.size,postsize))
oarr = dobj.local_data(res.val).reshape((presize,-1,postsize))
np.add.at(oarr, (slice(None), pindex.ravel(), slice(None)), arr)
return res.weight(-1, spaces=self._space)
def _adjoint_times(self, x):
pindex = self._target[self._space].pindex
......
......@@ -143,14 +143,16 @@ class PowerSpace(Space):
else:
tbb = binbounds
locdat = np.searchsorted(tbb, dobj.local_data(k_length_array.val))
temp_pindex = dobj.create_from_template(k_length_array.val, local_data=locdat,
dtype=locdat.dtype)
temp_pindex = dobj.create_from_template(
k_length_array.val, local_data=locdat, dtype=locdat.dtype)
nbin = len(tbb)
temp_rho = np.bincount(dobj.local_data(temp_pindex).ravel(), minlength=nbin)
temp_rho = np.bincount(dobj.local_data(temp_pindex).ravel(),
minlength=nbin)
temp_rho = dobj.np_allreduce_sum(temp_rho)
assert not (temp_rho == 0).any(), "empty bins detected"
temp_k_lengths = np.bincount(dobj.local_data(temp_pindex).ravel(),
weights=dobj.local_data(k_length_array.val).ravel(), minlength=nbin)
weights=dobj.local_data(k_length_array.val).ravel(),
minlength=nbin)
temp_k_lengths = dobj.np_allreduce_sum(temp_k_lengths) / temp_rho
temp_dvol = temp_rho*pdvol
self._powerIndexCache[key] = (binbounds,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment