From 6e6b4f76e717501fa408724a677b6eeab1a7addb Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Thu, 26 Jul 2018 11:40:05 +0200
Subject: [PATCH] running, but not correctly

---
 nifty5/field.py                          |  2 ++
 nifty5/operator.py                       | 22 ++++++++++++++++++----
 nifty5/operators/relaxed_sum_operator.py |  4 ++--
 3 files changed, 22 insertions(+), 6 deletions(-)

diff --git a/nifty5/field.py b/nifty5/field.py
index 377e74e7a..1516458f1 100644
--- a/nifty5/field.py
+++ b/nifty5/field.py
@@ -49,6 +49,8 @@ class Field(object):
     def __init__(self, domain, val):
         if not isinstance(domain, DomainTuple):
             raise TypeError("domain must be of type DomainTuple")
+        if np.isscalar(val):
+            val = np.full((), val)
         if not isinstance(val, dobj.data_object):
             raise TypeError("val must be of type dobj.data_object")
         if domain.shape != val.shape:
diff --git a/nifty5/operator.py b/nifty5/operator.py
index d47d6df1c..c3ff955c3 100644
--- a/nifty5/operator.py
+++ b/nifty5/operator.py
@@ -11,9 +11,10 @@ from .multi.multi_field import MultiField
 
 
 class Linearization(object):
-    def __init__(self, val, jac):
+    def __init__(self, val, jac, metric=None):
         self._val = val
         self._jac = jac
+        self._metric = metric
 
     @property
     def domain(self):
@@ -31,17 +32,25 @@ class Linearization(object):
     def jac(self):
         return self._jac
 
+    @property
+    def metric(self):
+        return self._metric
+
     def __neg__(self):
-        return Linearization(-self._val, self._jac*(-1))
+        return Linearization(-self._val, self._jac*(-1),
+                             None if self._metric is None else self._metric*(-1))
 
     def __add__(self, other):
         if isinstance(other, Linearization):
             from .operators.relaxed_sum_operator import RelaxedSumOperator
+            met = None
+            if self._metric is not None and other._metric is not None:
+                met = RelaxedSumOperator((self._metric, other._metric))
             return Linearization(
                 self._val.unite(other._val),
-                RelaxedSumOperator((self._jac, other._jac)))
+                RelaxedSumOperator((self._jac, other._jac)), met)
         if isinstance(other, (int, float, complex, Field, MultiField)):
-            return Linearization(self._val+other, self._jac)
+            return Linearization(self._val+other, self._jac, self._metric)
 
     def __radd__(self, other):
         return self.__add__(other)
@@ -76,6 +85,11 @@ class Linearization(object):
             d1 = makeOp(other)
             return Linearization(self._val*other, d1*self._jac)
 
+    def sum(self):
+        from .sugar import full
+        from .operators.vdot_operator import VdotOperator
+        return Linearization(full((),self._val.sum()),
+                             VdotOperator(full(self._jac.target,1))*self._jac)
     @staticmethod
     def make_var(field):
         from .operators.scaling_operator import ScalingOperator
diff --git a/nifty5/operators/relaxed_sum_operator.py b/nifty5/operators/relaxed_sum_operator.py
index b44bb77b2..aeac60ae7 100644
--- a/nifty5/operators/relaxed_sum_operator.py
+++ b/nifty5/operators/relaxed_sum_operator.py
@@ -60,6 +60,6 @@ class RelaxedSumOperator(LinearOperator):
         self._check_mode(mode)
         res = None
         for op in self._ops:
-            x = op.apply(x.extract(op._dom(mode)), mode)
-            res = x if res is None else res.unite(x)
+            tmp = op.apply(x.extract(op._dom(mode)), mode)
+            res = tmp if res is None else res.unite(tmp)
         return res
-- 
GitLab