Commit 2fc8a9fc authored by Martin Reinecke's avatar Martin Reinecke

make use of synergies

parent 9b6d9324
Pipeline #22718 passed with stage
in 4 minutes and 45 seconds
......@@ -43,15 +43,18 @@ class DOFProjectionOperator(LinearOperator):
if (wgt == 0).any():
raise ValueError("empty bins detected")
self._init2(dofdex.val, space, DOFSpace(wgt))
def _init2(self, dofdex, space, other_space):
self._space = space
tgt = list(self._domain)
tgt[self._space] = DOFSpace(wgt)
tgt[self._space] = other_space
self._target = DomainTuple.make(tgt)
if dobj.default_distaxis() in self.domain.axes[self._space]:
dofdex = dobj.local_data(dofdex.val)
if dobj.default_distaxis() in self._domain.axes[self._space]:
dofdex = dobj.local_data(dofdex)
else: # dofdex must be available fully on every task
dofdex = dobj.to_global_data(dofdex.val)
dofdex = dobj.to_global_data(dofdex)
self._dofdex = dofdex.ravel()
firstaxis = self._domain.axes[self._space][0]
lastaxis = self._domain.axes[self._space][-1]
......
......@@ -16,17 +16,14 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .. import Field, DomainTuple
from ..spaces import PowerSpace
from .linear_operator import LinearOperator
from .. import dobj
import numpy as np
from .dof_projection_operator import DOFProjectionOperator
from .. import Field, DomainTuple, dobj
from ..spaces import PowerSpace
class PowerProjectionOperator(LinearOperator):
class PowerProjectionOperator(DOFProjectionOperator):
def __init__(self, domain, power_space=None, space=None):
super(PowerProjectionOperator, self).__init__()
# Initialize domain and target
self._domain = DomainTuple.make(domain)
if space is None and len(self._domain) == 1:
......@@ -45,60 +42,4 @@ class PowerProjectionOperator(LinearOperator):
if power_space.harmonic_partner != hspace:
raise ValueError("power_space does not match its partner")
self._space = space
tgt = list(self._domain)
tgt[self._space] = power_space
self._target = DomainTuple.make(tgt)
pindex = self._target[self._space].pindex
if dobj.default_distaxis() in self.domain.axes[self._space]:
pindex = dobj.local_data(pindex)
else: # pindex must be available fully on every task
pindex = dobj.to_global_data(pindex)
self._pindex = pindex.ravel()
firstaxis = self._domain.axes[self._space][0]
lastaxis = self._domain.axes[self._space][-1]
arrshape = dobj.local_shape(self._domain.shape, 0)
presize = np.prod(arrshape[0:firstaxis], dtype=np.int)
postsize = np.prod(arrshape[lastaxis+1:], dtype=np.int)
self._hshape = (presize, self._target[self._space].shape[0], postsize)
self._pshape = (presize, self._pindex.size, postsize)
def _times(self, x):
arr = dobj.local_data(x.weight(1).val)
arr = arr.reshape(self._pshape)
oarr = np.zeros(self._hshape, dtype=x.dtype)
np.add.at(oarr, (slice(None), self._pindex, slice(None)), arr)
if dobj.distaxis(x.val) in x.domain.axes[self._space]:
oarr = dobj.np_allreduce_sum(oarr).reshape(self._target.shape)
res = Field(self._target, dobj.from_global_data(oarr))
else:
oarr = oarr.reshape(dobj.local_shape(self._target.shape,
dobj.distaxis(x.val)))
res = Field(self._target,
dobj.from_local_data(self._target.shape, oarr,
dobj.default_distaxis()))
return res.weight(-1, spaces=self._space)
def _adjoint_times(self, x):
res = Field.empty(self._domain, dtype=x.dtype)
if dobj.distaxis(x.val) in x.domain.axes[self._space]:
arr = dobj.to_global_data(x.val)
else:
arr = dobj.local_data(x.val)
arr = arr.reshape(self._hshape)
oarr = dobj.local_data(res.val).reshape(self._pshape)
oarr[()] = arr[(slice(None), self._pindex, slice(None))]
return res
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def unitary(self):
return False
self._init2(power_space.pindex, space, power_space)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment