Commit 17013e3d authored by Martin Reinecke's avatar Martin Reinecke

more SumOperator optimizations; new tests

parent 1509dfd8
Pipeline #23528 passed with stage
in 4 minutes and 36 seconds
......@@ -13,5 +13,6 @@ from .power_projection_operator import PowerProjectionOperator
from .dof_projection_operator import DOFProjectionOperator
from .chain_operator import ChainOperator
from .sum_operator import SumOperator
from .scaling_operator import ScalingOperator
from .inverse_operator import InverseOperator
from .adjoint_operator import AdjointOperator
......@@ -41,7 +41,7 @@ class SumOperator(LinearOperator):
# Step 2: unpack SumOperators
opsnew = []
negnew = []
for op, ng in zip (ops, neg):
for op, ng in zip(ops, neg):
if isinstance(op, SumOperator):
opsnew += op._ops
if ng:
......@@ -81,14 +81,33 @@ class SumOperator(LinearOperator):
ops = opsnew
neg = negnew
# Step 4: combine DiagonalOperators where possible
# (TBD)
processed = [False] * len(ops)
opsnew = []
negnew = []
for i in range(len(ops)):
if not processed[i]:
if isinstance(ops[i], DiagonalOperator):
diag = ops[i].diagonal()*(-1 if neg[i] else 1)
for j in range(i+1, len(ops)):
if (isinstance(ops[j], DiagonalOperator) and
ops[i]._spaces == ops[j]._spaces):
diag += ops[j].diagonal()*(-1 if neg[j] else 1)
processed[j] = True
opsnew.append(DiagonalOperator(diag, ops[i].domain,
ops[i]._spaces))
negnew.append(False)
else:
opsnew.append(ops[i])
negnew.append(neg[i])
ops = opsnew
neg = negnew
return ops, neg
@staticmethod
def make(ops, neg):
ops = tuple(ops)
neg = tuple(neg)
if len(ops)!= len(neg):
if len(ops) != len(neg):
raise ValueError("length mismatch between ops and neg")
ops, neg = SumOperator.simplify(ops, neg)
if len(ops) == 1 and not neg[0]:
......
import unittest
from numpy.testing import assert_allclose
from numpy.testing import assert_allclose, assert_equal
import nifty2go as ift
from test.common import generate_spaces
from itertools import product
......@@ -41,3 +41,23 @@ class ComposedOperator_Tests(unittest.TestCase):
assert_allclose(ift.dobj.to_global_data(tt1.val),
ift.dobj.to_global_data(rand1.val))
@expand(product(spaces))
def test_sum(self, space):
op1 = ift.DiagonalOperator(ift.Field(space, 2.))
op2 = ift.ScalingOperator(3., space)
full_op = op1 + op2 - (op2 - op1) + op1 + op1 + op2
x = ift.Field(space, 1.)
res = full_op(x)
assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
assert_allclose(ift.dobj.to_global_data(res.val), 11.)
@expand(product(spaces))
def test_chain(self, space):
op1 = ift.DiagonalOperator(ift.Field(space, 2.))
op2 = ift.ScalingOperator(3., space)
full_op = op1 * op2 * (op2 * op1) * op1 * op1 * op2
x = ift.Field(space, 1.)
res = full_op(x)
assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
assert_allclose(ift.dobj.to_global_data(res.val), 432.)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment