From dc5042f0512576521c61a7be228a8bf6f2baa28f Mon Sep 17 00:00:00 2001
From: Reimar Leike <reimar@leike.name>
Date: Mon, 8 Oct 2018 17:57:25 +0200
Subject: [PATCH] added an operator doing multilinear interpolation

---
 nifty5/__init__.py                       |  1 +
 nifty5/operators/linear_interpolation.py | 81 ++++++++++++++++++++++++
 test/test_operators/test_adjoint.py      | 13 ++--
 3 files changed, 89 insertions(+), 6 deletions(-)
 create mode 100644 nifty5/operators/linear_interpolation.py

diff --git a/nifty5/__init__.py b/nifty5/__init__.py
index b766ec525..773c52779 100644
--- a/nifty5/__init__.py
+++ b/nifty5/__init__.py
@@ -24,6 +24,7 @@ from .operators.diagonal_operator import DiagonalOperator
 from .operators.distributors import DOFDistributor, PowerDistributor
 from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter
 from .operators.contraction_operator import ContractionOperator
+from .operators.linear_interpolation import LinearInterpolator
 from .operators.endomorphic_operator import EndomorphicOperator
 from .operators.exp_transform import ExpTransform
 from .operators.harmonic_operators import (
diff --git a/nifty5/operators/linear_interpolation.py b/nifty5/operators/linear_interpolation.py
new file mode 100644
index 000000000..36f31623d
--- /dev/null
+++ b/nifty5/operators/linear_interpolation.py
@@ -0,0 +1,81 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Copyright(C) 2013-2018 Max-Planck-Society
+#
+# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
+# and financially supported by the Studienstiftung des deutschen Volkes.
+
+from __future__ import absolute_import, division, print_function
+
+from ..compat import *
+
+
+from .. import Field, UnstructuredDomain
+from ..sugar import makeDomain 
+from .linear_operator import LinearOperator
+from numpy import array, prod, mgrid, int64, arange, ravel_multi_index, zeros, abs, ravel
+from scipy.sparse import coo_matrix
+from scipy.sparse.linalg import aslinearoperator
+
+
+class LinearInterpolator(LinearOperator):
+    def __init__(self, domain, positions):
+        """
+
+        :param domain:
+            RGSpace
+        :param target:
+            UnstructuredDomain, shape (ndata,)
+        :param positions:
+            positions at which to interpolate
+            Field with UnstructuredDomain, shape (dim, ndata)
+        """
+        self._domain = makeDomain(domain)
+        N_points = positions.shape[1]
+        self._target = makeDomain(UnstructuredDomain(N_points))
+        self._capability = self.TIMES | self.ADJOINT_TIMES
+        self._build_mat(positions, N_points)
+
+    def _build_mat(self, positions, N_points):
+        ndim = positions.shape[0]
+        mg = mgrid[(slice(0,2),)*ndim]
+        mg = array(list(map(ravel, mg)))
+        dist = array(self.domain[0].distances).reshape((-1,1))
+        pos = positions/dist
+        excess = pos-pos.astype(int64)
+        pos = pos.astype(int64)
+        data = zeros((len(mg[0]), N_points))
+        ii = zeros((len(mg[0]), N_points), dtype=int64)
+        jj = zeros((len(mg[0]), N_points), dtype=int64)
+        for i in range(len(mg[0])):
+            factor = prod(abs(1-mg[:,i].reshape((-1,1))-excess),axis=0)
+            #print(factor)
+            data[i,:] = factor
+            fromi = pos+mg[:,i].reshape((-1,1))
+            ii[i, :] = arange(N_points)
+            jj[i, :] = ravel_multi_index(fromi, self.domain.shape)
+        self._mat = coo_matrix((data.reshape(-1), 
+            (ii.reshape(-1),jj.reshape(-1))),
+            (N_points, prod(self.domain.shape))) 
+        self._mat = aslinearoperator(self._mat)
+        
+    def apply(self, x, mode):
+        self._check_input(x, mode)
+        x_val = x.to_global_data()
+        if mode == self.TIMES:
+            res = self._mat.matvec(x_val.reshape((-1,)))
+            return Field.from_global_data(self.target, res)
+        res = self._mat.rmatvec(x_val).reshape(self.domain.shape)
+        return Field.from_global_data(self.domain, res)
+
diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py
index 34b8e1b0c..cafb0af13 100644
--- a/test/test_operators/test_adjoint.py
+++ b/test/test_operators/test_adjoint.py
@@ -58,12 +58,13 @@ class Consistency_Tests(unittest.TestCase):
         op = a+b
         ift.extra.consistency_check(op, dtype, dtype)
 
-    @expand(product(_h_spaces + _p_spaces + _pow_spaces,
-                    [np.float64, np.complex128]))
-    def testVdotOperator(self, sp, dtype):
-        op = ift.VdotOperator(ift.Field.from_random("normal", sp,
-                                                    dtype=dtype))
-        ift.extra.consistency_check(op, dtype, dtype)
+    def testLinearInterpolator(self):
+        sp = ift.RGSpace((10,8), distances=(0.1, 3.5))
+        pos = np.random.rand(2, 23)
+        pos[0,:] *= 0.9
+        pos[1,:] *= 7*3.5
+        op = ift.LinearInterpolator(sp, pos)
+        ift.extra.consistency_check(op)
 
     @expand(product([(ift.RGSpace(10, harmonic=True), 4, 0),
                      (ift.RGSpace((24, 31), distances=(0.4, 2.34),
-- 
GitLab