From cf834a3e9ec321e51be8b8fe8076e0de38d7c33a Mon Sep 17 00:00:00 2001
From: Theo Steininger <theos@mpa-garching.mpg.de>
Date: Wed, 15 Feb 2017 12:53:09 +0100
Subject: [PATCH] Added ProjectionOperator

---
 nifty/operators/__init__.py                   |  2 +
 .../operators/projection_operator/__init__.py |  3 +
 .../projection_operator.py                    | 99 +++++++++++++++++++
 3 files changed, 104 insertions(+)
 create mode 100644 nifty/operators/projection_operator/__init__.py
 create mode 100644 nifty/operators/projection_operator/projection_operator.py

diff --git a/nifty/operators/__init__.py b/nifty/operators/__init__.py
index dbf8ab6b3..dc92c25af 100644
--- a/nifty/operators/__init__.py
+++ b/nifty/operators/__init__.py
@@ -35,4 +35,6 @@ from invertible_operator_mixin import InvertibleOperatorMixin
 
 from propagator_operator import PropagatorOperator
 
+from propagator_operator import PropagatorOperator
+
 from composed_operator import ComposedOperator
diff --git a/nifty/operators/projection_operator/__init__.py b/nifty/operators/projection_operator/__init__.py
new file mode 100644
index 000000000..c5b2e1747
--- /dev/null
+++ b/nifty/operators/projection_operator/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding: utf-8 -*-
+
+from projection_operator import ProjectionOperator
diff --git a/nifty/operators/projection_operator/projection_operator.py b/nifty/operators/projection_operator/projection_operator.py
new file mode 100644
index 000000000..ce7c0ac35
--- /dev/null
+++ b/nifty/operators/projection_operator/projection_operator.py
@@ -0,0 +1,99 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+
+from nifty.field import Field
+
+from nifty.operators.endomorphic_operator import EndomorphicOperator
+
+
+class ProjectionOperator(EndomorphicOperator):
+
+    # ---Overwritten properties and methods---
+
+    def __init__(self, projection_field):
+        if not isinstance(projection_field, Field):
+            raise TypeError("The projection_field must be a NIFTy-Field"
+                            "instance.")
+        self._projection_field = projection_field
+        self._unitary = None
+
+    def _times(self, x, spaces):
+        # if the domain matches directly
+        # -> multiply the fields directly
+        if x.domain == self.domain:
+            # here the actual multiplication takes place
+            dotted = (self._projection_field * x).sum()
+            return self._projection_field * dotted
+
+        # if the distribution_strategy of self is sub-slice compatible to
+        # the one of x, reshape the local data of self and apply it directly
+        active_axes = []
+        if spaces is None:
+            active_axes = range(len(x.shape))
+        else:
+            for space_index in spaces:
+                active_axes += x.domain_axes[space_index]
+
+        axes_local_distribution_strategy = \
+            x.val.get_axes_local_distribution_strategy(active_axes)
+        if axes_local_distribution_strategy == \
+           self._projection_field.distribution_strategy:
+            local_projection_vector = \
+                self._projection_field.val.get_local_data(copy=False)
+        else:
+            # create an array that is sub-slice compatible
+            self.logger.warn("The input field is not sub-slice compatible to "
+                             "the distribution strategy of the operator. "
+                             "Performing an probably expensive "
+                             "redistribution.")
+            redistr_projection_val = self._projection_field.val.copy(
+                distribution_strategy=axes_local_distribution_strategy)
+            local_projection_vector = \
+                redistr_projection_val.get_local_data(copy=False)
+
+        local_x = x.val.get_local_data(copy=False)
+
+        l = len(local_projection_vector.shape)
+        sublist_projector = range(l)
+        sublist_x = np.arange(len(local_x.shape)) + l
+
+        for i in xrange(l):
+            a = active_axes[i]
+            sublist_x[a] = i
+
+        dotted = np.einsum(local_projection_vector, sublist_projector,
+                           local_x, sublist_x)
+
+        # get those elements from sublist_x that haven't got contracted
+        sublist_dotted = sublist_x[sublist_x >= l]
+
+        remultiplied = np.einsum(local_projection_vector, sublist_projector,
+                                 dotted, sublist_dotted,
+                                 sublist_x)
+        result_field = x.copy_empty(dtype=remultiplied.dtype)
+        result_field.val.set_local_data(remultiplied, copy=False)
+        return result_field
+
+    def _inverse_times(self, x, spaces):
+        raise NotImplementedError("The ProjectionOperator is a singular "
+                                  "operator and therefore has no inverse.")
+
+    # ---Mandatory properties and methods---
+
+    @property
+    def domain(self):
+        return self._projection_field.domain
+
+    @property
+    def implemented(self):
+        return True
+
+    @property
+    def unitary(self):
+        if self._unitary is None:
+            self._unitary = (self._projection_field.val == 1).all()
+        return self._unitary
+
+    @property
+    def symmetric(self):
+        return True
-- 
GitLab