Commit 2ac52b09 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

fixes

parent 0e747fa3
......@@ -28,6 +28,7 @@ from ..sugar import full, makeDomain
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
from .domain_tuple_field_inserter import DomainTupleFieldInserter
from .. import utilities
class VdotOperator(LinearOperator):
......@@ -46,13 +47,15 @@ class VdotOperator(LinearOperator):
class SumReductionOperator(LinearOperator):
def __init__(self, domain, spaces=None):
self._spaces = spaces
self._domain = domain
if spaces is None:
self._domain = DomainTuple.make(domain)
self._spaces = utilities.parse_spaces(spaces, len(self._domain))
if len(self._spaces) == len(self._domain):
self._spaces = None
if self._spaces is None:
self._target = DomainTuple.scalar_domain()
else:
self._target = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if not(i == spaces)))
self._marg_space = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if (i == spaces)))
self._target = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if not(i in self._spaces)))
self._marg_space = makeDomain(tuple(dom for i, dom in enumerate(self._domain) if (i in self._spaces)))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
......@@ -65,11 +68,7 @@ class SumReductionOperator(LinearOperator):
if self._spaces is None:
return full(self._domain, x.local_data[()])
else:
if isinstance(self._spaces, int):
sp = (self._spaces, )
else:
sp = self._spaces
for i in sp:
for i in self._spaces:
ns = self._domain._dom[i]
ps = tuple(i - 1 for i in ns.shape)
dtfi = DomainTupleFieldInserter(domain=self._target, new_space=ns, index=i, position=ps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment