Commit c25acc23 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add infront functionality, very preliminary

parent 4212cf8f
......@@ -67,18 +67,26 @@ class DomainDistributor(LinearOperator):
class DomainTupleFieldInserter(LinearOperator):
def __init__(self, domain, new_space, ind):
def __init__(self, domain, new_space, ind, infront=False):
'''Writes the content of a field into one slice of a DomainTuple.
Parameters
----------
domain : Domain, tuple of Domain or DomainTuple
new_space : Domain, tuple of Domain or DomainTuple
ind : Index of the same space as new_space
ind : Integer
Index of the same space as new_space
infront : Boolean
If true, the new domain is added in the beginning of the
DomainTuple. Otherwise it is added at the end.
'''
# FIXME Add assertions
self._domain = DomainTuple.make(domain)
self._target = DomainTuple.make(list(self.domain) + [new_space])
if infront:
self._target = DomainTuple.make([new_space] + list(self.domain))
else:
self._target = DomainTuple.make(list(self.domain) + [new_space])
self._infront = infront
self._capability = self.TIMES | self.ADJOINT_TIMES
self._ind = ind
......@@ -86,8 +94,16 @@ class DomainTupleFieldInserter(LinearOperator):
self._check_input(x, mode)
if mode == self.TIMES:
res = np.zeros(self.target.shape, dtype=x.dtype)
res[..., self._ind] = x.to_global_data()
if self._infront:
res[self._ind] = x.to_global_data()
else:
res[..., self._ind] = x.to_global_data()
return Field.from_global_data(self.target, res)
else:
return Field.from_global_data(self.domain,
x.to_global_data()[..., self._ind])
if self._infront:
return Field.from_global_data(self.domain,
x.to_global_data()[self._ind])
else:
return Field.from_global_data(
self.domain,
x.to_global_data()[..., self._ind])
......@@ -194,12 +194,13 @@ class Consistency_Tests(unittest.TestCase):
op = ift.DomainDistributor(dom, spaces)
ift.extra.consistency_check(op, dtype, dtype)
def testDomainTupleFieldInserter(self):
@expand(product([True, False]))
def testDomainTupleFieldInserter(self, infront):
domain = ift.DomainTuple.make((ift.UnstructuredDomain(12),
ift.RGSpace([4, 22])))
new_space = ift.UnstructuredDomain(7)
ind = 5
op = ift.DomainTupleFieldInserter(domain, new_space, ind)
op = ift.DomainTupleFieldInserter(domain, new_space, ind, infront)
ift.extra.consistency_check(op)
@expand(product([0, 2], [np.float64, np.complex128]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment