Commit 5a7c3863 authored by Neel Shah's avatar Neel Shah
Browse files

implement compatibility with sparse matrices

parent c88b1575
Pipeline #107175 canceled with stages
......@@ -80,8 +80,8 @@ class TensorDotOperator(LinearOperator):
tensor: scipy.sparse matrix or numpy array
Tensor of a shape whose last few axes should match the shape
of the axes of the field to be summed over (if not 'flatten').
If `flatten`, needs to be applicable to the val array of
input fields by `tensor.dot()`.
If it is not a numpy array, needs to be applicable to the val
array of input fields by `tensor.dot()`.
spaces: int or tuple of int, optional
The subdomain(s) of "domain" which the operator acts on.
If None, it acts on all elements.
......@@ -216,21 +216,24 @@ class TensorDotOperator(LinearOperator):
if self._spaces is None:
if not self._flatten:
if times:
res = np.tensordot(t, x.val, axes = len(self._domain.shape))
if type(t) == np.ndarray:
if times:
res = np.tensordot(t, x.val, axes = len(self._domain.shape))
else:
tensor_axes = np.flip(self._tensor_last_m)
field_axes = self._target_axes
res = np.tensordot(t, x.val, axes=(tensor_axes, field_axes))
res = res.transpose()
else:
mat_axes = np.flip(self._tensor_last_m)
field_axes = self._target_axes
res = np.tensordot(t, x.val, axes=(mat_axes, field_axes))
res = res.transpose()
res = t.dot(x.val)
else:
res = t.dot(x.val.ravel()).reshape(self._domain.shape)
return Field(target, res)
if times:
mat_axes = self._tensor_last_n
tensor_axes = self._tensor_last_n
move_axes = self._target_last_n
res = np.tensordot(t, x.val, axes=(mat_axes, self._active_axes))
res = np.tensordot(t, x.val, axes=(tensor_axes, self._active_axes))
try:
res = np.moveaxis(res, move_axes, self._inactive_axes)
except np.AxisError:
......@@ -238,11 +241,11 @@ class TensorDotOperator(LinearOperator):
"Number of dimensions in tensor:" +
f"{len(t.shape)}\n")
else:
mat_axes = np.flip(self._tensor_last_m)
tensor_axes = np.flip(self._tensor_last_m)
move_axes = np.flip(self._tensor_first_n)
field_axes = self._field_axes
try:
res = np.tensordot(t, x.val, axes=(mat_axes, field_axes))
res = np.tensordot(t, x.val, axes=(tensor_axes, field_axes))
except ValueError as e:
if e.args[0] == "shape-mismatch for sum":
raise ValueError("The tensor has too few extra dimensions.\n" +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment