Skip to content
Snippets Groups Projects
Commit 93b3fdd9 authored by Philipp Arras's avatar Philipp Arras
Browse files

Fix PartialExtractor

parent 441e854e
Branches
No related tags found
1 merge request!368Add more automatic checks for operators
...@@ -338,12 +338,17 @@ class PartialExtractor(LinearOperator): ...@@ -338,12 +338,17 @@ class PartialExtractor(LinearOperator):
if self._domain[key] is not self._target[key]: if self._domain[key] is not self._target[key]:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
self._compldomain = MultiDomain.make({kk: self._domain[kk]
for kk in self._domain.keys()
if kk not in self._target.keys()})
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
if mode == self.TIMES: if mode == self.TIMES:
return x.extract(self._target) return x.extract(self._target)
return MultiField.from_dict({key: x[key] for key in x.domain.keys()}) res0 = MultiField.from_dict({key: x[key] for key in x.domain.keys()})
res1 = MultiField.full(self._compldomain, 0.)
return res0.unite(res1)
class MatrixProductOperator(EndomorphicOperator): class MatrixProductOperator(EndomorphicOperator):
......
...@@ -295,3 +295,15 @@ def testValueInserter(sp, seed): ...@@ -295,3 +295,15 @@ def testValueInserter(sp, seed):
ind.append(np.random.randint(0, ss-1)) ind.append(np.random.randint(0, ss-1))
op = ift.ValueInserter(sp, ind) op = ift.ValueInserter(sp, ind)
ift.extra.consistency_check(op) ift.extra.consistency_check(op)
@pmp('seed', [12, 3])
def testPartialExtractor(seed):
np.random.seed(seed)
tgt = {'a': ift.RGSpace(1), 'b': ift.RGSpace(2)}
dom = tgt.copy()
dom['c'] = ift.RGSpace(3)
dom = ift.MultiDomain.make(dom)
tgt = ift.MultiDomain.make(tgt)
op = ift.PartialExtractor(dom, tgt)
ift.extra.consistency_check(op)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment