Scheduled maintenance on Monday 2019-06-24 between 10:00-11:00 CEST

Commit 29f141e5 authored by Platz, Lukas (lplatz)'s avatar Platz, Lukas (lplatz)

add operator to multiply fields with matrices

parent 157402e6
Pipeline #47790 passed with stages
in 8 minutes and 19 seconds
......@@ -45,7 +45,8 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.outer_product_operator import OuterProduct
from .operators.simple_linear_operators import (
VdotOperator, ConjugationOperator, Realizer,
FieldAdapter, ducktape, GeometryRemover, NullOperator)
FieldAdapter, ducktape, GeometryRemover, NullOperator,
MatrixProductOperator)
from .operators.value_inserter import ValueInserter
from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
......
......@@ -22,6 +22,7 @@ from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
from .. import utilities
class VdotOperator(LinearOperator):
......@@ -344,3 +345,36 @@ class _PartialExtractor(LinearOperator):
if mode == self.TIMES:
return x.extract(self._target)
return MultiField.from_dict({key: x[key] for key in x.domain.keys()})
class MatrixProductOperator(EndomorphicOperator):
"""Endomorphic matrix multiplication with input field.
Parameters
----------
domain: DomainTuple
Domain of the operator.
matrix:
Matrix of shape (field.shape, field.shape)
space: int, optional
The index of the subdomain on which the operator should act
"""
def __init__(self, domain, matrix, space=None):
self._domain = DomainTuple.make(domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._space = utilities.infer_space(self._domain, space)
self._mat = matrix
self._mat_tr = matrix.transpose()
def apply(self, x, mode):
self._check_input(x, mode)
res = x.to_global_data()
if mode == self.TIMES:
res = self._mat.dot(res)
if mode == self.ADJOINT_TIMES:
res = self._mat_tr.dot(res)
return Field.from_global_data(self._domain, res)
def __repr__(self):
return "MatrixProductOperator"
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment