Commit 5a7c3863 authored by Neel Shah's avatar Neel Shah
Browse files

implement compatibility with sparse matrices

parent c88b1575
Pipeline #107175 canceled with stages
...@@ -80,8 +80,8 @@ class TensorDotOperator(LinearOperator): ...@@ -80,8 +80,8 @@ class TensorDotOperator(LinearOperator):
tensor: scipy.sparse matrix or numpy array tensor: scipy.sparse matrix or numpy array
Tensor of a shape whose last few axes should match the shape Tensor of a shape whose last few axes should match the shape
of the axes of the field to be summed over (if not 'flatten'). of the axes of the field to be summed over (if not 'flatten').
If `flatten`, needs to be applicable to the val array of If it is not a numpy array, needs to be applicable to the val
input fields by `tensor.dot()`. array of input fields by `tensor.dot()`.
spaces: int or tuple of int, optional spaces: int or tuple of int, optional
The subdomain(s) of "domain" which the operator acts on. The subdomain(s) of "domain" which the operator acts on.
If None, it acts on all elements. If None, it acts on all elements.
...@@ -216,21 +216,24 @@ class TensorDotOperator(LinearOperator): ...@@ -216,21 +216,24 @@ class TensorDotOperator(LinearOperator):
if self._spaces is None: if self._spaces is None:
if not self._flatten: if not self._flatten:
if times: if type(t) == np.ndarray:
res = np.tensordot(t, x.val, axes = len(self._domain.shape)) if times:
res = np.tensordot(t, x.val, axes = len(self._domain.shape))
else:
tensor_axes = np.flip(self._tensor_last_m)
field_axes = self._target_axes
res = np.tensordot(t, x.val, axes=(tensor_axes, field_axes))
res = res.transpose()
else: else:
mat_axes = np.flip(self._tensor_last_m) res = t.dot(x.val)
field_axes = self._target_axes
res = np.tensordot(t, x.val, axes=(mat_axes, field_axes))
res = res.transpose()
else: else:
res = t.dot(x.val.ravel()).reshape(self._domain.shape) res = t.dot(x.val.ravel()).reshape(self._domain.shape)
return Field(target, res) return Field(target, res)
if times: if times:
mat_axes = self._tensor_last_n tensor_axes = self._tensor_last_n
move_axes = self._target_last_n move_axes = self._target_last_n
res = np.tensordot(t, x.val, axes=(mat_axes, self._active_axes)) res = np.tensordot(t, x.val, axes=(tensor_axes, self._active_axes))
try: try:
res = np.moveaxis(res, move_axes, self._inactive_axes) res = np.moveaxis(res, move_axes, self._inactive_axes)
except np.AxisError: except np.AxisError:
...@@ -238,11 +241,11 @@ class TensorDotOperator(LinearOperator): ...@@ -238,11 +241,11 @@ class TensorDotOperator(LinearOperator):
"Number of dimensions in tensor:" + "Number of dimensions in tensor:" +
f"{len(t.shape)}\n") f"{len(t.shape)}\n")
else: else:
mat_axes = np.flip(self._tensor_last_m) tensor_axes = np.flip(self._tensor_last_m)
move_axes = np.flip(self._tensor_first_n) move_axes = np.flip(self._tensor_first_n)
field_axes = self._field_axes field_axes = self._field_axes
try: try:
res = np.tensordot(t, x.val, axes=(mat_axes, field_axes)) res = np.tensordot(t, x.val, axes=(tensor_axes, field_axes))
except ValueError as e: except ValueError as e:
if e.args[0] == "shape-mismatch for sum": if e.args[0] == "shape-mismatch for sum":
raise ValueError("The tensor has too few extra dimensions.\n" + raise ValueError("The tensor has too few extra dimensions.\n" +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment