From 352cd0f703bdcf22d6af57f1ed4719917fc381d3 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Mon, 6 Nov 2017 13:59:05 +0100 Subject: [PATCH] tweak --- nifty/operators/diagonal_operator.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/nifty/operators/diagonal_operator.py b/nifty/operators/diagonal_operator.py index 334592320..11f961f14 100644 --- a/nifty/operators/diagonal_operator.py +++ b/nifty/operators/diagonal_operator.py @@ -90,6 +90,16 @@ class DiagonalOperator(EndomorphicOperator): for i, j in enumerate(self._spaces): if diagonal.domain[i] != self._domain[j]: raise ValueError("domain mismatch") + if self._spaces == tuple(range(len(self._domain.domains))): + self._spaces = None # shortcut + + if self._spaces is not None: + active_axes = [] + for space_index in self._spaces: + active_axes += self._domain.axes[space_index] + + self._reshaper = [shp if i in active_axes else 1 + for i, shp in enumerate(self._domain.shape)] self._diagonal = diagonal.copy() self._self_adjoint = None @@ -140,13 +150,7 @@ class DiagonalOperator(EndomorphicOperator): if self._spaces is None: return operation(self._diagonal)(x) - active_axes = [] - for space_index in self._spaces: - active_axes += x.domain.axes[space_index] - - reshaper = [shp if i in active_axes else 1 - for i, shp in enumerate(x.shape)] - reshaped_local_diagonal = np.reshape(self._diagonal.val, reshaper) + reshaped_local_diagonal = np.reshape(self._diagonal.val, self._reshaper) # here the actual multiplication takes place return Field(x.domain, val=operation(reshaped_local_diagonal)(x.val)) -- GitLab