Commit f4b1c9d7 authored by Philipp Arras's avatar Philipp Arras
Browse files

Refactor DomainTupleFieldInserter

parent f8d22093
......@@ -27,43 +27,32 @@ from .linear_operator import LinearOperator
class DomainTupleFieldInserter(LinearOperator):
def __init__(self, domain, new_space, ind, infront=False):
def __init__(self, domain, new_space, index, position):
'''Writes the content of a field into one slice of a DomainTuple.
Parameters
----------
domain : Domain, tuple of Domain or DomainTuple
new_space : Domain, tuple of Domain or DomainTuple
ind : Integer
Index of the same space as new_space
infront : Boolean
If true, the new domain is added in the beginning of the
DomainTuple. Otherwise it is added at the end.
index : Integer
Position at which new_space shall be added to domain.
position : tuple
Slice in new_space at which the field shall be inserted.
'''
# FIXME Add assertions
self._domain = DomainTuple.make(domain)
if infront:
self._target = DomainTuple.make([new_space] + list(self.domain))
else:
self._target = DomainTuple.make(list(self.domain) + [new_space])
self._infront = infront
tgt = list(self.domain)
tgt.insert(index, new_space)
self._target = DomainTuple.make(tgt)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._ind = ind
self._slc = (slice(None),)*index + position + (slice(None),)*(
len(self.domain.shape) - index)
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
res = np.zeros(self.target.shape, dtype=x.dtype)
if self._infront:
res[self._ind] = x.to_global_data()
else:
res[..., self._ind] = x.to_global_data()
res[self._slc] = x.to_global_data()
return Field.from_global_data(self.target, res)
else:
if self._infront:
return Field.from_global_data(self.domain,
x.to_global_data()[self._ind])
else:
return Field.from_global_data(
self.domain,
x.to_global_data()[..., self._ind])
return Field.from_global_data(self.domain,
x.to_global_data()[self._slc])
......@@ -194,13 +194,12 @@ class Consistency_Tests(unittest.TestCase):
op = ift.ContractionOperator(dom, spaces)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([True, False]))
def testDomainTupleFieldInserter(self, infront):
def testDomainTupleFieldInserter(self):
domain = ift.DomainTuple.make((ift.UnstructuredDomain(12),
ift.RGSpace([4, 22])))
new_space = ift.UnstructuredDomain(7)
ind = 5
op = ift.DomainTupleFieldInserter(domain, new_space, ind, infront)
pos = (5,)
op = ift.DomainTupleFieldInserter(domain, new_space, 0, pos)
ift.extra.consistency_check(op)
@expand(product([0, 2], [np.float64, np.complex128]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment