diff --git a/nifty5/linearization.py b/nifty5/linearization.py index babcd420d2609e5037b069a738735bdd77cb9f91..1affc0b65134603bb47de2a69428310d33acb2c7 100644 --- a/nifty5/linearization.py +++ b/nifty5/linearization.py @@ -461,7 +461,7 @@ class Linearization(object): if len(constants) == 0: return Linearization.make_var(field, want_metric) else: - ops = [ScalingOperator(0. if key in constants else 1., dom) - for key, dom in field.domain.items()] - bdop = BlockDiagonalOperator(field.domain, tuple(ops)) + ops = {key: ScalingOperator(0. if key in constants else 1., dom) + for key, dom in field.domain.items()} + bdop = BlockDiagonalOperator(field.domain, ops) return Linearization(field, bdop, want_metric=want_metric) diff --git a/nifty5/operators/block_diagonal_operator.py b/nifty5/operators/block_diagonal_operator.py index be29fc3ae6111be3832ad98917892009488de68f..72e3a72e6ff92fc73a7376b3bf5f9927053158d6 100644 --- a/nifty5/operators/block_diagonal_operator.py +++ b/nifty5/operators/block_diagonal_operator.py @@ -24,17 +24,16 @@ class BlockDiagonalOperator(EndomorphicOperator): """ Parameters ---------- + domain : MultiDomain + Domain and target of the operator. operators : dict - Dictionary with operators domain names as keys and LinearOperators as - items. + Dictionary with subdomain names as keys and LinearOperators as items. """ def __init__(self, domain, operators): if not isinstance(domain, MultiDomain): raise TypeError("MultiDomain expected") - if not isinstance(operators, tuple): - raise TypeError("tuple expected") self._domain = domain - self._ops = operators + self._ops = tuple(operators[key] for key in domain.keys()) self._capability = self._all_ops for op in self._ops: if op is not None: @@ -55,13 +54,14 @@ class BlockDiagonalOperator(EndomorphicOperator): def _combine_chain(self, op): if self._domain != op._domain: raise ValueError("domain mismatch") - res = tuple(v1(v2) for v1, v2 in zip(self._ops, op._ops)) + res = {key: v1(v2) + for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)} return BlockDiagonalOperator(self._domain, res) def _combine_sum(self, op, selfneg, opneg): from ..operators.sum_operator import SumOperator if self._domain != op._domain: raise ValueError("domain mismatch") - res = tuple(SumOperator.make([v1, v2], [selfneg, opneg]) - for v1, v2 in zip(self._ops, op._ops)) + res = {key: SumOperator.make([v1, v2], [selfneg, opneg]) + for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)} return BlockDiagonalOperator(self._domain, res) diff --git a/nifty5/sugar.py b/nifty5/sugar.py index 68484a326e56b8b1ebb4e9a177c2b88d674e1d0c..7e1342b3b5f073b266c9d1f4630a63ecc72af10b 100644 --- a/nifty5/sugar.py +++ b/nifty5/sugar.py @@ -363,7 +363,7 @@ def makeOp(input): return DiagonalOperator(input) if isinstance(input, MultiField): return BlockDiagonalOperator( - input.domain, tuple(makeOp(val) for val in input.values())) + input.domain, {key: makeOp(val) for key, val in enumerate(input)}) raise NotImplementedError diff --git a/test/test_multi_field.py b/test/test_multi_field.py index a316b3604421558ffb1c53a3c59b3fae9e7b2fcf..6990c038b08b137d5c1e986a5ff2c5e032f6fc15 100644 --- a/test/test_multi_field.py +++ b/test/test_multi_field.py @@ -43,7 +43,8 @@ def test_dataconv(): def test_blockdiagonal(): - op = ift.BlockDiagonalOperator(dom, (ift.ScalingOperator(20., dom["d1"]),)) + op = ift.BlockDiagonalOperator( + dom, {"d1": ift.ScalingOperator(20., dom["d1"])}) op2 = op(op) ift.extra.consistency_check(op2) assert_equal(type(op2), ift.BlockDiagonalOperator)