Commit c88eadc1 authored by Philipp Arras's avatar Philipp Arras
Browse files

Fix MatrixProductOperator

parent 93b3fdd9
......@@ -364,20 +364,19 @@ class MatrixProductOperator(EndomorphicOperator):
`dot()` and `transpose()` in the style of numpy arrays.
"""
def __init__(self, domain, matrix):
self._domain = domain
self._domain = DomainTuple.make(domain)
shp = self._domain.shape
if len(shp) > 1:
raise TypeError('Only 1D-domain supported yet.')
if matrix.shape != (*shp, *shp):
raise ValueError
self._capability = self.TIMES | self.ADJOINT_TIMES
self._mat = matrix
self._mat_tr = matrix.transpose()
self._mat_tr = matrix.transpose().conjugate()
def apply(self, x, mode):
self._check_input(x, mode)
res = x.to_global_data()
if mode == self.TIMES:
res = self._mat.dot(res)
if mode == self.ADJOINT_TIMES:
res = self._mat_tr.dot(res)
f = self._mat.dot if mode == self.TIMES else self._mat_tr.dot
res = f(res)
return Field.from_global_data(self._domain, res)
def __repr__(self):
return "MatrixProductOperator"
......@@ -297,6 +297,18 @@ def testValueInserter(sp, seed):
ift.extra.consistency_check(op)
@pmp('sp', [ift.RGSpace(10)])
@pmp('seed', [12, 3])
def testMatrixProductOperator(sp, seed):
np.random.seed(seed)
mat = np.random.randn(*sp.shape, *sp.shape)
op = ift.MatrixProductOperator(sp, mat)
ift.extra.consistency_check(op)
mat = mat + 1j*np.random.randn(*sp.shape, *sp.shape)
op = ift.MatrixProductOperator(sp, mat)
ift.extra.consistency_check(op)
@pmp('seed', [12, 3])
def testPartialExtractor(seed):
np.random.seed(seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment