From dc5042f0512576521c61a7be228a8bf6f2baa28f Mon Sep 17 00:00:00 2001 From: Reimar Leike <reimar@leike.name> Date: Mon, 8 Oct 2018 17:57:25 +0200 Subject: [PATCH] added an operator doing multilinear interpolation --- nifty5/__init__.py | 1 + nifty5/operators/linear_interpolation.py | 81 ++++++++++++++++++++++++ test/test_operators/test_adjoint.py | 13 ++-- 3 files changed, 89 insertions(+), 6 deletions(-) create mode 100644 nifty5/operators/linear_interpolation.py diff --git a/nifty5/__init__.py b/nifty5/__init__.py index b766ec525..773c52779 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -24,6 +24,7 @@ from .operators.diagonal_operator import DiagonalOperator from .operators.distributors import DOFDistributor, PowerDistributor from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter from .operators.contraction_operator import ContractionOperator +from .operators.linear_interpolation import LinearInterpolator from .operators.endomorphic_operator import EndomorphicOperator from .operators.exp_transform import ExpTransform from .operators.harmonic_operators import ( diff --git a/nifty5/operators/linear_interpolation.py b/nifty5/operators/linear_interpolation.py new file mode 100644 index 000000000..36f31623d --- /dev/null +++ b/nifty5/operators/linear_interpolation.py @@ -0,0 +1,81 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2018 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +from __future__ import absolute_import, division, print_function + +from ..compat import * + + +from .. import Field, UnstructuredDomain +from ..sugar import makeDomain +from .linear_operator import LinearOperator +from numpy import array, prod, mgrid, int64, arange, ravel_multi_index, zeros, abs, ravel +from scipy.sparse import coo_matrix +from scipy.sparse.linalg import aslinearoperator + + +class LinearInterpolator(LinearOperator): + def __init__(self, domain, positions): + """ + + :param domain: + RGSpace + :param target: + UnstructuredDomain, shape (ndata,) + :param positions: + positions at which to interpolate + Field with UnstructuredDomain, shape (dim, ndata) + """ + self._domain = makeDomain(domain) + N_points = positions.shape[1] + self._target = makeDomain(UnstructuredDomain(N_points)) + self._capability = self.TIMES | self.ADJOINT_TIMES + self._build_mat(positions, N_points) + + def _build_mat(self, positions, N_points): + ndim = positions.shape[0] + mg = mgrid[(slice(0,2),)*ndim] + mg = array(list(map(ravel, mg))) + dist = array(self.domain[0].distances).reshape((-1,1)) + pos = positions/dist + excess = pos-pos.astype(int64) + pos = pos.astype(int64) + data = zeros((len(mg[0]), N_points)) + ii = zeros((len(mg[0]), N_points), dtype=int64) + jj = zeros((len(mg[0]), N_points), dtype=int64) + for i in range(len(mg[0])): + factor = prod(abs(1-mg[:,i].reshape((-1,1))-excess),axis=0) + #print(factor) + data[i,:] = factor + fromi = pos+mg[:,i].reshape((-1,1)) + ii[i, :] = arange(N_points) + jj[i, :] = ravel_multi_index(fromi, self.domain.shape) + self._mat = coo_matrix((data.reshape(-1), + (ii.reshape(-1),jj.reshape(-1))), + (N_points, prod(self.domain.shape))) + self._mat = aslinearoperator(self._mat) + + def apply(self, x, mode): + self._check_input(x, mode) + x_val = x.to_global_data() + if mode == self.TIMES: + res = self._mat.matvec(x_val.reshape((-1,))) + return Field.from_global_data(self.target, res) + res = self._mat.rmatvec(x_val).reshape(self.domain.shape) + return Field.from_global_data(self.domain, res) + diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py index 34b8e1b0c..cafb0af13 100644 --- a/test/test_operators/test_adjoint.py +++ b/test/test_operators/test_adjoint.py @@ -58,12 +58,13 @@ class Consistency_Tests(unittest.TestCase): op = a+b ift.extra.consistency_check(op, dtype, dtype) - @expand(product(_h_spaces + _p_spaces + _pow_spaces, - [np.float64, np.complex128])) - def testVdotOperator(self, sp, dtype): - op = ift.VdotOperator(ift.Field.from_random("normal", sp, - dtype=dtype)) - ift.extra.consistency_check(op, dtype, dtype) + def testLinearInterpolator(self): + sp = ift.RGSpace((10,8), distances=(0.1, 3.5)) + pos = np.random.rand(2, 23) + pos[0,:] *= 0.9 + pos[1,:] *= 7*3.5 + op = ift.LinearInterpolator(sp, pos) + ift.extra.consistency_check(op) @expand(product([(ift.RGSpace(10, harmonic=True), 4, 0), (ift.RGSpace((24, 31), distances=(0.4, 2.34), -- GitLab