Commit c7e95442 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more MPI improvements

parent bad97bba
......@@ -59,6 +59,8 @@ class data_object(object):
distaxis = -1
self._distaxis = distaxis
self._data = data
if local_shape(self._shape, self._distaxis) != self._data.shape:
raise ValueError("shape mismatch")
# def _sanity_checks(self):
# # check whether the distaxis is consistent
......
......@@ -4,6 +4,7 @@ from ..domain_tuple import DomainTuple
from ..domains import PowerSpace, RGSpace
from ..field import Field
from .linear_operator import LinearOperator
from .. import dobj
class ExpTransform(LinearOperator):
......@@ -27,7 +28,7 @@ class ExpTransform(LinearOperator):
for i in range(ndim):
if isinstance(target, RGSpace):
rng = np.arange(target.shape[i])
tmp = np.minimum(rng, target.shape[i] + 1 - rng)
tmp = np.minimum(rng, target.shape[i]+1-rng)
k_array = tmp * target.distances[i]
else:
k_array = target.k_lengths
......@@ -42,8 +43,8 @@ class ExpTransform(LinearOperator):
# Save t_min for later
t_mins[i] = t_min
bindistances[i] = (t_max - t_min) / (dof[i] - 1)
coord = np.append(0., 1. + (log_k_array - t_min) / bindistances[i])
bindistances[i] = (t_max-t_min) / (dof[i]-1)
coord = np.append(0., 1. + (log_k_array-t_min) / bindistances[i])
self._bindex[i] = np.floor(coord).astype(int)
# Interpolated value is computed via
......@@ -52,7 +53,7 @@ class ExpTransform(LinearOperator):
self._frac[i] = coord - self._bindex[i]
from ..domains import LogRGSpace
log_space = LogRGSpace(2 * dof + 1, bindistances,
log_space = LogRGSpace(2*dof+1, bindistances,
t_mins, harmonic=False)
self._target = DomainTuple.make(target)
self._domain = DomainTuple.make(log_space)
......@@ -67,30 +68,34 @@ class ExpTransform(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
x = x.to_global_data()
x = x.val
ax = dobj.distaxis(x)
ndim = len(self.target.shape)
idx = ()
curshp = list(self._dom(mode).shape)
for d in range(ndim):
fst_dims = (1,) * d
lst_dims = (1,) * (ndim - d - 1)
wgt = self._frac[d].reshape(fst_dims + (-1,) + lst_dims)
idx = (slice(None,),) * d
wgt = self._frac[d].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1))
if d == ax:
x = dobj.redistribute(x, nodist=(ax,))
curax = dobj.distaxis(x)
x = dobj.local_data(x)
# ADJOINT_TIMES
if mode == self.ADJOINT_TIMES:
shp = list(x.shape)
shp[d] = self._tgt(mode).shape[d]
xnew = np.zeros(shp, dtype=x.dtype)
np.add.at(xnew, idx + (self._bindex[d],), x * (1. - wgt))
np.add.at(xnew, idx + (self._bindex[d] + 1,), x * wgt)
# TIMES
else:
xnew = x[idx + (self._bindex[d],)] * (1. - wgt)
xnew += x[idx + (self._bindex[d] + 1,)] * wgt
x = xnew
idx = (slice(None),) + idx
return Field.from_global_data(self._tgt(mode), x)
np.add.at(xnew, idx + (self._bindex[d],), x * (1.-wgt))
np.add.at(xnew, idx + (self._bindex[d]+1,), x * wgt)
else: # TIMES
xnew = x[idx + (self._bindex[d],)] * (1.-wgt)
xnew += x[idx + (self._bindex[d]+1,)] * wgt
curshp[d] = self._tgt(mode).shape[d]
x = dobj.from_local_data(curshp, xnew, distaxis=curax)
if d == ax:
x = dobj.redistribute(x, dist=ax)
return Field(self._tgt(mode), val=x)
@property
def capability(self):
......
......@@ -35,18 +35,18 @@ class QHTOperator(LinearOperator):
x = x.val * self.domain[0].scalar_dvol()
n = len(self.domain[0].shape)
rng = range(n) if mode == self.TIMES else reversed(range(n))
ax = dobj.distaxis(x)
globshape = x.shape
for i in rng:
sl = (slice(None),)*i + (slice(1, None),)
if i == dobj.distaxis(x):
x = dobj.redistribute(x, nodist=(i,))
ax = dobj.distaxis(x)
x = dobj.local_data(x)
x[sl] = hartley(x[sl], axes=(i,))
x = dobj.from_local_data(x.shape, x, distaxis=ax)
x = dobj.redistribute(x, dist=i)
else:
x[sl] = hartley(x[sl], axes=(i,))
if i == ax:
x = dobj.redistribute(x, nodist=(ax,))
curax = dobj.distaxis(x)
x = dobj.local_data(x)
x[sl] = hartley(x[sl], axes=(i,))
x = dobj.from_local_data(globshape, x, distaxis=curax)
if i == ax:
x = dobj.redistribute(x, dist=ax)
return Field(self._tgt(mode), val=x)
@property
......
from ..domain_tuple import DomainTuple
from ..field import Field
from .endomorphic_operator import EndomorphicOperator
from .. import dobj
class SymmetrizingOperator(EndomorphicOperator):
......@@ -14,12 +15,20 @@ class SymmetrizingOperator(EndomorphicOperator):
def apply(self, x, mode):
self._check_input(x, mode)
# FIXME Not efficient with MPI
tmp = x.to_global_data().copy()
tmp = x.copy().val
ax = dobj.distaxis(tmp)
globshape = tmp.shape
for i in range(self._ndim):
lead = (slice(None),)*i
if i == ax:
tmp = dobj.redistribute(tmp, nodist=(ax,))
curax = dobj.distaxis(tmp)
tmp = dobj.local_data(tmp)
tmp[lead + (slice(1, None),)] -= tmp[lead + (slice(None, 0, -1),)]
return Field.from_global_data(self.target, tmp)
tmp = dobj.from_local_data(globshape, tmp, distaxis=curax)
if i == ax:
tmp = dobj.redistribute(tmp, dist=ax)
return Field(self.target, val=tmp)
@property
def capability(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment