diff --git a/nifty/operators/__init__.py b/nifty/operators/__init__.py index 5c722d2bba208ea930e00a28a837b7e87bf4137a..ef620e45a3f3a4420f0327dc959f3f43cd725cbf 100644 --- a/nifty/operators/__init__.py +++ b/nifty/operators/__init__.py @@ -13,5 +13,6 @@ from .power_projection_operator import PowerProjectionOperator from .dof_projection_operator import DOFProjectionOperator from .chain_operator import ChainOperator from .sum_operator import SumOperator +from .scaling_operator import ScalingOperator from .inverse_operator import InverseOperator from .adjoint_operator import AdjointOperator diff --git a/nifty/operators/sum_operator.py b/nifty/operators/sum_operator.py index fb9f39938abd16e6612a8ed0507f70869d5ad7f4..09125fc22661829fc22719abdc95f5fee948fc18 100644 --- a/nifty/operators/sum_operator.py +++ b/nifty/operators/sum_operator.py @@ -41,7 +41,7 @@ class SumOperator(LinearOperator): # Step 2: unpack SumOperators opsnew = [] negnew = [] - for op, ng in zip (ops, neg): + for op, ng in zip(ops, neg): if isinstance(op, SumOperator): opsnew += op._ops if ng: @@ -81,14 +81,33 @@ class SumOperator(LinearOperator): ops = opsnew neg = negnew # Step 4: combine DiagonalOperators where possible - # (TBD) + processed = [False] * len(ops) + opsnew = [] + negnew = [] + for i in range(len(ops)): + if not processed[i]: + if isinstance(ops[i], DiagonalOperator): + diag = ops[i].diagonal()*(-1 if neg[i] else 1) + for j in range(i+1, len(ops)): + if (isinstance(ops[j], DiagonalOperator) and + ops[i]._spaces == ops[j]._spaces): + diag += ops[j].diagonal()*(-1 if neg[j] else 1) + processed[j] = True + opsnew.append(DiagonalOperator(diag, ops[i].domain, + ops[i]._spaces)) + negnew.append(False) + else: + opsnew.append(ops[i]) + negnew.append(neg[i]) + ops = opsnew + neg = negnew return ops, neg @staticmethod def make(ops, neg): ops = tuple(ops) neg = tuple(neg) - if len(ops)!= len(neg): + if len(ops) != len(neg): raise ValueError("length mismatch between ops and neg") ops, neg = SumOperator.simplify(ops, neg) if len(ops) == 1 and not neg[0]: diff --git a/test/test_operators/test_composed_operator.py b/test/test_operators/test_composed_operator.py index c82aa1ce684375b1b5e82e854f805ffdf6814c80..2e3d540cd922b5642b435e95e1adf25da47c15ee 100644 --- a/test/test_operators/test_composed_operator.py +++ b/test/test_operators/test_composed_operator.py @@ -1,5 +1,5 @@ import unittest -from numpy.testing import assert_allclose +from numpy.testing import assert_allclose, assert_equal import nifty2go as ift from test.common import generate_spaces from itertools import product @@ -41,3 +41,23 @@ class ComposedOperator_Tests(unittest.TestCase): assert_allclose(ift.dobj.to_global_data(tt1.val), ift.dobj.to_global_data(rand1.val)) + + @expand(product(spaces)) + def test_sum(self, space): + op1 = ift.DiagonalOperator(ift.Field(space, 2.)) + op2 = ift.ScalingOperator(3., space) + full_op = op1 + op2 - (op2 - op1) + op1 + op1 + op2 + x = ift.Field(space, 1.) + res = full_op(x) + assert_equal(isinstance(full_op, ift.DiagonalOperator), True) + assert_allclose(ift.dobj.to_global_data(res.val), 11.) + + @expand(product(spaces)) + def test_chain(self, space): + op1 = ift.DiagonalOperator(ift.Field(space, 2.)) + op2 = ift.ScalingOperator(3., space) + full_op = op1 * op2 * (op2 * op1) * op1 * op1 * op2 + x = ift.Field(space, 1.) + res = full_op(x) + assert_equal(isinstance(full_op, ift.DiagonalOperator), True) + assert_allclose(ift.dobj.to_global_data(res.val), 432.)