Commit 84c4d1e7 authored by Philipp Arras's avatar Philipp Arras
Browse files

Synchronize code and documentation

parent d02d1d26
......@@ -36,7 +36,7 @@ class BlockDiagonalOperator(EndomorphicOperator):
if not isinstance(domain, MultiDomain):
raise TypeError("MultiDomain expected")
self._domain = domain
self._ops = tuple(operators[key] for key in domain.keys())
self._ops = tuple(operators[key] if key in operators else None for key in domain.keys())
self._capability = self._all_ops
for op in self._ops:
if op is not None:
......@@ -47,6 +47,10 @@ class BlockDiagonalOperator(EndomorphicOperator):
else:
raise TypeError("LinearOperator expected")
def get_sqrt(self):
ops = {kk: vv.sqrt() for kk, vv in self._ops.items() if vv is not None}
return BlockDiagonalOperator(self._domain, ops)
def apply(self, x, mode):
self._check_input(x, mode)
val = tuple(op.apply(v, mode=mode) if op is not None else v
......
......@@ -71,3 +71,13 @@ def test_blockdiagonal():
f1 = op2(ift.full(dom, 1))
for val in f1.values():
assert_equal((val == 40).s_all(), True)
def test_blockdiagonal_nontrivial():
dom = ift.makeDomain({"d1": ift.RGSpace(10), "d2": ift.UnstructuredDomain(2)})
op = ift.BlockDiagonalOperator(dom, {"d1": ift.ScalingOperator(dom["d1"], 2)})
ift.extra.check_linear_operator(op)
assert op.domain == dom
fld = ift.from_random(dom)
ift.extra.assert_equal(op(fld)["d1"], 2*fld["d1"])
ift.extra.assert_equal(op(fld)["d2"], fld["d2"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment