Commit bad97bba authored by Martin Reinecke's avatar Martin Reinecke
Browse files

improve MPI behaviour

parent ee930b45
...@@ -35,13 +35,19 @@ class QHTOperator(LinearOperator): ...@@ -35,13 +35,19 @@ class QHTOperator(LinearOperator):
x = x.val * self.domain[0].scalar_dvol() x = x.val * self.domain[0].scalar_dvol()
n = len(self.domain[0].shape) n = len(self.domain[0].shape)
rng = range(n) if mode == self.TIMES else reversed(range(n)) rng = range(n) if mode == self.TIMES else reversed(range(n))
# MR FIXME: this needs to be fixed properly for MPI
x = dobj.to_global_data(x)
for i in rng: for i in rng:
sl = (slice(None),)*i + (slice(1, None),) sl = (slice(None),)*i + (slice(1, None),)
if i == dobj.distaxis(x):
x = dobj.redistribute(x, nodist=(i,))
ax = dobj.distaxis(x)
x = dobj.local_data(x)
x[sl] = hartley(x[sl], axes=(i,))
x = dobj.from_local_data(x.shape, x, distaxis=ax)
x = dobj.redistribute(x, dist=i)
else:
x[sl] = hartley(x[sl], axes=(i,)) x[sl] = hartley(x[sl], axes=(i,))
return Field.from_global_data(self._tgt(mode), x) return Field(self._tgt(mode), val=x)
@property @property
def capability(self): def capability(self):
......
...@@ -29,10 +29,6 @@ class SlopeOperator(LinearOperator): ...@@ -29,10 +29,6 @@ class SlopeOperator(LinearOperator):
lst_dims = (1,) * (self.ndim - i - 1) lst_dims = (1,) * (self.ndim - i - 1)
self.pos[i] += tmp.reshape(fst_dims + (shape[i],) + lst_dims) self.pos[i] += tmp.reshape(fst_dims + (shape[i],) + lst_dims)
@property
def sigmas(self):
return self._sigmas
@property @property
def domain(self): def domain(self):
return self._domain return self._domain
...@@ -47,17 +43,17 @@ class SlopeOperator(LinearOperator): ...@@ -47,17 +43,17 @@ class SlopeOperator(LinearOperator):
# Times # Times
if mode == self.TIMES: if mode == self.TIMES:
inp = x.to_global_data() inp = x.to_global_data()
res = self.sigmas[-1] * inp[-1] res = self._sigmas[-1] * inp[-1]
for i in range(self.ndim): for i in range(self.ndim):
res += self.sigmas[i] * inp[i] * self.pos[i] res += self._sigmas[i] * inp[i] * self.pos[i]
return Field.from_global_data(self.target, res) return Field.from_global_data(self.target, res)
# Adjoint times # Adjoint times
res = np.zeros(self.domain[0].shape) res = np.zeros(self.domain[0].shape)
xglob = x.to_global_data() xglob = x.to_global_data()
res[-1] = np.sum(xglob) * self.sigmas[-1] res[-1] = np.sum(xglob) * self._sigmas[-1]
for i in range(self.ndim): for i in range(self.ndim):
res[i] = np.sum(self.pos[i] * xglob) * self.sigmas[i] res[i] = np.sum(self.pos[i] * xglob) * self._sigmas[i]
return Field.from_global_data(self.domain, res) return Field.from_global_data(self.domain, res)
@property @property
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment