From 93b3fdd91e33c79d65da13f91919b668369c1fa6 Mon Sep 17 00:00:00 2001 From: Philipp Arras <parras@mpa-garching.mpg.de> Date: Wed, 6 Nov 2019 12:04:36 +0100 Subject: [PATCH] Fix PartialExtractor --- nifty5/operators/simple_linear_operators.py | 7 ++++++- test/test_operators/test_adjoint.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/nifty5/operators/simple_linear_operators.py b/nifty5/operators/simple_linear_operators.py index 9fd8efab3..6a5a4cba4 100644 --- a/nifty5/operators/simple_linear_operators.py +++ b/nifty5/operators/simple_linear_operators.py @@ -338,12 +338,17 @@ class PartialExtractor(LinearOperator): if self._domain[key] is not self._target[key]: raise ValueError("domain mismatch") self._capability = self.TIMES | self.ADJOINT_TIMES + self._compldomain = MultiDomain.make({kk: self._domain[kk] + for kk in self._domain.keys() + if kk not in self._target.keys()}) def apply(self, x, mode): self._check_input(x, mode) if mode == self.TIMES: return x.extract(self._target) - return MultiField.from_dict({key: x[key] for key in x.domain.keys()}) + res0 = MultiField.from_dict({key: x[key] for key in x.domain.keys()}) + res1 = MultiField.full(self._compldomain, 0.) + return res0.unite(res1) class MatrixProductOperator(EndomorphicOperator): diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py index a6782a6c0..08e3d8106 100644 --- a/test/test_operators/test_adjoint.py +++ b/test/test_operators/test_adjoint.py @@ -295,3 +295,15 @@ def testValueInserter(sp, seed): ind.append(np.random.randint(0, ss-1)) op = ift.ValueInserter(sp, ind) ift.extra.consistency_check(op) + + +@pmp('seed', [12, 3]) +def testPartialExtractor(seed): + np.random.seed(seed) + tgt = {'a': ift.RGSpace(1), 'b': ift.RGSpace(2)} + dom = tgt.copy() + dom['c'] = ift.RGSpace(3) + dom = ift.MultiDomain.make(dom) + tgt = ift.MultiDomain.make(tgt) + op = ift.PartialExtractor(dom, tgt) + ift.extra.consistency_check(op) -- GitLab