Commit bad97bba authored by Martin Reinecke's avatar Martin Reinecke
Browse files

improve MPI behaviour

parent ee930b45
......@@ -35,13 +35,19 @@ class QHTOperator(LinearOperator):
x = x.val * self.domain[0].scalar_dvol()
n = len(self.domain[0].shape)
rng = range(n) if mode == self.TIMES else reversed(range(n))
# MR FIXME: this needs to be fixed properly for MPI
x = dobj.to_global_data(x)
for i in rng:
sl = (slice(None),)*i + (slice(1, None),)
x[sl] = hartley(x[sl], axes=(i,))
return Field.from_global_data(self._tgt(mode), x)
if i == dobj.distaxis(x):
x = dobj.redistribute(x, nodist=(i,))
ax = dobj.distaxis(x)
x = dobj.local_data(x)
x[sl] = hartley(x[sl], axes=(i,))
x = dobj.from_local_data(x.shape, x, distaxis=ax)
x = dobj.redistribute(x, dist=i)
else:
x[sl] = hartley(x[sl], axes=(i,))
return Field(self._tgt(mode), val=x)
@property
def capability(self):
......
......@@ -29,10 +29,6 @@ class SlopeOperator(LinearOperator):
lst_dims = (1,) * (self.ndim - i - 1)
self.pos[i] += tmp.reshape(fst_dims + (shape[i],) + lst_dims)
@property
def sigmas(self):
return self._sigmas
@property
def domain(self):
return self._domain
......@@ -47,17 +43,17 @@ class SlopeOperator(LinearOperator):
# Times
if mode == self.TIMES:
inp = x.to_global_data()
res = self.sigmas[-1] * inp[-1]
res = self._sigmas[-1] * inp[-1]
for i in range(self.ndim):
res += self.sigmas[i] * inp[i] * self.pos[i]
res += self._sigmas[i] * inp[i] * self.pos[i]
return Field.from_global_data(self.target, res)
# Adjoint times
res = np.zeros(self.domain[0].shape)
xglob = x.to_global_data()
res[-1] = np.sum(xglob) * self.sigmas[-1]
res[-1] = np.sum(xglob) * self._sigmas[-1]
for i in range(self.ndim):
res[i] = np.sum(self.pos[i] * xglob) * self.sigmas[i]
res[i] = np.sum(self.pos[i] * xglob) * self._sigmas[i]
return Field.from_global_data(self.domain, res)
@property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment