Commit a6d80b97 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'matrix_product_operator' into 'NIFTy_5'

add operator to multiply fields with matrices

See merge request !323
parents 157402e6 676302fe
Pipeline #49109 passed with stages
in 19 minutes and 7 seconds
......@@ -45,7 +45,8 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.outer_product_operator import OuterProduct
from .operators.simple_linear_operators import (
VdotOperator, ConjugationOperator, Realizer,
FieldAdapter, ducktape, GeometryRemover, NullOperator)
FieldAdapter, ducktape, GeometryRemover, NullOperator,
from .operators.value_inserter import ValueInserter
from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
......@@ -344,3 +344,35 @@ class _PartialExtractor(LinearOperator):
if mode == self.TIMES:
return x.extract(self._target)
return MultiField.from_dict({key: x[key] for key in x.domain.keys()})
class MatrixProductOperator(EndomorphicOperator):
"""Endomorphic matrix multiplication with input field.
domain: :class:`Domain` or :class:`DomainTuple`
Domain of the operator.
If :class:`DomainTuple` it is assumed to have only one entry.
matrix: scipy.sparse matrix or numpy array
Matrix of shape `(domain.shape, domain.shape)`. Needs to support
`dot()` and `transpose()` in the style of numpy arrays.
def __init__(self, domain, matrix):
self._domain = domain
self._capability = self.TIMES | self.ADJOINT_TIMES
self._mat = matrix
self._mat_tr = matrix.transpose()
def apply(self, x, mode):
self._check_input(x, mode)
res = x.to_global_data()
if mode == self.TIMES:
res =
if mode == self.ADJOINT_TIMES:
res =
return Field.from_global_data(self._domain, res)
def __repr__(self):
return "MatrixProductOperator"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment