diff --git a/nifty/operators/diagonal_operator.py b/nifty/operators/diagonal_operator.py index 33459232050ff0bc77c514c21220e0487a037923..11f961f149b606aeae8e155181fa7db77cdd9aea 100644 --- a/nifty/operators/diagonal_operator.py +++ b/nifty/operators/diagonal_operator.py @@ -90,6 +90,16 @@ class DiagonalOperator(EndomorphicOperator): for i, j in enumerate(self._spaces): if diagonal.domain[i] != self._domain[j]: raise ValueError("domain mismatch") + if self._spaces == tuple(range(len(self._domain.domains))): + self._spaces = None # shortcut + + if self._spaces is not None: + active_axes = [] + for space_index in self._spaces: + active_axes += self._domain.axes[space_index] + + self._reshaper = [shp if i in active_axes else 1 + for i, shp in enumerate(self._domain.shape)] self._diagonal = diagonal.copy() self._self_adjoint = None @@ -140,13 +150,7 @@ class DiagonalOperator(EndomorphicOperator): if self._spaces is None: return operation(self._diagonal)(x) - active_axes = [] - for space_index in self._spaces: - active_axes += x.domain.axes[space_index] - - reshaper = [shp if i in active_axes else 1 - for i, shp in enumerate(x.shape)] - reshaped_local_diagonal = np.reshape(self._diagonal.val, reshaper) + reshaped_local_diagonal = np.reshape(self._diagonal.val, self._reshaper) # here the actual multiplication takes place return Field(x.domain, val=operation(reshaped_local_diagonal)(x.val))