Commit 21f72f3c authored by Martin Reinecke's avatar Martin Reinecke
Browse files

merge nifty2go

parents a2690f08 352cd0f7
Pipeline #21107 passed with stage
in 4 minutes and 11 seconds
......@@ -25,7 +25,6 @@ from ..nifty_utilities import cast_iseq_to_tuple
from ..dobj import to_ndarray as to_np
class DiagonalOperator(EndomorphicOperator):
""" NIFTY class for diagonal operators.
......@@ -93,6 +92,16 @@ class DiagonalOperator(EndomorphicOperator):
for i, j in enumerate(self._spaces):
if diagonal.domain[i] != self._domain[j]:
raise ValueError("domain mismatch")
if self._spaces == tuple(range(len(self._domain.domains))):
self._spaces = None # shortcut
if self._spaces is not None:
active_axes = []
for space_index in self._spaces:
active_axes += self._domain.axes[space_index]
self._reshaper = [shp if i in active_axes else 1
for i, shp in enumerate(self._domain.shape)]
self._diagonal = diagonal.copy()
self._self_adjoint = None
......@@ -143,13 +152,5 @@ class DiagonalOperator(EndomorphicOperator):
if self._spaces is None:
return diag*x
active_axes = []
for space_index in self._spaces:
active_axes += x.domain.axes[space_index]
reshaper = [shp if i in active_axes else 1
for i, shp in enumerate(x.shape)]
reshaped_local_diagonal = to_np(diag.val).reshape(reshaper)
# here the actual multiplication takes place
reshaped_local_diagonal = np.reshape(to_np(diag.val), self._reshaper)
return Field(x.domain, val=x.val*reshaped_local_diagonal)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment