Commit 17013e3d by Martin Reinecke

### more SumOperator optimizations; new tests

parent 1509dfd8
Pipeline #23528 passed with stage
in 4 minutes and 36 seconds
 ... ... @@ -13,5 +13,6 @@ from .power_projection_operator import PowerProjectionOperator from .dof_projection_operator import DOFProjectionOperator from .chain_operator import ChainOperator from .sum_operator import SumOperator from .scaling_operator import ScalingOperator from .inverse_operator import InverseOperator from .adjoint_operator import AdjointOperator
 ... ... @@ -41,7 +41,7 @@ class SumOperator(LinearOperator): # Step 2: unpack SumOperators opsnew = [] negnew = [] for op, ng in zip (ops, neg): for op, ng in zip(ops, neg): if isinstance(op, SumOperator): opsnew += op._ops if ng: ... ... @@ -81,14 +81,33 @@ class SumOperator(LinearOperator): ops = opsnew neg = negnew # Step 4: combine DiagonalOperators where possible # (TBD) processed = [False] * len(ops) opsnew = [] negnew = [] for i in range(len(ops)): if not processed[i]: if isinstance(ops[i], DiagonalOperator): diag = ops[i].diagonal()*(-1 if neg[i] else 1) for j in range(i+1, len(ops)): if (isinstance(ops[j], DiagonalOperator) and ops[i]._spaces == ops[j]._spaces): diag += ops[j].diagonal()*(-1 if neg[j] else 1) processed[j] = True opsnew.append(DiagonalOperator(diag, ops[i].domain, ops[i]._spaces)) negnew.append(False) else: opsnew.append(ops[i]) negnew.append(neg[i]) ops = opsnew neg = negnew return ops, neg @staticmethod def make(ops, neg): ops = tuple(ops) neg = tuple(neg) if len(ops)!= len(neg): if len(ops) != len(neg): raise ValueError("length mismatch between ops and neg") ops, neg = SumOperator.simplify(ops, neg) if len(ops) == 1 and not neg[0]: ... ...
 import unittest from numpy.testing import assert_allclose from numpy.testing import assert_allclose, assert_equal import nifty2go as ift from test.common import generate_spaces from itertools import product ... ... @@ -41,3 +41,23 @@ class ComposedOperator_Tests(unittest.TestCase): assert_allclose(ift.dobj.to_global_data(tt1.val), ift.dobj.to_global_data(rand1.val)) @expand(product(spaces)) def test_sum(self, space): op1 = ift.DiagonalOperator(ift.Field(space, 2.)) op2 = ift.ScalingOperator(3., space) full_op = op1 + op2 - (op2 - op1) + op1 + op1 + op2 x = ift.Field(space, 1.) res = full_op(x) assert_equal(isinstance(full_op, ift.DiagonalOperator), True) assert_allclose(ift.dobj.to_global_data(res.val), 11.) @expand(product(spaces)) def test_chain(self, space): op1 = ift.DiagonalOperator(ift.Field(space, 2.)) op2 = ift.ScalingOperator(3., space) full_op = op1 * op2 * (op2 * op1) * op1 * op1 * op2 x = ift.Field(space, 1.) res = full_op(x) assert_equal(isinstance(full_op, ift.DiagonalOperator), True) assert_allclose(ift.dobj.to_global_data(res.val), 432.)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!