Skip to content
Snippets Groups Projects
Commit c56d29a4 authored by Lukas Platz's avatar Lukas Platz
Browse files

extend MatrixProductOperator for multi-dim fields

parent 193a276f
No related branches found
No related tags found
1 merge request!433Extend MatrixProductOperator to be able to operate on subdomains of fields
Pipeline #71502 passed
......@@ -22,6 +22,7 @@ from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
import numpy as np
class VdotOperator(LinearOperator):
......@@ -360,21 +361,41 @@ class MatrixProductOperator(EndomorphicOperator):
matrix: scipy.sparse matrix or numpy array
Matrix of shape `(domain.shape, domain.shape)`. Needs to support
`dot()` and `transpose()` in the style of numpy arrays.
axis: integer or None
in case of multi-dim input fields (N > 1), along which axis
of the input field to apply the matrix
"""
def __init__(self, domain, matrix):
def __init__(self, domain, matrix, axis=None):
self._capability = self.TIMES | self.ADJOINT_TIMES
self._domain = DomainTuple.make(domain)
shp = self._domain.shape
if len(shp) > 1:
raise TypeError('Only 1D-domain supported yet.')
if matrix.shape != (*shp, *shp):
raise ValueError
self._capability = self.TIMES | self.ADJOINT_TIMES
if axis is None:
raise ValueError(
"For multi-dim inputs an axis needs to be specified.")
ref_shp = (shp[axis], shp[axis])
else:
if not (axis is None or axis == 0):
raise ValueError(
"For one-dim inputs axis must be None or zero")
ref_shp = (shp[0], shp[0])
axis = None
if matrix.shape != ref_shp:
raise ValueError(
"Domain/domain on axis and matrix shape do not match.")
self._mat = matrix
self._mat_tr = matrix.transpose().conjugate()
self._axis = axis
def apply(self, x, mode):
self._check_input(x, mode)
res = x.val
f = self._mat.dot if mode == self.TIMES else self._mat_tr.dot
res = f(res)
m = self._mat if mode == self.TIMES else self._mat_tr
if self._axis is None:
res = m.dot(x.val)
else:
res = np.tensordot(m, x.val, axes=(-1, self._axis))
res = np.moveaxis(res, 0, self._axis)
return Field(self._domain, res)
......@@ -280,7 +280,7 @@ def testSpecialSum(sp):
@pmp('sp', [ift.RGSpace(10)])
@pmp('seed', [12, 3])
def testMatrixProductOperator(sp, seed):
def testMatrixProductOperator_1d(sp, seed):
ift.random.push_sseq_from_seed(seed)
mat = ift.random.current_rng().standard_normal((*sp.shape, *sp.shape))
op = ift.MatrixProductOperator(sp, mat)
......@@ -291,6 +291,21 @@ def testMatrixProductOperator(sp, seed):
ift.random.pop_sseq()
@pmp('sp', [ift.RGSpace((2, 10))])
@pmp('axis', [0, 1])
@pmp('seed', [12, 3])
def testMatrixProductOperator_2d(sp, axis, seed):
mat_shp = (sp.shape[axis], sp.shape[axis])
ift.random.push_sseq_from_seed(seed)
mat = ift.random.current_rng().standard_normal(mat_shp)
op = ift.MatrixProductOperator(sp, mat, axis)
ift.extra.consistency_check(op)
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shp)
op = ift.MatrixProductOperator(sp, mat, axis)
ift.extra.consistency_check(op)
ift.random.pop_sseq()
@pmp('seed', [12, 3])
def testPartialExtractor(seed):
ift.random.push_sseq_from_seed(seed)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment