From d07a537eda98c052d849087e93083e703d7230c8 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Sat, 7 Jul 2018 16:08:50 +0200
Subject: [PATCH] more work

---
 nifty5/library/correlated_fields.py     |  4 +--
 nifty5/library/point_sources.py         |  4 ++-
 nifty5/multi/block_diagonal_operator.py | 40 ++++++++++++-------------
 nifty5/multi/multi_domain.py            | 10 ++++---
 nifty5/multi/multi_field.py             | 15 ++++++----
 nifty5/operators/selection_operator.py  |  4 +--
 nifty5/sugar.py                         |  4 ++-
 7 files changed, 44 insertions(+), 37 deletions(-)

diff --git a/nifty5/library/correlated_fields.py b/nifty5/library/correlated_fields.py
index a6f643925..5f92291fe 100644
--- a/nifty5/library/correlated_fields.py
+++ b/nifty5/library/correlated_fields.py
@@ -22,7 +22,7 @@ def make_correlated_field(s_space, amplitude_model):
     position['xi'] = Field.from_random('normal', h_space)
     position['tau'] = amplitude_model.position['tau']
     position['phi'] = amplitude_model.position['phi']
-    position = MultiField(position)
+    position = MultiField.from_dict(position)
 
     xi = Variable(position)['xi']
     A = power_distributor(amplitude_model)
@@ -70,7 +70,7 @@ def make_mf_correlated_field(s_space_spatial, s_space_energy,
     a = a_spatial*a_energy
     A = pd(a)
 
-    position = MultiField({'xi': Field.from_random('normal', h_space)})
+    position = MultiField.from_dict({'xi': Field.from_random('normal', h_space)})
     xi = Variable(position)['xi']
     correlated_field_h = A*xi
     correlated_field = ht(correlated_field_h)
diff --git a/nifty5/library/point_sources.py b/nifty5/library/point_sources.py
index 67e2b162e..4ba9572cb 100644
--- a/nifty5/library/point_sources.py
+++ b/nifty5/library/point_sources.py
@@ -22,6 +22,7 @@ class PointSources(Model):
     @memo
     def value(self):
         points = self.position['points'].local_data
+        # MR FIXME?!
         points = np.clip(points, None, 8.2)
         points = Field.from_local_data(self.position['points'].domain, points)
         return self.IG(points, self._alpha, self._q)
@@ -40,7 +41,8 @@ class PointSources(Model):
         outer = 1/outer_inv
         grad = Field.from_local_data(self.position['points'].domain,
                                      inner*outer)
-        grad = makeOp(MultiField({'points': grad}))
+        grad = makeOp(MultiField.from_dict({"points": grad},
+                                           self.position._domain))
         return SelectionOperator(grad.target, 'points')*grad
 
     @staticmethod
diff --git a/nifty5/multi/block_diagonal_operator.py b/nifty5/multi/block_diagonal_operator.py
index 99877d430..0546f35a1 100644
--- a/nifty5/multi/block_diagonal_operator.py
+++ b/nifty5/multi/block_diagonal_operator.py
@@ -5,7 +5,7 @@ from .multi_field import MultiField
 
 
 class BlockDiagonalOperator(EndomorphicOperator):
-    def __init__(self, operators):
+    def __init__(self, domain, operators):
         """
         Parameters
         ----------
@@ -14,12 +14,12 @@ class BlockDiagonalOperator(EndomorphicOperator):
             LinearOperators as items
         """
         super(BlockDiagonalOperator, self).__init__()
-        self._operators = operators
-        self._domain = MultiDomain.make(
-            {key: op.domain for key, op in self._operators.items()})
+        self._domain = domain
+        self._ops = tuple(operators[key] for key in self.domain.keys())
         self._cap = self._all_ops
-        for op in self._operators.values():
-            self._cap &= op.capability
+        for op in self._ops:
+            if op is not None:
+                self._cap &= op.capability
 
     @property
     def domain(self):
@@ -31,27 +31,27 @@ class BlockDiagonalOperator(EndomorphicOperator):
 
     def apply(self, x, mode):
         self._check_input(x, mode)
-        val = tuple(self._operators[key].apply(x._val[i], mode=mode)
-                    for i, key in enumerate(x.keys()))
+        val = tuple(op.apply(v, mode=mode) if op is not None else None
+                    for op, v in zip(self._ops, x.values()))
         return MultiField(self._domain, val)
 
-    def draw_sample(self, from_inverse=False, dtype=np.float64):
-        dtype = MultiField.build_dtype(dtype, self._domain)
-        val = tuple(self._operators[key].draw_sample(from_inverse, dtype[key])
-                    for key in self._domain._keys)
-        return MultiField(self._domain, val)
+#    def draw_sample(self, from_inverse=False, dtype=np.float64):
+#        dtype = MultiField.build_dtype(dtype, self._domain)
+#        val = tuple(op.draw_sample(from_inverse, dtype)
+#                    for op in self._op)
+#        return MultiField(self._domain, val)
 
     def _combine_chain(self, op):
-        res = {}
-        for key in self._operators.keys():
-            res[key] = self._operators[key]*op._operators[key]
+        if self._domain is not op._domain:
+            raise ValueError("domain mismatch")
+        res = {key : v1*v2 for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)}
         return BlockDiagonalOperator(res)
 
     def _combine_sum(self, op, selfneg, opneg):
         from ..operators.sum_operator import SumOperator
