Commit 412b149a authored by Martin Reinecke's avatar Martin Reinecke
Browse files

simplify data redistribution

parent 0e487953
......@@ -33,7 +33,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"np_allreduce_min", "np_allreduce_max",
"distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw"]
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed"]
_comm = MPI.COMM_WORLD
ntask = _comm.Get_size()
......@@ -540,3 +541,15 @@ def lock(arr):
def locked(arr):
return not arr._data.flags.writeable
def ensure_not_distributed(arr, axes):
if arr._distaxis in axes:
arr = redistribute(arr, nodist=axes)
return arr, arr._data
def ensure_default_distributed(arr):
if arr._distaxis != 0:
arr = redistribute(arr, dist=0)
return arr
......@@ -32,7 +32,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"np_allreduce_min", "np_allreduce_max",
"distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy",
"lock", "locked", "uniform_full", "to_global_data_rw"]
"lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed"]
ntask = 1
rank = 0
......@@ -132,3 +133,11 @@ def locked(arr):
def uniform_full(shape, fill_value, dtype=None, distaxis=-1):
return np.broadcast_to(fill_value, shape)
def ensure_not_distributed(arr, axes):
return arr, arr
def ensure_default_distributed(arr):
return arr
......@@ -93,14 +93,10 @@ class CentralZeroPadder(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val
dax = dobj.distaxis(x)
v = x.val
shp_out = self._tgt(mode).shape
axes = self._target.axes[self._space]
if dax in axes:
x = dobj.redistribute(x, nodist=axes)
curax = dobj.distaxis(x)
x = dobj.local_data(x)
v, x = dobj.ensure_not_distributed(v, self._target.axes[self._space])
curax = dobj.distaxis(v)
if mode == self.TIMES:
# slice along each axis and copy the data to an
......@@ -114,7 +110,5 @@ class CentralZeroPadder(LinearOperator):
y = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
for i in self.slicer:
y[i] = x[i]
y = dobj.from_local_data(shp_out, y, distaxis=curax)
if dax in axes:
y = dobj.redistribute(y, dist=dax)
return Field(self._tgt(mode), val=y)
v = dobj.from_local_data(shp_out, y, distaxis=dobj.distaxis(v))
return Field(self._tgt(mode), dobj.ensure_default_distributed(v))
......@@ -105,8 +105,7 @@ class ExpTransform(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val
ax = dobj.distaxis(x)
v = x.val
ndim = len(self.target.shape)
curshp = list(self._dom(mode).shape)
d0 = self._target.axes[self._space][0]
......@@ -114,10 +113,7 @@ class ExpTransform(LinearOperator):
idx = (slice(None),) * d
wgt = self._frac[d-d0].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1))
if d == ax:
x = dobj.redistribute(x, nodist=(ax,))
curax = dobj.distaxis(x)
x = dobj.local_data(x)
v, x = dobj.ensure_not_distributed(v, (d,))
if mode == self.ADJOINT_TIMES:
shp = list(x.shape)
......@@ -130,7 +126,5 @@ class ExpTransform(LinearOperator):
xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
curshp[d] = self._tgt(mode).shape[d]
x = dobj.from_local_data(curshp, xnew, distaxis=curax)
if d == ax:
x = dobj.redistribute(x, dist=ax)
return Field(self._tgt(mode), val=x)
v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
return Field(self._tgt(mode), dobj.ensure_default_distributed(v))
......@@ -322,23 +322,19 @@ class SHTOperator(LinearOperator):
def _apply_spherical(self, x, mode):
axes = x.domain.axes[self._space]
axis = axes[0]
tval = x.val
if dobj.distaxis(tval) == axis:
tval = dobj.redistribute(tval, nodist=(axis,))
distaxis = dobj.distaxis(tval)
v = x.val
v, idat = dobj.ensure_not_distributed(v, (axis,))
distaxis = dobj.distaxis(v)
p2h = not x.domain[self._space].harmonic
tdom = self._tgt(mode)
func = self._slice_p2h if p2h else self._slice_h2p
idat = dobj.local_data(tval)
odat = np.empty(dobj.local_shape(tdom.shape, distaxis=distaxis),
dtype=x.dtype)
for slice in utilities.get_slice_list(idat.shape, axes):
odat[slice] = func(idat[slice])
odat = dobj.from_local_data(tdom.shape, odat, distaxis)
if distaxis != dobj.distaxis(x.val):
odat = dobj.redistribute(odat, dist=dobj.distaxis(x.val))
return Field(tdom, odat)
return Field(tdom, dobj.ensure_default_distributed(odat))
class HarmonicTransformOperator(LinearOperator):
......
......@@ -78,10 +78,7 @@ class LaplaceOperator(EndomorphicOperator):
sl_r = prefix + (slice(1, None),) # "right" slice
dpos = self._dpos.reshape((1,)*axis + (nval-1,))
dposc = self._dposc.reshape((1,)*axis + (nval,))
locval = x.val
if axis == dobj.distaxis(locval):
locval = dobj.redistribute(locval, nodist=(axis,))
val = dobj.local_data(locval)
v, val = dobj.ensure_not_distributed(x.val, (axis,))
ret = np.empty_like(val)
if mode == self.TIMES:
deriv = (val[sl_r]-val[sl_l])/dpos # defined between points
......@@ -99,7 +96,5 @@ class LaplaceOperator(EndomorphicOperator):
ret[sl_l] = deriv
ret[prefix + (-1,)] = 0.
ret[sl_r] -= deriv
ret = dobj.from_local_data(locval.shape, ret, dobj.distaxis(locval))
if dobj.distaxis(locval) != dobj.distaxis(x.val):
ret = dobj.redistribute(ret, dist=dobj.distaxis(x.val))
return Field(self.domain, val=ret)
ret = dobj.from_local_data(x.shape, ret, dobj.distaxis(v))
return Field(self.domain, dobj.ensure_default_distributed(ret))
......@@ -64,16 +64,11 @@ class QHTOperator(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
dom = self._domain[self._space]
x = x.val * dom.scalar_dvol
v = x.val * dom.scalar_dvol
n = self._domain.axes[self._space]
rng = n if mode == self.TIMES else reversed(n)
ax = dobj.distaxis(x)
for i in rng:
sl = (slice(None),)*i + (slice(1, None),)
if i == ax:
x = dobj.redistribute(x, nodist=(ax,))
tmp = dobj.local_data(x)
v, tmp = dobj.ensure_not_distributed(v, (i,))
tmp[sl] = hartley(tmp[sl], axes=(i,))
if i == ax:
x = dobj.redistribute(x, dist=ax)
return Field(self._tgt(mode), val=x)
return Field(self._tgt(mode), dobj.ensure_default_distributed(v))
......@@ -58,13 +58,12 @@ class RegriddingOperator(LinearOperator):
self._frac = [None] * ndim
for d in range(ndim):
tmp = np.arange(new_shape[d])*(newdist[d]/dom.distances[d])
self._bindex[d] = np.minimum(dom.shape[d]-2,tmp.astype(np.int))
self._bindex[d] = np.minimum(dom.shape[d]-2, tmp.astype(np.int))
self._frac = tmp-self._bindex[d]
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val
ax = dobj.distaxis(x)
v = x.val
ndim = len(self.target.shape)
curshp = list(self._dom(mode).shape)
d0 = self._target.axes[self._space][0]
......@@ -72,10 +71,7 @@ class RegriddingOperator(LinearOperator):
idx = (slice(None),) * d
wgt = self._frac[d-d0].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1))
if d == ax:
x = dobj.redistribute(x, nodist=(ax,))
curax = dobj.distaxis(x)
x = dobj.local_data(x)
v, x = dobj.ensure_not_distributed(v, (d,))
if mode == self.ADJOINT_TIMES:
shp = list(x.shape)
......@@ -88,7 +84,5 @@ class RegriddingOperator(LinearOperator):
xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
curshp[d] = self._tgt(mode).shape[d]
x = dobj.from_local_data(curshp, xnew, distaxis=curax)
if d == ax:
x = dobj.redistribute(x, dist=ax)
return Field(self._tgt(mode), val=x)
v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
return Field(self._tgt(mode), dobj.ensure_default_distributed(v))
......@@ -37,14 +37,9 @@ class SymmetrizingOperator(EndomorphicOperator):
def apply(self, x, mode):
self._check_input(x, mode)
tmp = x.val.copy()
ax = dobj.distaxis(tmp)
v = x.val.copy()
for i in self._domain.axes[self._space]:
lead = (slice(None),)*i
if i == ax:
tmp = dobj.redistribute(tmp, nodist=(ax,))
tmp2 = dobj.local_data(tmp)
tmp2[lead+(slice(1, None),)] -= tmp2[lead+(slice(None, 0, -1),)]
if i == ax:
tmp = dobj.redistribute(tmp, dist=ax)
return Field(self.target, val=tmp)
v, loc = dobj.ensure_not_distributed(v, (i,))
loc[lead+(slice(1, None),)] -= loc[lead+(slice(None, 0, -1),)]
return Field(self.target, dobj.ensure_default_distributed(v))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment