From 29f141e55930cd19d00cbfb2f25512d0b6ab09a3 Mon Sep 17 00:00:00 2001 From: Lukas Platz <lplatz@mpa-garching.mpg.de> Date: Mon, 6 May 2019 14:56:53 +0200 Subject: [PATCH] add operator to multiply fields with matrices --- nifty5/__init__.py | 3 +- nifty5/operators/simple_linear_operators.py | 34 +++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/nifty5/__init__.py b/nifty5/__init__.py index eee17e5bb..8b96ca3c2 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -45,7 +45,8 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator from .operators.outer_product_operator import OuterProduct from .operators.simple_linear_operators import ( VdotOperator, ConjugationOperator, Realizer, - FieldAdapter, ducktape, GeometryRemover, NullOperator) + FieldAdapter, ducktape, GeometryRemover, NullOperator, + MatrixProductOperator) from .operators.value_inserter import ValueInserter from .operators.energy_operators import ( EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood, diff --git a/nifty5/operators/simple_linear_operators.py b/nifty5/operators/simple_linear_operators.py index 1f3d11d2d..15e4d4fe2 100644 --- a/nifty5/operators/simple_linear_operators.py +++ b/nifty5/operators/simple_linear_operators.py @@ -22,6 +22,7 @@ from ..multi_domain import MultiDomain from ..multi_field import MultiField from .endomorphic_operator import EndomorphicOperator from .linear_operator import LinearOperator +from .. import utilities class VdotOperator(LinearOperator): @@ -344,3 +345,36 @@ class _PartialExtractor(LinearOperator): if mode == self.TIMES: return x.extract(self._target) return MultiField.from_dict({key: x[key] for key in x.domain.keys()}) + + +class MatrixProductOperator(EndomorphicOperator): + """Endomorphic matrix multiplication with input field. + + Parameters + ---------- + domain: DomainTuple + Domain of the operator. + matrix: + Matrix of shape (field.shape, field.shape) + space: int, optional + The index of the subdomain on which the operator should act + """ + def __init__(self, domain, matrix, space=None): + self._domain = DomainTuple.make(domain) + self._capability = self.TIMES | self.ADJOINT_TIMES + self._space = utilities.infer_space(self._domain, space) + + self._mat = matrix + self._mat_tr = matrix.transpose() + + def apply(self, x, mode): + self._check_input(x, mode) + res = x.to_global_data() + if mode == self.TIMES: + res = self._mat.dot(res) + if mode == self.ADJOINT_TIMES: + res = self._mat_tr.dot(res) + return Field.from_global_data(self._domain, res) + + def __repr__(self): + return "MatrixProductOperator" -- GitLab