From 68b71aca97e3e61b6e98687f8349f3baa7f15cd9 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Sun, 29 Oct 2017 11:41:30 +0100
Subject: [PATCH] locate all places where adjustments are needed for
 distributed fields

---
 nifty/data_objects/my_own_do.py      |  4 ++--
 nifty/dobj.py                        |  1 +
 nifty/operators/diagonal_operator.py | 17 +++++++++--------
 3 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/nifty/data_objects/my_own_do.py b/nifty/data_objects/my_own_do.py
index 594ae2838..56c2d31dd 100644
--- a/nifty/data_objects/my_own_do.py
+++ b/nifty/data_objects/my_own_do.py
@@ -56,8 +56,8 @@ class data_object(object):
         a = self._data
         if isinstance(other, data_object):
             b = other._data
-            # if a.shape != b.shape:
-            #     print("shapes are incompatible.")
+            if a.shape != b.shape:
+                raise ValueError("shapes are incompatible.")
         else:
             b = other
 
diff --git a/nifty/dobj.py b/nifty/dobj.py
index 2264d4915..cb3af9f93 100644
--- a/nifty/dobj.py
+++ b/nifty/dobj.py
@@ -1 +1,2 @@
 from .data_objects.my_own_do import *
+#from .data_objects.numpy_do import *
diff --git a/nifty/operators/diagonal_operator.py b/nifty/operators/diagonal_operator.py
index 1757c1271..5006540e1 100644
--- a/nifty/operators/diagonal_operator.py
+++ b/nifty/operators/diagonal_operator.py
@@ -22,6 +22,7 @@ from ..field import Field
 from ..domain_tuple import DomainTuple
 from .endomorphic_operator import EndomorphicOperator
 from ..nifty_utilities import cast_iseq_to_tuple
+from ..dobj import to_ndarray as to_np
 
 
 class DiagonalOperator(EndomorphicOperator):
@@ -97,16 +98,16 @@ class DiagonalOperator(EndomorphicOperator):
         self._unitary = None
 
     def _times(self, x):
-        return self._times_helper(x, lambda z: z.__mul__)
+        return self._times_helper(x, self._diagonal)
 
     def _adjoint_times(self, x):
-        return self._times_helper(x, lambda z: z.conjugate().__mul__)
+        return self._times_helper(x, self._diagonal.conj())
 
     def _inverse_times(self, x):
-        return self._times_helper(x, lambda z: z.__rtruediv__)
+        return self._times_helper(x, 1./self._diagonal)
 
     def _adjoint_inverse_times(self, x):
-        return self._times_helper(x, lambda z: z.conjugate().__rtruediv__)
+        return self._times_helper(x, 1./self._diagonal.conj())
 
     def diagonal(self):
         """ Returns the diagonal of the Operator.
@@ -137,9 +138,9 @@ class DiagonalOperator(EndomorphicOperator):
             self._unitary = (abs(self._diagonal.val) == 1.).all()
         return self._unitary
 
-    def _times_helper(self, x, operation):
+    def _times_helper(self, x, diag):
         if self._spaces is None:
-            return operation(self._diagonal)(x)
+            return diag*x
 
         active_axes = []
         for space_index in self._spaces:
@@ -147,7 +148,7 @@ class DiagonalOperator(EndomorphicOperator):
 
         reshaper = [shp if i in active_axes else 1
                     for i, shp in enumerate(x.shape)]
-        reshaped_local_diagonal = self._diagonal.val.reshape(reshaper)
+        reshaped_local_diagonal = to_np(diag.val).reshape(reshaper)
 
         # here the actual multiplication takes place
-        return Field(x.domain, val=operation(reshaped_local_diagonal)(x.val))
+        return Field(x.domain, val=x.val*reshaped_local_diagonal)
-- 
GitLab