Commit c2dd59ae authored by Neel Shah's avatar Neel Shah
Browse files

Fixed bug with adjoint_times for spaces=None, minor optimization and aesthetic changes

parent 4457dfd7
Pipeline #104575 canceled with stages
......@@ -4,6 +4,7 @@ from .. import utilities
from ..domain_tuple import DomainTuple
from ..field import Field
from ..domains.rg_space import RGSpace
from .linear_operator import LinearOperator
class GeneralMatrixProduct(LinearOperator):
......@@ -19,7 +20,8 @@ class GeneralMatrixProduct(LinearOperator):
scipy.sparse matrices the `flatten` keyword argument must be
set to true. This means that the input field will be flattened
before applying the matrix and reshaped to its original shape
afterwards.
afterwards. Flattening is only supported when the domain and target
are the same, and a target can't be specified if flatten=True'
Matrices are tested regarding their compatibility with the
called for application method.
......@@ -80,13 +82,16 @@ class GeneralMatrixProduct(LinearOperator):
self._spaces = None
self._active_axes = utilities.my_sum(self._domain.axes)
self._inactive_axes = ()
if flatten:
domain_shape = (utilities.my_product(domain_shape), )
target_space_shape = domain_shape
target_dim = len(target_space_shape)
mat_inactive_axes_dim = mat_dim - len(domain_shape)
if mat_inactive_axes_dim < 0:
raise ValueError('Domain too big for matrix.')
raise ValueError("Domain has more dimensions than matrix.")
target_space_shape = matrix.shape[:mat_inactive_axes_dim]
target_dim = mat_inactive_axes_dim
if flatten:
target_space_shape = (utilities.my_product(target_space_shape), )
else:
if flatten:
raise ValueError(
......@@ -108,9 +113,9 @@ class GeneralMatrixProduct(LinearOperator):
domain_shape = tuple(domain_shape)
self._active_axes = tuple(active_axes)
self._inactive_axes = tuple(self._inactive_axes)
mat_inactive_axes_dim = len(matrix.shape) - len(domain_shape)
mat_inactive_axes_dim = mat_dim - len(domain_shape)
if mat_inactive_axes_dim < 0:
raise ValueError('Domain too big for matrix.')
raise ValueError("Domain has more dimensions than matrix.")
target_dim = mat_inactive_axes_dim + len(self._inactive_axes)
domain_dim = len(domain_shape)
......@@ -127,23 +132,44 @@ class GeneralMatrixProduct(LinearOperator):
self._mat_last_n = tuple([-domain_dim + i for i in range(domain_dim)])
self._mat_first_n = np.arange(domain_dim)
self._target_last_n = tuple([-len(self._inactive_axes) + i for i in range(len(self._inactive_axes))])
#mat_last_m is needed for adjoint application even if spaces = None
self._mat_last_m = tuple([-mat_inactive_axes_dim + i for i in range(mat_inactive_axes_dim)])
self._target_last_n = tuple([-len(self._inactive_axes) + i for
i in range(len(self._inactive_axes))])
# mat_last_m is needed for adjoint application even if spaces=None
self._mat_last_m = tuple([-mat_inactive_axes_dim + i for i in
range(mat_inactive_axes_dim)])
self._target_axes = tuple(range(len(target_space_shape)))
if spaces != None:
self._field_axes = list(self._target_axes)
for i in list(self._target_axes):
if i in self._inactive_axes:
self._field_axes.remove(i)
self._field_axes = tuple(self._field_axes)
if target == None:
if target_dim != 0:
default_target = DomainTuple.make(RGSpace(shape = target_space_shape))
if flatten:
self._target = self._domain
else:
default_target = DomainTuple.make(None)
self._target = default_target
if target_dim != 0:
default_target = DomainTuple.make(RGSpace(shape=target_space_shape))
else:
default_target = DomainTuple.make(None)
self._target = default_target
elif flatten:
raise ValueError("Flattening is supported only for endomorphic application,"
+ " and you can't specify a target.")
elif target.shape == target_space_shape:
self._target = target
self._target = DomainTuple.make(target)
else:
raise ValueError("Target space has invalid shape.")
raise ValueError(f"Target space has invalid shape.\n"
+"Its shape should be {target_space_shape}.")
if matrix.shape[mat_inactive_axes_dim:] != domain_shape:
raise ValueError("Matrix doesn't fit with the domain.")
matrix_appl_shape = matrix.shape[mat_inactive_axes_dim:]
if matrix_appl_shape != domain_shape:
raise ValueError("Matrix doesn't fit with the domain.\n" +
f"Shape of matrix axes used in summation: {matrix_appl_shape},\n " +
f"Shape of domain axes used in summation: {domain_shape}.")
self._mat = matrix
self._mat_tr = matrix.transpose().conjugate()
......@@ -158,29 +184,25 @@ class GeneralMatrixProduct(LinearOperator):
if self._spaces is None:
if not self._flatten:
if times:
res = np.tensordot(m,x.val,axes = len(x.shape))
res = np.tensordot(m, x.val, axes=len(x.domain.shape))
else:
mat_axes = np.flip(self._mat_last_m)
field_axes = list(range(len(self._target.shape)))
res = np.tensordot(m,x.val,axes=(mat_axes,field_axes))
res = res.reshape(np.flip(res.shape))
field_axes = self._target_axes
res = np.tensordot(m, x.val, axes=(mat_axes, field_axes))
res = res.transpose()
else:
res = m.dot(x.val.flatten()).reshape(self._domain.shape)
return Field(target,res)
return Field(target, res)
if times:
mat_axes = self._mat_last_n
move_axes = self._target_last_n
res = np.tensordot(m, x.val,axes=(mat_axes,self._active_axes))
res = np.moveaxis(res,move_axes,self._inactive_axes)
res = np.tensordot(m, x.val, axes=(mat_axes, self._active_axes))
res = np.moveaxis(res, move_axes, self._inactive_axes)
else:
mat_axes = np.flip(self._mat_last_m)
move_axes = np.flip(self._mat_first_n)
field_axes = list(range(len(self._target.shape)))
for i in range(len(field_axes)):
if i in self._inactive_axes:
field_axes.remove(i)
field_axes = tuple(field_axes)
res = np.tensordot(m, x.val, axes=(mat_axes,field_axes))
res = np.moveaxis(res, move_axes,self._active_axes)
return Field(target,res)
field_axes = self._field_axes
res = np.tensordot(m, x.val, axes=(mat_axes, field_axes))
res = np.moveaxis(res, move_axes, self._active_axes)
return Field(target, res)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment