diff --git a/nifty5/operators/simple_linear_operators.py b/nifty5/operators/simple_linear_operators.py index 9fd8efab3fbc08f1295bfa80812ee2b949b651a0..6a5a4cba4599d04bb05430fd508f109bc842ac98 100644 --- a/nifty5/operators/simple_linear_operators.py +++ b/nifty5/operators/simple_linear_operators.py @@ -338,12 +338,17 @@ class PartialExtractor(LinearOperator): if self._domain[key] is not self._target[key]: raise ValueError("domain mismatch") self._capability = self.TIMES | self.ADJOINT_TIMES + self._compldomain = MultiDomain.make({kk: self._domain[kk] + for kk in self._domain.keys() + if kk not in self._target.keys()}) def apply(self, x, mode): self._check_input(x, mode) if mode == self.TIMES: return x.extract(self._target) - return MultiField.from_dict({key: x[key] for key in x.domain.keys()}) + res0 = MultiField.from_dict({key: x[key] for key in x.domain.keys()}) + res1 = MultiField.full(self._compldomain, 0.) + return res0.unite(res1) class MatrixProductOperator(EndomorphicOperator): diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py index a6782a6c0d96d81498c0263724692af4de1873d3..08e3d8106970d1ea3359adbeae8b2375ed083543 100644 --- a/test/test_operators/test_adjoint.py +++ b/test/test_operators/test_adjoint.py @@ -295,3 +295,15 @@ def testValueInserter(sp, seed): ind.append(np.random.randint(0, ss-1)) op = ift.ValueInserter(sp, ind) ift.extra.consistency_check(op) + + +@pmp('seed', [12, 3]) +def testPartialExtractor(seed): + np.random.seed(seed) + tgt = {'a': ift.RGSpace(1), 'b': ift.RGSpace(2)} + dom = tgt.copy() + dom['c'] = ift.RGSpace(3) + dom = ift.MultiDomain.make(dom) + tgt = ift.MultiDomain.make(tgt) + op = ift.PartialExtractor(dom, tgt) + ift.extra.consistency_check(op)