Skip to content
Snippets Groups Projects
Commit 0327ffdc authored by Philipp Arras's avatar Philipp Arras
Browse files

Adapt interface of DomainTupleFieldInserter

parent 4f79d69c
No related branches found
No related tags found
No related merge requests found
...@@ -23,31 +23,38 @@ from .linear_operator import LinearOperator ...@@ -23,31 +23,38 @@ from .linear_operator import LinearOperator
class DomainTupleFieldInserter(LinearOperator): class DomainTupleFieldInserter(LinearOperator):
"""Writes the content of a :class:`Field` into one slice of a :class:`DomainTuple`. """Writes the content of a :class:`Field` into one slice of a
:class:`DomainTuple`.
Parameters Parameters
---------- ----------
domain : Domain, tuple of Domain or DomainTuple target : Domain, tuple of Domain or DomainTuple
new_space : Domain, tuple of Domain or DomainTuple space : int
index : Integer The index of the sub-domain which is inserted.
Index at which new_space shall be added to domain. index : tuple
position : tuple Slice in new sub-domain in which the input field shall be written into.
Slice in new_space in which the input field shall be written into.
""" """
def __init__(self, domain, new_space, index, position):
self._domain = DomainTuple.make(domain) def __init__(self, target, space, pos):
tgt = list(self.domain) if not space <= len(target) or space < 0:
tgt.insert(index, new_space) raise ValueError
self._target = DomainTuple.make(tgt) self._target = DomainTuple.make(target)
dom = list(self.target)
dom.pop(space)
self._domain = DomainTuple.make(dom)
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
fst_dims = sum(len(dd.shape) for dd in self.domain[:index])
new_space = target[space]
nshp = new_space.shape nshp = new_space.shape
if len(position) != len(nshp): fst_dims = sum(len(dd.shape) for dd in self.target[:space])
if len(pos) != len(nshp):
raise ValueError("shape mismatch between new_space and position") raise ValueError("shape mismatch between new_space and position")
for s, p in zip(nshp, position): for s, p in zip(nshp, pos):
if p < 0 or p >= s: if p < 0 or p >= s:
raise ValueError("bad position value") raise ValueError("bad position value")
self._slc = (slice(None),)*fst_dims + position
self._slc = (slice(None),)*fst_dims + pos
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
......
...@@ -189,11 +189,10 @@ def testContractionOperator(spaces, wgt, dtype): ...@@ -189,11 +189,10 @@ def testContractionOperator(spaces, wgt, dtype):
def testDomainTupleFieldInserter(): def testDomainTupleFieldInserter():
domain = ift.DomainTuple.make((ift.UnstructuredDomain(12), target = ift.DomainTuple.make((ift.UnstructuredDomain([3, 2]),
ift.UnstructuredDomain(7),
ift.RGSpace([4, 22]))) ift.RGSpace([4, 22])))
new_space = ift.UnstructuredDomain(7) op = ift.DomainTupleFieldInserter(target, 1, (5,))
pos = (5,)
op = ift.DomainTupleFieldInserter(domain, new_space, 0, pos)
ift.extra.consistency_check(op) ift.extra.consistency_check(op)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment