Commit a5a73894 authored by Neel Shah's avatar Neel Shah
Browse files

replace flatten() by ravel(), and minor changes

parent 6b285039
Pipeline #106469 canceled with stages
...@@ -61,7 +61,7 @@ class MatrixProductOperator(LinearOperator): ...@@ -61,7 +61,7 @@ class MatrixProductOperator(LinearOperator):
non-participating subspaces of the domain cannot be different in the non-participating subspaces of the domain cannot be different in the
target space, thus the matrix must have enough unsummed axes to stand target space, thus the matrix must have enough unsummed axes to stand
in the places of summed-over axes of the domain, if those summed-over in the places of summed-over axes of the domain, if those summed-over
axes are followed by any unsummed(inactive) axes. Example to make axes are followed by any unsummed (inactive) axes. Example to make
this clear: If the first 2 spaces of the domain are summed over in this clear: If the first 2 spaces of the domain are summed over in
the matrix multiplication and the 3rd space doesn't participate, the the matrix multiplication and the 3rd space doesn't participate, the
matrix must have (at least) 2 axes that don't participate in the matrix must have (at least) 2 axes that don't participate in the
...@@ -105,9 +105,8 @@ class MatrixProductOperator(LinearOperator): ...@@ -105,9 +105,8 @@ class MatrixProductOperator(LinearOperator):
domain_dim = len(domain_shape) domain_dim = len(domain_shape)
# take shortcut for trivial case # take shortcut for trivial case
if spaces is not None: if spaces is not None and len(self._domain.shape) == 1 and spaces == (0, ):
if len(self._domain.shape) == 1 and spaces == (0, ): spaces = None
spaces = None
if spaces is None: if spaces is None:
self._spaces = None self._spaces = None
...@@ -157,7 +156,7 @@ class MatrixProductOperator(LinearOperator): ...@@ -157,7 +156,7 @@ class MatrixProductOperator(LinearOperator):
target_space_shape.append(self._domain[i].shape[j]) target_space_shape.append(self._domain[i].shape[j])
else: else:
target_space_shape.append(matrix.shape[matrix_shape_idx]) target_space_shape.append(matrix.shape[matrix_shape_idx])
matrix_shape_idx +=1 matrix_shape_idx += 1
target_space_shape = tuple(target_space_shape) target_space_shape = tuple(target_space_shape)
self._mat_last_n = tuple([-domain_dim + i for i in range(domain_dim)]) self._mat_last_n = tuple([-domain_dim + i for i in range(domain_dim)])
...@@ -216,14 +215,14 @@ class MatrixProductOperator(LinearOperator): ...@@ -216,14 +215,14 @@ class MatrixProductOperator(LinearOperator):
if self._spaces is None: if self._spaces is None:
if not self._flatten: if not self._flatten:
if times: if times:
res = np.tensordot(m, x.val, axes = len(x.domain.shape)) res = np.tensordot(m, x.val, axes = len(self._domain.shape))
else: else:
mat_axes = np.flip(self._mat_last_m) mat_axes = np.flip(self._mat_last_m)
field_axes = self._target_axes field_axes = self._target_axes
res = np.tensordot(m, x.val, axes=(mat_axes, field_axes)) res = np.tensordot(m, x.val, axes=(mat_axes, field_axes))
res = res.transpose() res = res.transpose()
else: else:
res = m.dot(x.val.flatten()).reshape(self._domain.shape) res = m.dot(x.val.ravel()).reshape(self._domain.shape)
return Field(target, res) return Field(target, res)
if times: if times:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment