Commit 93b3fdd9 authored by Philipp Arras's avatar Philipp Arras
Browse files

Fix PartialExtractor

parent 441e854e
......@@ -338,12 +338,17 @@ class PartialExtractor(LinearOperator):
if self._domain[key] is not self._target[key]:
raise ValueError("domain mismatch")
self._capability = self.TIMES | self.ADJOINT_TIMES
self._compldomain = MultiDomain.make({kk: self._domain[kk]
for kk in self._domain.keys()
if kk not in self._target.keys()})
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return x.extract(self._target)
return MultiField.from_dict({key: x[key] for key in x.domain.keys()})
res0 = MultiField.from_dict({key: x[key] for key in x.domain.keys()})
res1 = MultiField.full(self._compldomain, 0.)
return res0.unite(res1)
class MatrixProductOperator(EndomorphicOperator):
......@@ -295,3 +295,15 @@ def testValueInserter(sp, seed):
ind.append(np.random.randint(0, ss-1))
op = ift.ValueInserter(sp, ind)
@pmp('seed', [12, 3])
def testPartialExtractor(seed):
tgt = {'a': ift.RGSpace(1), 'b': ift.RGSpace(2)}
dom = tgt.copy()
dom['c'] = ift.RGSpace(3)
dom = ift.MultiDomain.make(dom)
tgt = ift.MultiDomain.make(tgt)
op = ift.PartialExtractor(dom, tgt)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment