From 3d1fadb69cbd10ed689fcbe3f8bd8ac1b9a10dd4 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Fri, 22 Jun 2018 11:50:20 +0200
Subject: [PATCH] first try

---
 nifty5/domain_tuple.py                    | 17 ++++++++--
 nifty5/multi/multi_domain.py              | 41 +++++++++++++++++++++++
 nifty5/multi/multi_field.py               | 15 +++++++--
 nifty5/operators/linear_operator.py       |  2 +-
 nifty5/operators/sum_operator.py          | 21 +++++++-----
 test/test_minimization/test_minimizers.py |  2 +-
 6 files changed, 82 insertions(+), 16 deletions(-)

diff --git a/nifty5/domain_tuple.py b/nifty5/domain_tuple.py
index 39eafa885..b164988ab 100644
--- a/nifty5/domain_tuple.py
+++ b/nifty5/domain_tuple.py
@@ -128,13 +128,24 @@ class DomainTuple(object):
     def __eq__(self, x):
         if not isinstance(x, DomainTuple):
             x = DomainTuple.make(x)
-        if self is x:
-            return True
-        return self._dom == x._dom
+        return self is x
 
     def __ne__(self, x):
         return not self.__eq__(x)
 
+    def compatibleTo(self, x):
+        return self.__eq__(x)
+
+    def subsetOf(self, x):
+        return self.__eq__(x)
+
+    def unitedWith(self, x):
+        if not isinstance(x, DomainTuple):
+            x = DomainTuple.make(x)
+        if self != x:
+            raise ValueError("domain mismatch")
+        return self
+
     def __str__(self):
         res = "DomainTuple, len: " + str(len(self))
         for i in self:
diff --git a/nifty5/multi/multi_domain.py b/nifty5/multi/multi_domain.py
index 89272688b..24b7524a3 100644
--- a/nifty5/multi/multi_domain.py
+++ b/nifty5/multi/multi_domain.py
@@ -71,3 +71,44 @@ class MultiDomain(frozendict):
         obj = MultiDomain(domain, _callingfrommake=True)
         MultiDomain._domainCache[domain] = obj
         return obj
+
+    def __eq__(self, x):
+        if not isinstance(x, MultiDomain):
+            x = MultiDomain.make(x)
+        return self is x
+
+    def __ne__(self, x):
+        return not self.__eq__(x)
+
+    def compatibleTo(self, x):
+        if not isinstance(x, MultiDomain):
+            x = MultiDomain.make(x)
+        commonKeys = set(self.keys()) & set(x.keys())
+        for key in commonKeys:
+            if self[key] != x[key]:
+                return False
+        return True
+
+    def subsetOf(self, x):
+        if not isinstance(x, MultiDomain):
+            x = MultiDomain.make(x)
+        for key in self.keys():
+            if not key in x:
+                return False
+            if self[key] != x[key]:
+                return False
+        return True
+
+    def unitedWith(self, x):
+        if not isinstance(x, MultiDomain):
+            x = MultiDomain.make(x)
+        if self == x:
+            return self
+        if not self.compatibleTo(x):
+            raise ValueError("domain mismatch")
+        res = {}
+        for key, val in self.items():
+            res[key] = val
+        for key, val in x.items():
+            res[key] = val
+        return MultiDomain.make(res)
diff --git a/nifty5/multi/multi_field.py b/nifty5/multi/multi_field.py
index 8e96a6e24..fe182774f 100644
--- a/nifty5/multi/multi_field.py
+++ b/nifty5/multi/multi_field.py
@@ -199,9 +199,18 @@ for op in ["__add__", "__radd__", "__iadd__",
     def func(op):
         def func2(self, other):
             if isinstance(other, MultiField):
-                self._check_domain(other)
-                result_val = {key: getattr(sub_field, op)(other[key])
-                              for key, sub_field in self.items()}
+                if self._domain == other._domain:
+                    result_val = {key: getattr(sub_field, op)(other[key])
+                                  for key, sub_field in self.items()}
+                else:
+                    if not self._domain.compatibleTo(other.domain):
+                        raise ValueError("domain mismatch")
+                    fullkeys = set(self._domain.keys()) | set(other._domain.keys())
+                    result_val = {}
+                    for key in fullkeys:
+                        f1 = self[key] if key in self._domain.keys() else other[key]*0
+                        f2 = other[key] if key in other._domain.keys() else self[key]*0
+                        result_val[key] = getattr(f1, op)(f2)
             else:
                 result_val = {key: getattr(val, op)(other)
                               for key, val in self.items()}
diff --git a/nifty5/operators/linear_operator.py b/nifty5/operators/linear_operator.py
index 8cc61dd32..0f8d0aa33 100644
--- a/nifty5/operators/linear_operator.py
+++ b/nifty5/operators/linear_operator.py
@@ -280,5 +280,5 @@ class LinearOperator(NiftyMetaBase()):
 
     def _check_input(self, x, mode):
         self._check_mode(mode)
-        if x.domain != self._dom(mode):
+        if not self._dom(mode).subsetOf(x.domain):
             raise ValueError("The operator's and field's domains don't match.")
diff --git a/nifty5/operators/sum_operator.py b/nifty5/operators/sum_operator.py
index d32d7bfd7..b2bf53eb3 100644
--- a/nifty5/operators/sum_operator.py
+++ b/nifty5/operators/sum_operator.py
@@ -23,12 +23,14 @@ import numpy as np
 class SumOperator(LinearOperator):
     """Class representing sums of operators."""
 
-    def __init__(self, ops, neg, _callingfrommake=False):
+    def __init__(self, ops, neg, dom, tgt, _callingfrommake=False):
         if not _callingfrommake:
             raise NotImplementedError
         super(SumOperator, self).__init__()
         self._ops = ops
         self._neg = neg
+        self._domain = dom
+        self._target = tgt
         self._capability = self.TIMES | self.ADJOINT_TIMES
         for op in ops:
             self._capability &= op.capability
@@ -38,9 +40,12 @@ class SumOperator(LinearOperator):
         from .scaling_operator import ScalingOperator
         from .diagonal_operator import DiagonalOperator
         # Step 1: verify domains
+        dom = ops[0].domain
+        tgt = ops[0].target
         for op in ops[1:]:
-            if op.domain != ops[0].domain or op.target != ops[0].target:
-                raise ValueError("domain mismatch")
+            dom = dom.unitedWith(op.domain)
+            tgt = tgt.unitedWith(op.target)
+
         # Step 2: unpack SumOperators
         opsnew = []
         negnew = []
@@ -124,7 +129,7 @@ class SumOperator(LinearOperator):
                     negnew.append(neg[i])
         ops = opsnew
         neg = negnew
-        return ops, neg
+        return ops, neg, dom, tgt
 
     @staticmethod
     def make(ops, neg):
@@ -134,18 +139,18 @@ class SumOperator(LinearOperator):
             raise ValueError("ops is empty")
         if len(ops) != len(neg):
             raise ValueError("length mismatch between ops and neg")
-        ops, neg = SumOperator.simplify(ops, neg)
+        ops, neg, dom, tgt = SumOperator.simplify(ops, neg)
         if len(ops) == 1 and not neg[0]:
             return ops[0]
-        return SumOperator(ops, neg, _callingfrommake=True)
+        return SumOperator(ops, neg, dom, tgt, _callingfrommake=True)
 
     @property
     def domain(self):
-        return self._ops[0].domain
+        return self._domain
 
     @property
     def target(self):
-        return self._ops[0].target
+        return self._target
 
     @property
     def adjoint(self):
diff --git a/test/test_minimization/test_minimizers.py b/test/test_minimization/test_minimizers.py
index 060fe30b3..c1c7632cf 100644
--- a/test/test_minimization/test_minimizers.py
+++ b/test/test_minimization/test_minimizers.py
@@ -76,7 +76,7 @@ class Test_Minimizers(unittest.TestCase):
         except ImportError:
             raise SkipTest
         np.random.seed(42)
-        space = ift.UnstructuredDomain((2,))
+        space = ift.DomainTuple.make(ift.UnstructuredDomain((2,)))
         starting_point = ift.Field.from_random('normal', domain=space)*10
 
         class RBEnergy(ift.Energy):
-- 
GitLab