Commit 4212cf8f authored by Philipp Arras's avatar Philipp Arras
Browse files

Add DomainTupleFieldInserter

parent 63789fe6
......@@ -22,7 +22,7 @@ from .operators.operator import Operator
from .operators.central_zero_padder import CentralZeroPadder
from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_distributor import DomainDistributor
from .operators.domain_tuple_operators import DomainDistributor, DomainTupleFieldInserter
from .operators.endomorphic_operator import EndomorphicOperator
from .operators.exp_transform import ExpTransform
from .operators.harmonic_operators import (
......
......@@ -22,7 +22,7 @@ from ..compat import *
from ..domain_tuple import DomainTuple
from ..multi_domain import MultiDomain
from ..operators.distributors import PowerDistributor
from ..operators.domain_distributor import DomainDistributor
from ..operators.domain_tuple_operators import DomainDistributor
from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.simple_linear_operators import FieldAdapter
from ..sugar import exp
......
......@@ -64,3 +64,30 @@ class DomainDistributor(LinearOperator):
else:
return x.sum(
[s for s in range(len(x.domain)) if s not in self._spaces])
class DomainTupleFieldInserter(LinearOperator):
def __init__(self, domain, new_space, ind):
'''Writes the content of a field into one slice of a DomainTuple.
Parameters
----------
domain : Domain, tuple of Domain or DomainTuple
new_space : Domain, tuple of Domain or DomainTuple
ind : Index of the same space as new_space
'''
# FIXME Add assertions
self._domain = DomainTuple.make(domain)
self._target = DomainTuple.make(list(self.domain) + [new_space])
self._capability = self.TIMES | self.ADJOINT_TIMES
self._ind = ind
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
res = np.zeros(self.target.shape, dtype=x.dtype)
res[..., self._ind] = x.to_global_data()
return Field.from_global_data(self.target, res)
else:
return Field.from_global_data(self.domain,
x.to_global_data()[..., self._ind])
......@@ -194,6 +194,14 @@ class Consistency_Tests(unittest.TestCase):
op = ift.DomainDistributor(dom, spaces)
ift.extra.consistency_check(op, dtype, dtype)
def testDomainTupleFieldInserter(self):
domain = ift.DomainTuple.make((ift.UnstructuredDomain(12),
ift.RGSpace([4, 22])))
new_space = ift.UnstructuredDomain(7)
ind = 5
op = ift.DomainTupleFieldInserter(domain, new_space, ind)
ift.extra.consistency_check(op)
@expand(product([0, 2], [np.float64, np.complex128]))
def testSymmetrizingOperator(self, space, dtype):
dom = (ift.LogRGSpace(10, [2.], [1.]), ift.UnstructuredDomain(13),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment