From 29f141e55930cd19d00cbfb2f25512d0b6ab09a3 Mon Sep 17 00:00:00 2001
From: Lukas Platz <lplatz@mpa-garching.mpg.de>
Date: Mon, 6 May 2019 14:56:53 +0200
Subject: [PATCH] add operator to multiply fields with matrices

---
 nifty5/__init__.py                          |  3 +-
 nifty5/operators/simple_linear_operators.py | 34 +++++++++++++++++++++
 2 files changed, 36 insertions(+), 1 deletion(-)

diff --git a/nifty5/__init__.py b/nifty5/__init__.py
index eee17e5bb..8b96ca3c2 100644
--- a/nifty5/__init__.py
+++ b/nifty5/__init__.py
@@ -45,7 +45,8 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator
 from .operators.outer_product_operator import OuterProduct
 from .operators.simple_linear_operators import (
     VdotOperator, ConjugationOperator, Realizer,
-    FieldAdapter, ducktape, GeometryRemover, NullOperator)
+    FieldAdapter, ducktape, GeometryRemover, NullOperator,
+    MatrixProductOperator)
 from .operators.value_inserter import ValueInserter
 from .operators.energy_operators import (
     EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
diff --git a/nifty5/operators/simple_linear_operators.py b/nifty5/operators/simple_linear_operators.py
index 1f3d11d2d..15e4d4fe2 100644
--- a/nifty5/operators/simple_linear_operators.py
+++ b/nifty5/operators/simple_linear_operators.py
@@ -22,6 +22,7 @@ from ..multi_domain import MultiDomain
 from ..multi_field import MultiField
 from .endomorphic_operator import EndomorphicOperator
 from .linear_operator import LinearOperator
+from .. import utilities
 
 
 class VdotOperator(LinearOperator):
@@ -344,3 +345,36 @@ class _PartialExtractor(LinearOperator):
         if mode == self.TIMES:
             return x.extract(self._target)
         return MultiField.from_dict({key: x[key] for key in x.domain.keys()})
+
+
+class MatrixProductOperator(EndomorphicOperator):
+    """Endomorphic matrix multiplication with input field.
+
+    Parameters
+    ----------
+    domain: DomainTuple
+        Domain of the operator.
+    matrix:
+        Matrix of shape (field.shape, field.shape)
+    space: int, optional
+        The index of the subdomain on which the operator should act
+    """
+    def __init__(self, domain, matrix, space=None):
+        self._domain = DomainTuple.make(domain)
+        self._capability = self.TIMES | self.ADJOINT_TIMES
+        self._space = utilities.infer_space(self._domain, space)
+
+        self._mat = matrix
+        self._mat_tr = matrix.transpose()
+
+    def apply(self, x, mode):
+        self._check_input(x, mode)
+        res = x.to_global_data()
+        if mode == self.TIMES:
+            res = self._mat.dot(res)
+        if mode == self.ADJOINT_TIMES:
+            res = self._mat_tr.dot(res)
+        return Field.from_global_data(self._domain, res)
+
+    def __repr__(self):
+        return "MatrixProductOperator"
-- 
GitLab