Skip to content
Snippets Groups Projects
Commit 29f141e5 authored by Lukas Platz's avatar Lukas Platz
Browse files

add operator to multiply fields with matrices

parent 157402e6
No related branches found
No related tags found
1 merge request!323add operator to multiply fields with matrices
Pipeline #47790 passed
......@@ -45,7 +45,8 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.outer_product_operator import OuterProduct
from .operators.simple_linear_operators import (
VdotOperator, ConjugationOperator, Realizer,
FieldAdapter, ducktape, GeometryRemover, NullOperator)
FieldAdapter, ducktape, GeometryRemover, NullOperator,
MatrixProductOperator)
from .operators.value_inserter import ValueInserter
from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
......
......@@ -22,6 +22,7 @@ from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
from .. import utilities
class VdotOperator(LinearOperator):
......@@ -344,3 +345,36 @@ class _PartialExtractor(LinearOperator):
if mode == self.TIMES:
return x.extract(self._target)
return MultiField.from_dict({key: x[key] for key in x.domain.keys()})
class MatrixProductOperator(EndomorphicOperator):
"""Endomorphic matrix multiplication with input field.
Parameters
----------
domain: DomainTuple
Domain of the operator.
matrix:
Matrix of shape (field.shape, field.shape)
space: int, optional
The index of the subdomain on which the operator should act
"""
def __init__(self, domain, matrix, space=None):
self._domain = DomainTuple.make(domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._space = utilities.infer_space(self._domain, space)
self._mat = matrix
self._mat_tr = matrix.transpose()
def apply(self, x, mode):
self._check_input(x, mode)
res = x.to_global_data()
if mode == self.TIMES:
res = self._mat.dot(res)
if mode == self.ADJOINT_TIMES:
res = self._mat_tr.dot(res)
return Field.from_global_data(self._domain, res)
def __repr__(self):
return "MatrixProductOperator"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment