Commit 0327ffdc authored by Philipp Arras's avatar Philipp Arras
Browse files

Adapt interface of DomainTupleFieldInserter

parent 4f79d69c
......@@ -23,31 +23,38 @@ from .linear_operator import LinearOperator
class DomainTupleFieldInserter(LinearOperator):
"""Writes the content of a :class:`Field` into one slice of a :class:`DomainTuple`.
"""Writes the content of a :class:`Field` into one slice of a
:class:`DomainTuple`.
Parameters
----------
domain : Domain, tuple of Domain or DomainTuple
new_space : Domain, tuple of Domain or DomainTuple
index : Integer
Index at which new_space shall be added to domain.
position : tuple
Slice in new_space in which the input field shall be written into.
target : Domain, tuple of Domain or DomainTuple
space : int
The index of the sub-domain which is inserted.
index : tuple
Slice in new sub-domain in which the input field shall be written into.
"""
def __init__(self, domain, new_space, index, position):
self._domain = DomainTuple.make(domain)
tgt = list(self.domain)
tgt.insert(index, new_space)
self._target = DomainTuple.make(tgt)
def __init__(self, target, space, pos):
if not space <= len(target) or space < 0:
raise ValueError
self._target = DomainTuple.make(target)
dom = list(self.target)
dom.pop(space)
self._domain = DomainTuple.make(dom)
self._capability = self.TIMES | self.ADJOINT_TIMES
fst_dims = sum(len(dd.shape) for dd in self.domain[:index])
new_space = target[space]
nshp = new_space.shape
if len(position) != len(nshp):
fst_dims = sum(len(dd.shape) for dd in self.target[:space])
if len(pos) != len(nshp):
raise ValueError("shape mismatch between new_space and position")
for s, p in zip(nshp, position):
for s, p in zip(nshp, pos):
if p < 0 or p >= s:
raise ValueError("bad position value")
self._slc = (slice(None),)*fst_dims + position
self._slc = (slice(None),)*fst_dims + pos
def apply(self, x, mode):
self._check_input(x, mode)
......
......@@ -189,11 +189,10 @@ def testContractionOperator(spaces, wgt, dtype):
def testDomainTupleFieldInserter():
domain = ift.DomainTuple.make((ift.UnstructuredDomain(12),
target = ift.DomainTuple.make((ift.UnstructuredDomain([3, 2]),
ift.UnstructuredDomain(7),
ift.RGSpace([4, 22])))
new_space = ift.UnstructuredDomain(7)
pos = (5,)
op = ift.DomainTupleFieldInserter(domain, new_space, 0, pos)
op = ift.DomainTupleFieldInserter(target, 1, (5,))
ift.extra.consistency_check(op)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment