Commit a5a73894 authored by Neel Shah's avatar Neel Shah
Browse files

replace flatten() by ravel(), and minor changes

parent 6b285039
Pipeline #106469 canceled with stages
......@@ -61,7 +61,7 @@ class MatrixProductOperator(LinearOperator):
non-participating subspaces of the domain cannot be different in the
target space, thus the matrix must have enough unsummed axes to stand
in the places of summed-over axes of the domain, if those summed-over
axes are followed by any unsummed(inactive) axes. Example to make
axes are followed by any unsummed (inactive) axes. Example to make
this clear: If the first 2 spaces of the domain are summed over in
the matrix multiplication and the 3rd space doesn't participate, the
matrix must have (at least) 2 axes that don't participate in the
......@@ -105,9 +105,8 @@ class MatrixProductOperator(LinearOperator):
domain_dim = len(domain_shape)
# take shortcut for trivial case
if spaces is not None:
if len(self._domain.shape) == 1 and spaces == (0, ):
spaces = None
if spaces is not None and len(self._domain.shape) == 1 and spaces == (0, ):
spaces = None
if spaces is None:
self._spaces = None
......@@ -157,7 +156,7 @@ class MatrixProductOperator(LinearOperator):
matrix_shape_idx +=1
matrix_shape_idx += 1
target_space_shape = tuple(target_space_shape)
self._mat_last_n = tuple([-domain_dim + i for i in range(domain_dim)])
......@@ -216,14 +215,14 @@ class MatrixProductOperator(LinearOperator):
if self._spaces is None:
if not self._flatten:
if times:
res = np.tensordot(m, x.val, axes = len(x.domain.shape))
res = np.tensordot(m, x.val, axes = len(self._domain.shape))
mat_axes = np.flip(self._mat_last_m)
field_axes = self._target_axes
res = np.tensordot(m, x.val, axes=(mat_axes, field_axes))
res = res.transpose()
res =
res =
return Field(target, res)
if times:
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment