replace flatten() by ravel(), and minor changes

parent 6b285039
Pipeline #106469 canceled with stages
 ... @@ -61,7 +61,7 @@ class MatrixProductOperator(LinearOperator): ... @@ -61,7 +61,7 @@ class MatrixProductOperator(LinearOperator): non-participating subspaces of the domain cannot be different in the non-participating subspaces of the domain cannot be different in the target space, thus the matrix must have enough unsummed axes to stand target space, thus the matrix must have enough unsummed axes to stand in the places of summed-over axes of the domain, if those summed-over in the places of summed-over axes of the domain, if those summed-over axes are followed by any unsummed(inactive) axes. Example to make axes are followed by any unsummed (inactive) axes. Example to make this clear: If the first 2 spaces of the domain are summed over in this clear: If the first 2 spaces of the domain are summed over in the matrix multiplication and the 3rd space doesn't participate, the the matrix multiplication and the 3rd space doesn't participate, the matrix must have (at least) 2 axes that don't participate in the matrix must have (at least) 2 axes that don't participate in the ... @@ -105,9 +105,8 @@ class MatrixProductOperator(LinearOperator): ... @@ -105,9 +105,8 @@ class MatrixProductOperator(LinearOperator): domain_dim = len(domain_shape) domain_dim = len(domain_shape) # take shortcut for trivial case # take shortcut for trivial case if spaces is not None: if spaces is not None and len(self._domain.shape) == 1 and spaces == (0, ): if len(self._domain.shape) == 1 and spaces == (0, ): spaces = None spaces = None if spaces is None: if spaces is None: self._spaces = None self._spaces = None ... @@ -157,7 +156,7 @@ class MatrixProductOperator(LinearOperator): ... @@ -157,7 +156,7 @@ class MatrixProductOperator(LinearOperator): target_space_shape.append(self._domain[i].shape[j]) target_space_shape.append(self._domain[i].shape[j]) else: else: target_space_shape.append(matrix.shape[matrix_shape_idx]) target_space_shape.append(matrix.shape[matrix_shape_idx]) matrix_shape_idx +=1 matrix_shape_idx += 1 target_space_shape = tuple(target_space_shape) target_space_shape = tuple(target_space_shape) self._mat_last_n = tuple([-domain_dim + i for i in range(domain_dim)]) self._mat_last_n = tuple([-domain_dim + i for i in range(domain_dim)]) ... @@ -216,14 +215,14 @@ class MatrixProductOperator(LinearOperator): ... @@ -216,14 +215,14 @@ class MatrixProductOperator(LinearOperator): if self._spaces is None: if self._spaces is None: if not self._flatten: if not self._flatten: if times: if times: res = np.tensordot(m, x.val, axes = len(x.domain.shape)) res = np.tensordot(m, x.val, axes = len(self._domain.shape)) else: else: mat_axes = np.flip(self._mat_last_m) mat_axes = np.flip(self._mat_last_m) field_axes = self._target_axes field_axes = self._target_axes res = np.tensordot(m, x.val, axes=(mat_axes, field_axes)) res = np.tensordot(m, x.val, axes=(mat_axes, field_axes)) res = res.transpose() res = res.transpose() else: else: res = m.dot(x.val.flatten()).reshape(self._domain.shape) res = m.dot(x.val.ravel()).reshape(self._domain.shape) return Field(target, res) return Field(target, res) if times: if times: ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!