Commit b20e3981 authored by Neel Shah's avatar Neel Shah
Browse files

Generalizing the endomorphic MatrixProductOperator

parent 859032e1
Pipeline #104267 failed with stages
import numpy as np
from .. import utilities
from ..domain_tuple import DomainTuple
from ..field import Field
from import RGSpace
class GeneralMatrixProduct(LinearOperator):
"""Matrix multiplication with input field.
This operator supports scipy.sparse matrices and numpy arrays
as the matrix to be applied.
For numpy array matrices, can apply the matrix over a subspace
of the input.
If the input arrays have more than one dimension, for
scipy.sparse matrices the `flatten` keyword argument must be
set to true. This means that the input field will be flattened
before applying the matrix and reshaped to its original shape
Matrices are tested regarding their compatibility with the
called for application method.
Flattening and subspace application are mutually exclusive.
The target space type and distances can be specified. If
unspecified, it defaults to an RGSpace with default distance
convention. Either a single domain or a DomainTuple of the valid
shape can be set as the target.
When applied to specific subspaces of the domain, the
non-participating subspaces of the domain retain their positions
in the target space. The order of other axes in the target space is
the matrix's axes in their original order. This convention matches
with the endomorphic MatrixProductOperator.
domain: :class:`Domain` or :class:`DomainTuple`
Domain of the operator.
If :class:`Domain` it is assumed to have only one subspace.
matrix: scipy.sparse matrix or numpy array
Quadratic matrix of shape `(domain.shape, domain.shape)`
(if `not flatten`) that supports `matrix.transpose()`.
If it is not a numpy array, needs to be applicable to the val
array of input fields by ``.
spaces: int or tuple of int, optional
The subdomain(s) of "domain" which the operator acts on.
If None, it acts on all elements.
Only possible for numpy array matrices.
If `len(domain) > 1` and `flatten=False`, this parameter is
flatten: boolean, optional
Whether the input value array should be flattened before
applying the matrix and reshaped to its original shape
Needed for scipy.sparse matrices if `len(domain) > 1`.
target: :class:`Domain` or :class:`DomainTuple`, optional
Target of the operator. It must be of the valid shape, other
parameters like domain type and distances are flexible. The
default is an RGSpace with default distances convention.
def __init__(self, domain, matrix, spaces=None, flatten=False, target=None):
self._capability = self.TIMES | self.ADJOINT_TIMES
self._domain = DomainTuple.make(domain)
mat_dim = len(matrix.shape)
domain_shape = domain.shape
domain_dim = len(domain_shape)
# take shortcut for trivial case
if spaces is not None:
if len(self._domain.shape) == 1 and spaces == (0, ):
spaces = None
if spaces is None:
self._spaces = None
self._active_axes = utilities.my_sum(self._domain.axes)
self._inactive_axes = ()
mat_inactive_axes_dim = mat_dim - len(domain_shape)
if mat_inactive_axes_dim < 0:
raise ValueError('Domain too big for matrix.')
target_space_shape = matrix.shape[:mat_inactive_axes_dim]
target_dim = mat_inactive_axes_dim
if flatten:
target_space_shape = (utilities.my_product(target_space_shape), )
if flatten:
raise ValueError(
"Cannot flatten input AND apply to a subspace")
if not isinstance(matrix, np.ndarray):
raise ValueError(
"Application to subspaces only supported for numpy array matrices."
self._spaces = utilities.parse_spaces(spaces, len(self._domain))
active_axes = []
self._inactive_axes = list(range(len(self._domain)))
domain_shape = []
for space_idx in spaces:
domain_shape += self._domain[space_idx].shape
active_axes += self._domain.axes[space_idx]
if space_idx in self._inactive_axes:
domain_shape = tuple(domain_shape)
self._active_axes = tuple(active_axes)
self._inactive_axes = tuple(self._inactive_axes)
mat_inactive_axes_dim = len(matrix.shape) - len(domain_shape)
if mat_inactive_axes_dim < 0:
raise ValueError('Domain too big for matrix.')
target_dim = mat_inactive_axes_dim + len(self._inactive_axes)
domain_dim = len(domain_shape)
target_space_shape = []
matrix_shape_idx = 0
for i in range(target_dim):
if i in tuple(self._inactive_axes):
for j in range(len(self._domain[i].shape)):
matrix_shape_idx +=1
target_space_shape = tuple(target_space_shape)
self._mat_last_n = tuple([-domain_dim + i for i in range(domain_dim)])
self._mat_first_n = np.arange(domain_dim)
self._target_last_n = tuple([-len(self._inactive_axes) + i for i in range(len(self._inactive_axes))])
#mat_last_m is needed for adjoint application even if spaces = None
self._mat_last_m = tuple([-mat_inactive_axes_dim + i for i in range(mat_inactive_axes_dim)])
if target == None:
if target_dim != 0:
default_target = DomainTuple.make(RGSpace(shape = target_space_shape))
default_target = DomainTuple.make(None)
self._target = default_target
elif target.shape == target_space_shape:
self._target = target
raise ValueError("Target space has invalid shape.")
if matrix.shape[mat_inactive_axes_dim:] != domain_shape:
raise ValueError("Matrix doesn't fit with the domain.")
self._mat = matrix
self._mat_tr = matrix.transpose().conjugate()
self._flatten = flatten
def apply(self, x, mode):
self._check_input(x, mode)
times = (mode == self.TIMES)
m = self._mat if times else self._mat_tr
target = self._target if times else self._domain
if self._spaces is None:
if not self._flatten:
if times:
res = np.tensordot(m,x.val,axes = len(x.shape))
mat_axes = np.flip(self._mat_last_m)
field_axes = list(range(len(self._target.shape)))
res = np.tensordot(m,x.val,axes=(mat_axes,field_axes))
res = res.reshape(np.flip(res.shape))
res =
return Field(target,res)
if times:
mat_axes = self._mat_last_n
move_axes = self._target_last_n
res = np.tensordot(m, x.val,axes=(mat_axes,self._active_axes))
res = np.moveaxis(res,move_axes,self._inactive_axes)
mat_axes = np.flip(self._mat_last_m)
move_axes = np.flip(self._mat_first_n)
field_axes = list(range(len(self._target.shape)))
for i in range(len(field_axes)):
if i in self._inactive_axes:
field_axes = tuple(field_axes)
res = np.tensordot(m, x.val, axes=(mat_axes,field_axes))
res = np.moveaxis(res, move_axes,self._active_axes)
return Field(target,res)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment