From 352cd0f703bdcf22d6af57f1ed4719917fc381d3 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Mon, 6 Nov 2017 13:59:05 +0100
Subject: [PATCH] tweak

---
 nifty/operators/diagonal_operator.py | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/nifty/operators/diagonal_operator.py b/nifty/operators/diagonal_operator.py
index 334592320..11f961f14 100644
--- a/nifty/operators/diagonal_operator.py
+++ b/nifty/operators/diagonal_operator.py
@@ -90,6 +90,16 @@ class DiagonalOperator(EndomorphicOperator):
             for i, j in enumerate(self._spaces):
                 if diagonal.domain[i] != self._domain[j]:
                     raise ValueError("domain mismatch")
+            if self._spaces == tuple(range(len(self._domain.domains))):
+                self._spaces = None  # shortcut
+
+        if self._spaces is not None:
+            active_axes = []
+            for space_index in self._spaces:
+                active_axes += self._domain.axes[space_index]
+
+            self._reshaper = [shp if i in active_axes else 1
+                              for i, shp in enumerate(self._domain.shape)]
 
         self._diagonal = diagonal.copy()
         self._self_adjoint = None
@@ -140,13 +150,7 @@ class DiagonalOperator(EndomorphicOperator):
         if self._spaces is None:
             return operation(self._diagonal)(x)
 
-        active_axes = []
-        for space_index in self._spaces:
-            active_axes += x.domain.axes[space_index]
-
-        reshaper = [shp if i in active_axes else 1
-                    for i, shp in enumerate(x.shape)]
-        reshaped_local_diagonal = np.reshape(self._diagonal.val, reshaper)
+        reshaped_local_diagonal = np.reshape(self._diagonal.val, self._reshaper)
 
         # here the actual multiplication takes place
         return Field(x.domain, val=operation(reshaped_local_diagonal)(x.val))
-- 
GitLab