Commit 352cd0f7 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweak

parent 2d59db54
Pipeline #21106 passed with stage
in 4 minutes and 9 seconds
......@@ -90,6 +90,16 @@ class DiagonalOperator(EndomorphicOperator):
for i, j in enumerate(self._spaces):
if diagonal.domain[i] != self._domain[j]:
raise ValueError("domain mismatch")
if self._spaces == tuple(range(len(self._domain.domains))):
self._spaces = None # shortcut
if self._spaces is not None:
active_axes = []
for space_index in self._spaces:
active_axes += self._domain.axes[space_index]
self._reshaper = [shp if i in active_axes else 1
for i, shp in enumerate(self._domain.shape)]
self._diagonal = diagonal.copy()
self._self_adjoint = None
......@@ -140,13 +150,7 @@ class DiagonalOperator(EndomorphicOperator):
if self._spaces is None:
return operation(self._diagonal)(x)
active_axes = []
for space_index in self._spaces:
active_axes += x.domain.axes[space_index]
reshaper = [shp if i in active_axes else 1
for i, shp in enumerate(x.shape)]
reshaped_local_diagonal = np.reshape(self._diagonal.val, reshaper)
reshaped_local_diagonal = np.reshape(self._diagonal.val, self._reshaper)
# here the actual multiplication takes place
return Field(x.domain, val=operation(reshaped_local_diagonal)(x.val))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment