Commit 133bf484 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

fix interface of BlockDiagonalOperator

parent 3996cb92
......@@ -461,7 +461,7 @@ class Linearization(object):
if len(constants) == 0:
return Linearization.make_var(field, want_metric)
else:
ops = [ScalingOperator(0. if key in constants else 1., dom)
for key, dom in field.domain.items()]
bdop = BlockDiagonalOperator(field.domain, tuple(ops))
ops = {key: ScalingOperator(0. if key in constants else 1., dom)
for key, dom in field.domain.items()}
bdop = BlockDiagonalOperator(field.domain, ops)
return Linearization(field, bdop, want_metric=want_metric)
......@@ -24,17 +24,16 @@ class BlockDiagonalOperator(EndomorphicOperator):
"""
Parameters
----------
domain : MultiDomain
Domain and target of the operator.
operators : dict
Dictionary with operators domain names as keys and LinearOperators as
items.
Dictionary with subdomain names as keys and LinearOperators as items.
"""
def __init__(self, domain, operators):
if not isinstance(domain, MultiDomain):
raise TypeError("MultiDomain expected")
if not isinstance(operators, tuple):
raise TypeError("tuple expected")
self._domain = domain
self._ops = operators
self._ops = tuple(operators[key] for key in domain.keys())
self._capability = self._all_ops
for op in self._ops:
if op is not None:
......@@ -55,13 +54,14 @@ class BlockDiagonalOperator(EndomorphicOperator):
def _combine_chain(self, op):
if self._domain != op._domain:
raise ValueError("domain mismatch")
res = tuple(v1(v2) for v1, v2 in zip(self._ops, op._ops))
res = {key: v1(v2)
for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)}
return BlockDiagonalOperator(self._domain, res)
def _combine_sum(self, op, selfneg, opneg):
from ..operators.sum_operator import SumOperator
if self._domain != op._domain:
raise ValueError("domain mismatch")
res = tuple(SumOperator.make([v1, v2], [selfneg, opneg])
for v1, v2 in zip(self._ops, op._ops))
res = {key: SumOperator.make([v1, v2], [selfneg, opneg])
for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)}
return BlockDiagonalOperator(self._domain, res)
......@@ -363,7 +363,7 @@ def makeOp(input):
return DiagonalOperator(input)
if isinstance(input, MultiField):
return BlockDiagonalOperator(
input.domain, tuple(makeOp(val) for val in input.values()))
input.domain, {key: makeOp(val) for key, val in enumerate(input)})
raise NotImplementedError
......
......@@ -43,7 +43,8 @@ def test_dataconv():
def test_blockdiagonal():
op = ift.BlockDiagonalOperator(dom, (ift.ScalingOperator(20., dom["d1"]),))
op = ift.BlockDiagonalOperator(
dom, {"d1": ift.ScalingOperator(20., dom["d1"])})
op2 = op(op)
ift.extra.consistency_check(op2)
assert_equal(type(op2), ift.BlockDiagonalOperator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment