diff --git a/nifty5/multi/block_diagonal_operator.py b/nifty5/multi/block_diagonal_operator.py index 0546f35a1a89b99eb8dfd1258a1770805f49225e..3d9ec9fc32d9ca2dca5bddc2077b8e6a11814fdc 100644 --- a/nifty5/multi/block_diagonal_operator.py +++ b/nifty5/multi/block_diagonal_operator.py @@ -45,7 +45,7 @@ class BlockDiagonalOperator(EndomorphicOperator): if self._domain is not op._domain: raise ValueError("domain mismatch") res = {key : v1*v2 for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)} - return BlockDiagonalOperator(res) + return BlockDiagonalOperator(self._domain, res) def _combine_sum(self, op, selfneg, opneg): from ..operators.sum_operator import SumOperator @@ -54,4 +54,4 @@ class BlockDiagonalOperator(EndomorphicOperator): res = {} for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops): res[key] = SumOperator.make([v1, v2], [selfneg, opneg]) - return BlockDiagonalOperator(res) + return BlockDiagonalOperator(self._domain, res) diff --git a/test/test_multi_field.py b/test/test_multi_field.py index 9409b3f6a4b94950531f4fa296baeb77d13ec0bd..820d82da7a27080233414ee0030c8de6ceb9fc3b 100644 --- a/test/test_multi_field.py +++ b/test/test_multi_field.py @@ -39,8 +39,8 @@ class Test_Functionality(unittest.TestCase): assert_equal(val.local_data, f2[key].local_data) def test_blockdiagonal(self): - op = ift.BlockDiagonalOperator({"d1": - ift.ScalingOperator(20., dom["d1"])}) + op = ift.BlockDiagonalOperator( + dom, {"d1": ift.ScalingOperator(20., dom["d1"])}) op2 = op*op ift.extra.consistency_check(op2) assert_equal(type(op2), ift.BlockDiagonalOperator)