+        if self._domain is not op._domain:
+            raise ValueError("domain mismatch")
         res = {}
-        for key in self._operators.keys():
-            res[key] = SumOperator.make([self._operators[key],
-                                         op._operators[key]],
-                                        [selfneg, opneg])
+        for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops):
+            res[key] = SumOperator.make([v1, v2], [selfneg, opneg])
         return BlockDiagonalOperator(res)
diff --git a/nifty5/multi/multi_domain.py b/nifty5/multi/multi_domain.py
index 6497c780e..6687f165d 100644
--- a/nifty5/multi/multi_domain.py
+++ b/nifty5/multi/multi_domain.py
@@ -6,8 +6,6 @@ from ..utilities import frozendict
 
 class MultiDomain(object):
     _domainCache = {}
-    _subsetCache = set()
-    _compatCache = set()
 
     def __init__(self, dict, _callingfrommake=False):
         if not _callingfrommake:
@@ -15,7 +13,7 @@ class MultiDomain(object):
                 'To create a MultiDomain call `MultiDomain.make()`.')
         self._keys = tuple(sorted(dict.keys()))
         self._domains = tuple(dict[key] for key in self._keys)
-        self._dict = frozendict({key: i for i, key in enumerate(self._keys)})
+        self._idx = frozendict({key: i for i, key in enumerate(self._keys)})
 
     @staticmethod
     def make(inp):
@@ -42,11 +40,15 @@ class MultiDomain(object):
     def domains(self):
         return self._domains
 
+    @property
+    def idx(self):
+        return self._idx
+
     def items(self):
         return zip(self._keys, self._domains)
 
     def __getitem__(self, key):
-        return self._domains[self._dict[key]]
+        return self._domains[self._idx[key]]
 
     def __len__(self):
         return len(self._keys)
diff --git a/nifty5/multi/multi_field.py b/nifty5/multi/multi_field.py
index d43e7f58c..bb6004c5f 100644
--- a/nifty5/multi/multi_field.py
+++ b/nifty5/multi/multi_field.py
@@ -55,21 +55,24 @@ class MultiField(object):
         self._val = val
 
     @staticmethod
-    def from_dict(dict):
-        domain = MultiDomain.make({key: v._domain for key, v in dict.items()})
-        return MultiField(domain, tuple(dict[key] for key in domain._keys))
+    def from_dict(dict, domain=None):
+        if domain is None:
+            domain = MultiDomain.make({key: v._domain
+                                       for key, v in dict.items()})
+        return MultiField(domain, tuple(dict[key] if key in dict else None
+                                        for key in domain.keys()))
 
     def to_dict(self):
-        return {key: val for key, val in zip(self._domain._keys, self._val)}
+        return {key: val for key, val in zip(self._domain.keys(), self._val)}
 
     def __getitem__(self, key):
-        return self._val[self._domain._dict[key]]
+        return self._val[self._domain.idx[key]]
 
     def keys(self):
         return self._domain.keys()
 
     def items(self):
-        return zip(self._domain._keys, self._val)
+        return zip(self._domain.keys(), self._val)
 
     def values(self):
         return self._val
diff --git a/nifty5/operators/selection_operator.py b/nifty5/operators/selection_operator.py
index daf31b92a..71fc3e860 100644
--- a/nifty5/operators/selection_operator.py
+++ b/nifty5/operators/selection_operator.py
@@ -53,6 +53,4 @@ class SelectionOperator(LinearOperator):
             return x[self._key]
         else:
             from ..multi.multi_field import MultiField
-            rval = [None]*len(self._domain)
-            rval[self._domain._dict[self._key]] = x
-            return MultiField(self._domain, tuple(rval))
+            return MultiField.from_dict({self._key: x}, self._domain)
diff --git a/nifty5/sugar.py b/nifty5/sugar.py
index 0dd149ac5..277cb3b62 100644
--- a/nifty5/sugar.py
+++ b/nifty5/sugar.py
@@ -228,10 +228,12 @@ def makeDomain(domain):
 
 
 def makeOp(input):
+    if input is None:
+        return None
     if isinstance(input, Field):
         return DiagonalOperator(input)
     if isinstance(input, MultiField):
-        return BlockDiagonalOperator({key: makeOp(val)
+        return BlockDiagonalOperator(input.domain, {key: makeOp(val)
                                       for key, val in input.items()})
     raise NotImplementedError
 
-- 
GitLab