Commit 835fc086 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add docu and cosmetics

parent 6774b35a
......@@ -28,11 +28,26 @@ from .linear_operator import LinearOperator
class DomainDistributor(LinearOperator):
"""A linear operator which broadcasts a field to a larger domain.
This DomainDistributor broadcasts a field which is defined on a
DomainTuple to a DomainTuple which contains the former as a subset. The
entries of the field are copied such that they are constant in the
direction of the new spaces.
Parameters
----------
target : Domain, tuple of Domain or DomainTuple
spaces : int or tuple of int
The elements of "target" which are taken as domain.
"""
def __init__(self, target, spaces):
self._target = DomainTuple.make(target)
self._spaces = utilities.parse_spaces(spaces, len(self._target))
self._domain = [tgt for i, tgt in enumerate(self._target)
if i in self._spaces]
self._domain = [
tgt for i, tgt in enumerate(self._target) if i in self._spaces
]
self._domain = DomainTuple.make(self._domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
......@@ -43,9 +58,9 @@ class DomainDistributor(LinearOperator):
shp = []
for i, tgt in enumerate(self._target):
tmp = tgt.shape if i > 0 else tgt.local_shape
shp += tmp if i in self._spaces else(1,)*len(tgt.shape)
shp += tmp if i in self._spaces else (1,)*len(tgt.shape)
ldat = np.broadcast_to(ldat.reshape(shp), self._target.local_shape)
return Field.from_local_data(self._target, ldat)
else:
return x.sum([s for s in range(len(x.domain))
if s not in self._spaces])
return x.sum(
[s for s in range(len(x.domain)) if s not in self._spaces])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment