Commit 133bf484 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

fix interface of BlockDiagonalOperator

parent 3996cb92
...@@ -461,7 +461,7 @@ class Linearization(object): ...@@ -461,7 +461,7 @@ class Linearization(object):
if len(constants) == 0: if len(constants) == 0:
return Linearization.make_var(field, want_metric) return Linearization.make_var(field, want_metric)
else: else:
ops = [ScalingOperator(0. if key in constants else 1., dom) ops = {key: ScalingOperator(0. if key in constants else 1., dom)
for key, dom in field.domain.items()] for key, dom in field.domain.items()}
bdop = BlockDiagonalOperator(field.domain, tuple(ops)) bdop = BlockDiagonalOperator(field.domain, ops)
return Linearization(field, bdop, want_metric=want_metric) return Linearization(field, bdop, want_metric=want_metric)
...@@ -24,17 +24,16 @@ class BlockDiagonalOperator(EndomorphicOperator): ...@@ -24,17 +24,16 @@ class BlockDiagonalOperator(EndomorphicOperator):
""" """
Parameters Parameters
---------- ----------
domain : MultiDomain
Domain and target of the operator.
operators : dict operators : dict
Dictionary with operators domain names as keys and LinearOperators as Dictionary with subdomain names as keys and LinearOperators as items.
items.
""" """
def __init__(self, domain, operators): def __init__(self, domain, operators):
if not isinstance(domain, MultiDomain): if not isinstance(domain, MultiDomain):
raise TypeError("MultiDomain expected") raise TypeError("MultiDomain expected")
if not isinstance(operators, tuple):
raise TypeError("tuple expected")
self._domain = domain self._domain = domain
self._ops = operators self._ops = tuple(operators[key] for key in domain.keys())
self._capability = self._all_ops self._capability = self._all_ops
for op in self._ops: for op in self._ops:
if op is not None: if op is not None:
...@@ -55,13 +54,14 @@ class BlockDiagonalOperator(EndomorphicOperator): ...@@ -55,13 +54,14 @@ class BlockDiagonalOperator(EndomorphicOperator):
def _combine_chain(self, op): def _combine_chain(self, op):
if self._domain != op._domain: if self._domain != op._domain:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
res = tuple(v1(v2) for v1, v2 in zip(self._ops, op._ops)) res = {key: v1(v2)
for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)}
return BlockDiagonalOperator(self._domain, res) return BlockDiagonalOperator(self._domain, res)
def _combine_sum(self, op, selfneg, opneg): def _combine_sum(self, op, selfneg, opneg):
from ..operators.sum_operator import SumOperator from ..operators.sum_operator import SumOperator
if self._domain != op._domain: if self._domain != op._domain:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
res = tuple(SumOperator.make([v1, v2], [selfneg, opneg]) res = {key: SumOperator.make([v1, v2], [selfneg, opneg])
for v1, v2 in zip(self._ops, op._ops)) for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)}
return BlockDiagonalOperator(self._domain, res) return BlockDiagonalOperator(self._domain, res)
...@@ -363,7 +363,7 @@ def makeOp(input): ...@@ -363,7 +363,7 @@ def makeOp(input):
return DiagonalOperator(input) return DiagonalOperator(input)
if isinstance(input, MultiField): if isinstance(input, MultiField):
return BlockDiagonalOperator( return BlockDiagonalOperator(
input.domain, tuple(makeOp(val) for val in input.values())) input.domain, {key: makeOp(val) for key, val in enumerate(input)})
raise NotImplementedError raise NotImplementedError
......
...@@ -43,7 +43,8 @@ def test_dataconv(): ...@@ -43,7 +43,8 @@ def test_dataconv():
def test_blockdiagonal(): def test_blockdiagonal():
op = ift.BlockDiagonalOperator(dom, (ift.ScalingOperator(20., dom["d1"]),)) op = ift.BlockDiagonalOperator(
dom, {"d1": ift.ScalingOperator(20., dom["d1"])})
op2 = op(op) op2 = op(op)
ift.extra.consistency_check(op2) ift.extra.consistency_check(op2)
assert_equal(type(op2), ift.BlockDiagonalOperator) assert_equal(type(op2), ift.BlockDiagonalOperator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment