Commit 0db9e4a6 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more fixes

parent 3f5859d0
Pipeline #21362 failed with stage
in 3 minutes and 54 seconds
......@@ -50,12 +50,18 @@ class PowerProjectionOperator(LinearOperator):
tgt[self._space] = power_space
self._target = DomainTuple.make(tgt)
# shopping list:
# 1) make sure that pindex is distributed in the same way as in the Field living on self.domain.
# 2) if the operated-on space is not distributed (i.e. if it is not space 0), _no_ further communication is necessary
def _times(self, x):
# harmonic field goes in
# pindex must be distributed in the same way as harmonic field
# power field must be available in full
pindex = self._target[self._space].pindex
res = Field.zeros(self._target, dtype=x.dtype)
if dobj.distaxis(x.val) in x.domain.axes[self._space]: # the distributed axis is part of the projected space
pindex = dobj.local_data(pindex)
else:
else: # pindex must be available fully on every task
pindex = dobj.to_global_data(pindex)
pindex.reshape((1, pindex.size, 1))
arr = dobj.local_data(x.weight(1).val)
......@@ -64,8 +70,15 @@ class PowerProjectionOperator(LinearOperator):
presize = np.prod(arr.shape[0:firstaxis], dtype=np.int)
postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int)
arr = arr.reshape((presize,pindex.size,postsize))
oarr = dobj.local_data(res.val).reshape((presize,-1,postsize))
oarr = np.zeros((presize,self._target[self._space].shape[0],postsize), dtype=x.dtype)
np.add.at(oarr, (slice(None), pindex.ravel(), slice(None)), arr)
if dobj.distaxis(x.val) in x.domain.axes[self._space]:
oarr = dobj.np_allreduce_sum(oarr)
oarr = oarr.reshape(self._target.shape)
res = Field(self._target, dobj.from_global_data(oarr))
else:
oarr = oarr.reshape(dobj.get_locshape(self._target.shape, dobj.distaxis(x.val)))
res = Field(self._target, dobj.from_local_data(self._target.shape,oarr,dobj.default_distaxis()))
return res.weight(-1, spaces=self._space)
def _adjoint_times(self, x):
......@@ -73,10 +86,11 @@ class PowerProjectionOperator(LinearOperator):
res = Field.empty(self._domain, dtype=x.dtype)
if dobj.distaxis(x.val) in x.domain.axes[self._space]: # the distributed axis is part of the projected space
pindex = dobj.local_data(pindex)
arr = dobj.to_global_data(x.val)
else:
pindex = dobj.to_global_data(pindex)
arr = dobj.local_data(x.val)
pindex = pindex.reshape((1, pindex.size, 1))
arr = dobj.local_data(x.val)
firstaxis = x.domain.axes[self._space][0]
lastaxis = x.domain.axes[self._space][-1]
presize = np.prod(arr.shape[0:firstaxis], dtype=np.int)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment