Commit 9e4b327a authored by Lukas Platz's avatar Lukas Platz
Browse files

use spaces instead of axes, add flatten option

parent c56d29a4
Pipeline #71847 passed with stages
in 15 minutes and 11 seconds
......@@ -16,12 +16,13 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..domain_tuple import DomainTuple
from ..multi_domain import MultiDomain
from import UnstructuredDomain
from ..field import Field
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
from .endomorphic_operator import EndomorphicOperator
from .. import utilities
import numpy as np
......@@ -353,49 +354,113 @@ class PartialExtractor(LinearOperator):
class MatrixProductOperator(EndomorphicOperator):
"""Endomorphic matrix multiplication with input field.
This operator supports scipy.sparse matrices and numpy arrays
as the matrix to be applied.
For numpy array matrices, can apply the matrix over a subspace
of the input.
If the input arrays have more than one dimension, for
scipy.sparse matrices the `flatten` keyword argument must be
set to true. This means that the input field will be flattened
before applying the matrix and reshaped to its original shape
Matrices are tested regarding their compatibility with the
called for application method.
Flattening and subspace application are mutually exclusive.
domain: :class:`Domain` or :class:`DomainTuple`
Domain of the operator.
If :class:`DomainTuple` it is assumed to have only one entry.
matrix: scipy.sparse matrix or numpy array
Matrix of shape `(domain.shape, domain.shape)`. Needs to support
`dot()` and `transpose()` in the style of numpy arrays.
axis: integer or None
in case of multi-dim input fields (N > 1), along which axis
of the input field to apply the matrix
Quadratic matrix of shape `(domain.shape, domain.shape)`
(if `not flatten`) that supports `matrix.transpose()`.
If it is not a numpy array, needs to be applicable to the val
array of input fields by ``.
spaces: int or tuple of int, optional
The subdomain(s) of "domain" which the operator acts on.
If None, it acts on all elements.
Only possible for numpy array matrices.
If `len(domain) > 1` and `flatten=False`, this parameter is
flatten: boolean, optional
Whether the input value array should be flattened before
applying the matrix and reshaped to its original shape
Needed for scipy.sparse matrices if `len(domain) > 1`.
def __init__(self, domain, matrix, axis=None):
def __init__(self, domain, matrix, spaces=None, flatten=False):
self._capability = self.TIMES | self.ADJOINT_TIMES
self._domain = DomainTuple.make(domain)
shp = self._domain.shape
if len(shp) > 1:
if axis is None:
raise ValueError(
"For multi-dim inputs an axis needs to be specified.")
ref_shp = (shp[axis], shp[axis])
mat_dim = len(matrix.shape)
if mat_dim % 2 != 0 or \
matrix.shape != (matrix.shape[:mat_dim//2] + matrix.shape[:mat_dim//2]):
raise ValueError("Matrix must be quadratic.")
appl_dim = mat_dim // 2 # matrix application space dimension
# take shortcut for trivial case
if spaces is not None:
if len(self._domain.shape) == 1 and spaces == (0, ):
spaces = None
if spaces is None:
self._spaces = None
self._active_axes = utilities.my_sum(self._domain.axes)
appl_space_shape = self._domain.shape
if flatten:
appl_space_shape = (utilities.my_product(appl_space_shape), )
if not (axis is None or axis == 0):
if flatten:
raise ValueError(
"For one-dim inputs axis must be None or zero")
ref_shp = (shp[0], shp[0])
axis = None
if matrix.shape != ref_shp:
"Cannot flatten input AND apply to a subspace")
if not isinstance(matrix, np.ndarray):
raise ValueError(
"Application to subspaces only supported for numpy array matrices."
self._spaces = utilities.parse_spaces(spaces, len(self._domain))
appl_space_shape = []
active_axes = []
for space_idx in spaces:
appl_space_shape += self._domain[space_idx].shape
active_axes += self._domain.axes[space_idx]
appl_space_shape = tuple(appl_space_shape)
self._active_axes = tuple(active_axes)
self._mat_last_n = tuple([-appl_dim + i for i in range(appl_dim)])
self._mat_first_n = np.arange(appl_dim)
# Test if the matrix and the array it will be applied to fit
if matrix.shape[:appl_dim] != appl_space_shape:
raise ValueError(
"Domain/domain on axis and matrix shape do not match.")
"Matrix and domain shapes are incompatible under the requested "
+ "application scheme.\n" +
f"Matrix appl shape: {matrix.shape[:appl_dim]}, " +
f"appl_space_shape: {appl_space_shape}.")
self._mat = matrix
self._mat_tr = matrix.transpose().conjugate()
self._axis = axis
self._flatten = flatten
def apply(self, x, mode):
self._check_input(x, mode)
m = self._mat if mode == self.TIMES else self._mat_tr
if self._axis is None:
res =
res = np.tensordot(m, x.val, axes=(-1, self._axis))
res = np.moveaxis(res, 0, self._axis)
times = (mode == self.TIMES)
m = self._mat if times else self._mat_tr
if self._spaces is None:
if not self._flatten:
res =
res =
return Field(self._domain, res)
mat_axes = self._mat_last_n if times else np.flip(self._mat_last_n)
move_axes = self._mat_first_n if times else np.flip(self._mat_first_n)
res = np.tensordot(m, x.val, axes=(mat_axes, self._active_axes))
res = np.moveaxis(res, move_axes, self._active_axes)
return Field(self._domain, res)
......@@ -277,34 +277,40 @@ def testSpecialSum(sp):
op = ift.library.correlated_fields._SpecialSum(sp)
@pmp('sp', [ift.RGSpace(10)])
@pmp('seed', [12, 3])
def testMatrixProductOperator_1d(sp, seed):
def metatestMatrixProductOperator(sp, mat_shape, seed, **kwargs):
mat = ift.random.current_rng().standard_normal((*sp.shape, *sp.shape))
op = ift.MatrixProductOperator(sp, mat)
mat = ift.random.current_rng().standard_normal(mat_shape)
op = ift.MatrixProductOperator(sp, mat, **kwargs)
mat = mat + 1j*ift.random.current_rng().standard_normal((*sp.shape, *sp.shape))
op = ift.MatrixProductOperator(sp, mat)
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape)
op = ift.MatrixProductOperator(sp, mat, **kwargs)
@pmp('sp', [ift.RGSpace(10)])
@pmp('spaces', [None, (0,)])
@pmp('seed', [12, 3])
def testMatrixProductOperator_1d(sp, spaces, seed):
mat_shape = sp.shape * 2
metatestMatrixProductOperator(sp, mat_shape, seed, spaces=spaces)
@pmp('sp', [ift.RGSpace((2, 10))])
@pmp('axis', [0, 1])
@pmp('sp', [ift.DomainTuple.make((ift.RGSpace((2)), ift.RGSpace((10))))])
@pmp('spaces', [(0,), (1,), (0, 1)])
@pmp('seed', [12, 3])
def testMatrixProductOperator_2d(sp, axis, seed):
mat_shp = (sp.shape[axis], sp.shape[axis])
mat = ift.random.current_rng().standard_normal(mat_shp)
op = ift.MatrixProductOperator(sp, mat, axis)
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shp)
op = ift.MatrixProductOperator(sp, mat, axis)
def testMatrixProductOperator_2d_spaces(sp, spaces, seed):
appl_shape = []
for sp_idx in spaces:
appl_shape += sp[sp_idx].shape
appl_shape = tuple(appl_shape)
mat_shape = appl_shape * 2
metatestMatrixProductOperator(sp, mat_shape, seed, spaces=spaces)
@pmp('sp', [ift.RGSpace((2, 10))])
@pmp('seed', [12, 3])
def testMatrixProductOperator_2d_flatten(sp, seed):
appl_shape = (ift.utilities.my_product(sp.shape),)
mat_shape = appl_shape * 2
metatestMatrixProductOperator(sp, mat_shape, seed, flatten=True)
@pmp('seed', [12, 3])
def testPartialExtractor(seed):
......@@ -318,7 +324,6 @@ def testPartialExtractor(seed):
@pmp('seed', [12, 3])
def testSlowFieldAdapter(seed):
dom = {'a': ift.RGSpace(1), 'b': ift.RGSpace(2)}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment