From 1f5a367c3242739c547b8a7b4bd16600ebdbc9d3 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Mon, 8 Jan 2018 11:18:07 +0100
Subject: [PATCH] tweak SumOperator; fix a few other problems

---
 demos/paper_demos/cartesian_wiener_filter.py |  2 +-
 nifty/field.py                               |  4 +-
 nifty/operators/linear_operator.py           |  6 +-
 nifty/operators/sum_operator.py              | 86 +++++++++++++++++---
 4 files changed, 79 insertions(+), 19 deletions(-)

diff --git a/demos/paper_demos/cartesian_wiener_filter.py b/demos/paper_demos/cartesian_wiener_filter.py
index 95468b64a..137a0e6ad 100644
--- a/demos/paper_demos/cartesian_wiener_filter.py
+++ b/demos/paper_demos/cartesian_wiener_filter.py
@@ -70,7 +70,7 @@ if __name__ == "__main__":
     mask_2[N2_10*7:N2_10*9] = 0.
     mask_2 = ift.Field(signal_space_2, ift.dobj.from_global_data(mask_2))
 
-    R = ift.ResponseOperator(signal_domain, spaces=(0, 1),
+    R = ift.ResponseOperator(signal_domain,
                              sigma=(response_sigma_1, response_sigma_2),
                              exposure=(mask_1, mask_2))
     data_domain = R.target
diff --git a/nifty/field.py b/nifty/field.py
index 4428b2026..467561d3f 100644
--- a/nifty/field.py
+++ b/nifty/field.py
@@ -370,10 +370,10 @@ class Field(object):
         return self.copy()
 
     def __neg__(self):
-        return Field(self._domain, -self.val, self.dtype)
+        return Field(self._domain, -self.val)
 
     def __abs__(self):
-        return Field(self._domain, dobj.abs(self.val), self.dtype)
+        return Field(self._domain, dobj.abs(self.val))
 
     def _contraction_helper(self, op, spaces):
         if spaces is None:
diff --git a/nifty/operators/linear_operator.py b/nifty/operators/linear_operator.py
index a27c29461..536aa107a 100644
--- a/nifty/operators/linear_operator.py
+++ b/nifty/operators/linear_operator.py
@@ -104,7 +104,7 @@ class LinearOperator(with_metaclass(
     def __add__(self, other):
         from .sum_operator import SumOperator
         other = self._toOperator(other, self.domain)
-        return SumOperator(self, other)
+        return SumOperator.make([self, other], [False, False])
 
     def __radd__(self, other):
         return self.__add__(other)
@@ -112,13 +112,13 @@ class LinearOperator(with_metaclass(
     def __sub__(self, other):
         from .sum_operator import SumOperator
         other = self._toOperator(other, self.domain)
-        return SumOperator(self, other, neg=True)
+        return SumOperator.make([self, other], [False, True])
 
     # MR FIXME: this might be more complicated ...
     def __rsub__(self, other):
         from .sum_operator import SumOperator
         other = self._toOperator(other, self.domain)
-        return SumOperator(other, self, neg=True)
+        return SumOperator.make(other, self, [False, True])
 
     def supports(self, ops):
         return False
diff --git a/nifty/operators/sum_operator.py b/nifty/operators/sum_operator.py
index 070c1634e..51178b845 100644
--- a/nifty/operators/sum_operator.py
+++ b/nifty/operators/sum_operator.py
@@ -20,20 +20,80 @@ from .linear_operator import LinearOperator
 
 
 class SumOperator(LinearOperator):
-    def __init__(self, op1, op2, neg=False):
+    def __init__(self, ops, neg, _callingfrommake=False):
+        if not _callingfrommake:
+            raise NotImplementedError
         super(SumOperator, self).__init__()
-        if op1.domain != op2.domain or op1.target != op2.target:
-            raise ValueError("domain mismatch")
-        self._capability = (op1.capability & op2.capability &
-                            (self.TIMES | self.ADJOINT_TIMES))
-        neg1 = op1._neg if isinstance(op1, SumOperator) else (False,)
-        op1 = op1._ops if isinstance(op1, SumOperator) else (op1,)
-        neg2 = op2._neg if isinstance(op2, SumOperator) else (False,)
-        op2 = op2._ops if isinstance(op2, SumOperator) else (op2,)
-        if neg:
-            neg2 = tuple(not n for n in neg2)
-        self._ops = op1 + op2
-        self._neg = neg1 + neg2
+        self._ops = ops
+        self._neg = neg
+        self._capability = self.TIMES | self.ADJOINT_TIMES
+        for op in ops:
+            self._capability &= op.capability
+
+    @staticmethod
+    def simplify(ops, neg):
+        from .scaling_operator import ScalingOperator
+        from .diagonal_operator import DiagonalOperator
+        # Step 1: verify domains
+        for op in ops[1:]:
+            if op.domain != ops[0].domain or op.target != ops[0].target:
+                raise ValueError("domain mismatch")
+        # Step 2: unpack SumOperators
+        opsnew = []
+        negnew = []
+        for op, ng in zip (ops, neg):
+            if isinstance(op, SumOperator):
+                opsnew += op._ops
+                if ng:
+                    negtmp += [not n for n in ng]
+                else:
+                    negtmp += list(ng)
+            else:
+                opsnew.append(op)
+                negnew.append(ng)
+        ops = opsnew
+        neg = negnew
+        # Step 3: collect ScalingOperators
+        sum = 0.
+        opsnew = []
+        negnew = []
+        lastdom = ops[-1].domain
+        for op, ng in zip(ops, neg):
+            if isinstance(op, ScalingOperator):
+                sum += op._factor * (-1 if ng else 1)
+            else:
+                opsnew.append(op)
+                negnew.append(ng)
+        if sum != 0.:
+            # try to absorb the factor into a DiagonalOperator
+            for i in range(len(opsnew)):
+                if isinstance(opsnew[i], DiagonalOperator):
+                    sum *= (-1 if negnew[i] else 1)
+                    opsnew[i] = DiagonalOperator(opsnew[i].diagonal()+sum,
+                                                 domain=opsnew[i].domain,
+                                                 spaces=opsnew[i]._spaces)
+                    sum = 0.
+                    break
+        if sum != 0:
+            # have to add the scaling operator at the end
+            opsnew.append(ScalingOperator(sum, lastdom))
+            newnew.append(False)
+        ops = opsnew
+        neg = negnew
+        # Step 4: combine DiagonalOperators where possible
+        # (TBD)
+        return ops, neg
+
+    @staticmethod
+    def make(ops, neg):
+        ops = tuple(ops)
+        neg = tuple(neg)
+        if len(ops)!= len(neg):
+            raise ValueError("length mismatch between ops and neg")
+        ops, neg = SumOperator.simplify(ops, neg)
+        if len(ops) == 1 and not neg[0]:
+            return ops[0]
+        return SumOperator(ops, neg, _callingfrommake=True)
 
     @property
     def domain(self):
-- 
GitLab