From c960121b31b1187bb11d3beb79e64af3a2461c5a Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Tue, 24 Jul 2018 21:21:59 +0200
Subject: [PATCH] more

---
 nifty5/multi/multi_domain.py             | 12 +++++
 nifty5/multi/multi_field.py              |  1 +
 nifty5/operator.py                       | 20 ++++---
 nifty5/operators/relaxed_sum_operator.py | 66 ++++++++++++++++++++++++
 nifty5/sugar.py                          |  3 +-
 5 files changed, 93 insertions(+), 9 deletions(-)
 create mode 100644 nifty5/operators/relaxed_sum_operator.py

diff --git a/nifty5/multi/multi_domain.py b/nifty5/multi/multi_domain.py
index 715e51570..7876184af 100644
--- a/nifty5/multi/multi_domain.py
+++ b/nifty5/multi/multi_domain.py
@@ -105,3 +105,15 @@ class MultiDomain(object):
         for key, dom in zip(self._keys, self._domains):
             res += key+": "+str(dom)+"\n"
         return res
+
+    @staticmethod
+    def union(inp):
+        res = {}
+        for dom in inp:
+            for key, subdom in zip(dom._keys, dom._domains):
+                if key in res:
+                    if res[key] is not subdom:
+                        raise ValueError("domain mismatch")
+                else:
+                    res[key] = subdom
+        return MultiDomain.make(res)
diff --git a/nifty5/multi/multi_field.py b/nifty5/multi/multi_field.py
index c79e2767a..3c4457264 100644
--- a/nifty5/multi/multi_field.py
+++ b/nifty5/multi/multi_field.py
@@ -121,6 +121,7 @@ class MultiField(object):
 
     @staticmethod
     def full(domain, val):
+        domain = MultiDomain.make(domain)
         return MultiField(domain, tuple(Field.full(dom, val)
                           for dom in domain._domains))
 
diff --git a/nifty5/operator.py b/nifty5/operator.py
index ecccefbf0..a6abdf5ee 100644
--- a/nifty5/operator.py
+++ b/nifty5/operator.py
@@ -38,7 +38,10 @@ class Linearization(object):
 
     def __add__(self, other):
         if isinstance(other, Linearization):
-            return Linearization(self._val+other._val, self._jac+other._jac)
+            from .operators.relaxed_sum_operator import RelaxedSumOperator
+            return Linearization(
+                MultiField.combine((self._val, other._val)),
+                RelaxedSumOperator((self._jac, other._jac)))
         if isinstance(other, (int, float, complex, Field, MultiField)):
             return Linearization(self._val+other, self._jac)
 
@@ -52,10 +55,10 @@ class Linearization(object):
         return (-self).__add__(other)
 
     def __mul__(self, other):
-        from .operators.diagonal_operator import DiagonalOperator
+        from .sugar import makeOp
         if isinstance(other, Linearization):
-            d1 = DiagonalOperator(self._val)
-            d2 = DiagonalOperator(other._val)
+            d1 = makeOp(self._val)
+            d2 = makeOp(other._val)
             return Linearization(self._val*other._val,
                                  self._jac*d2 + d1*other._jac)
         if isinstance(other, (int, float, complex)):
@@ -63,15 +66,16 @@ class Linearization(object):
             #    return ...
             return Linearization(self._val*other, self._jac*other)
         if isinstance(other, (Field, MultiField)):
-            d2 = DiagonalOperator(other)
+            d2 = makeOp(other)
             return Linearization(self._val*other, self._jac*d2)
         raise TypeError
 
     def __rmul__(self, other):
+        from .sugar import makeOp
         if isinstance(other, (int, float, complex)):
             return Linearization(self._val*other, self._jac*other)
         if isinstance(other, (Field, MultiField)):
-            d1 = DiagonalOperator(other)
+            d1 = makeOp(other)
             return Linearization(self._val*other, d1*self._jac)
 
     @staticmethod
@@ -80,8 +84,8 @@ class Linearization(object):
         return Linearization(field, ScalingOperator(1., field.domain))
     @staticmethod
     def make_const(field):
-        from .operators.scaling_operator import ScalingOperator
-        return Linearization(field, ScalingOperator(0., {}))
+        from .operators.null_operator import NullOperator
+        return Linearization(field, NullOperator({}, field.domain))
 
 class Operator(NiftyMetaBase()):
     """Transforms values living on one domain into values living on another
diff --git a/nifty5/operators/relaxed_sum_operator.py b/nifty5/operators/relaxed_sum_operator.py
new file mode 100644
index 000000000..f7f50a6ec
--- /dev/null
+++ b/nifty5/operators/relaxed_sum_operator.py
@@ -0,0 +1,66 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Copyright(C) 2013-2018 Max-Planck-Society
+#
+# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
+# and financially supported by the Studienstiftung des deutschen Volkes.
+
+from __future__ import absolute_import, division, print_function
+
+import numpy as np
+
+from ..compat import *
+from ..utilities import my_sum
+from .linear_operator import LinearOperator
+from ..multi.multi_domain import MultiDomain
+
+
+class RelaxedSumOperator(LinearOperator):
+    """Class representing sums of operators with compatible MultiDomains."""
+
+    def __init__(self, ops):
+        super(RelaxedSumOperator, self).__init__()
+        self._ops = ops
+        self._domain = MultiDomain.union([op.domain for op in ops])
+        self._target = MultiDomain.union([op.target for op in ops])
+        self._capability = self.TIMES | self.ADJOINT_TIMES
+        for op in ops:
+            self._capability &= op.capability
+
+    @property
+    def domain(self):
+        return self._domain
+
+    @property
+    def target(self):
+        return self._target
+
+    @property
+    def adjoint(self):
+        return RelaxedSumOperator([op.adjoint for op in self._ops])
+
+    @property
+    def capability(self):
+        return self._capability
+
+    def apply(self, x, mode):
+        self._check_mode(mode)
+        res = None
+        for op in self._ops:
+            tmp = x.extract(op._dom(mode), mode)
+            if res is None:
+                res = tmp
+            else:
+                res = MultiField.combine([res, tmp])
+        return res
diff --git a/nifty5/sugar.py b/nifty5/sugar.py
index 942d11a4e..cc2a2963c 100644
--- a/nifty5/sugar.py
+++ b/nifty5/sugar.py
@@ -251,7 +251,8 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
     def func(f):
         def func2(x):
             if isinstance(x, MultiField):
-                return MultiField({key: func2(val) for key, val in x.items()})
+                return MultiField(x.domain,
+                                  tuple(func2(val) for val in x.values()))
             elif isinstance(x, Field):
                 fu = getattr(dobj, f)
                 return Field(domain=x._domain, val=fu(x.val))
-- 
GitLab