Skip to content
Snippets Groups Projects
Commit c88eadc1 authored by Philipp Arras's avatar Philipp Arras
Browse files

Fix MatrixProductOperator

parent 93b3fdd9
Branches
Tags
1 merge request!368Add more automatic checks for operators
...@@ -364,20 +364,19 @@ class MatrixProductOperator(EndomorphicOperator): ...@@ -364,20 +364,19 @@ class MatrixProductOperator(EndomorphicOperator):
`dot()` and `transpose()` in the style of numpy arrays. `dot()` and `transpose()` in the style of numpy arrays.
""" """
def __init__(self, domain, matrix): def __init__(self, domain, matrix):
self._domain = domain self._domain = DomainTuple.make(domain)
shp = self._domain.shape
if len(shp) > 1:
raise TypeError('Only 1D-domain supported yet.')
if matrix.shape != (*shp, *shp):
raise ValueError
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
self._mat = matrix self._mat = matrix
self._mat_tr = matrix.transpose() self._mat_tr = matrix.transpose().conjugate()
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
res = x.to_global_data() res = x.to_global_data()
if mode == self.TIMES: f = self._mat.dot if mode == self.TIMES else self._mat_tr.dot
res = self._mat.dot(res) res = f(res)
if mode == self.ADJOINT_TIMES:
res = self._mat_tr.dot(res)
return Field.from_global_data(self._domain, res) return Field.from_global_data(self._domain, res)
def __repr__(self):
return "MatrixProductOperator"
...@@ -297,6 +297,18 @@ def testValueInserter(sp, seed): ...@@ -297,6 +297,18 @@ def testValueInserter(sp, seed):
ift.extra.consistency_check(op) ift.extra.consistency_check(op)
@pmp('sp', [ift.RGSpace(10)])
@pmp('seed', [12, 3])
def testMatrixProductOperator(sp, seed):
np.random.seed(seed)
mat = np.random.randn(*sp.shape, *sp.shape)
op = ift.MatrixProductOperator(sp, mat)
ift.extra.consistency_check(op)
mat = mat + 1j*np.random.randn(*sp.shape, *sp.shape)
op = ift.MatrixProductOperator(sp, mat)
ift.extra.consistency_check(op)
@pmp('seed', [12, 3]) @pmp('seed', [12, 3])
def testPartialExtractor(seed): def testPartialExtractor(seed):
np.random.seed(seed) np.random.seed(seed)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment