linear_interpolation.py 3.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

from __future__ import absolute_import, division, print_function

from ..compat import *


from .. import Field, UnstructuredDomain
from ..sugar import makeDomain 
from .linear_operator import LinearOperator
from numpy import array, prod, mgrid, int64, arange, ravel_multi_index, zeros, abs, ravel
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import aslinearoperator


class LinearInterpolator(LinearOperator):
    def __init__(self, domain, positions):
        """

        :param domain:
            RGSpace
        :param target:
            UnstructuredDomain, shape (ndata,)
        :param positions:
            positions at which to interpolate
            Field with UnstructuredDomain, shape (dim, ndata)
        """
        self._domain = makeDomain(domain)
        N_points = positions.shape[1]
        self._target = makeDomain(UnstructuredDomain(N_points))
        self._capability = self.TIMES | self.ADJOINT_TIMES
        self._build_mat(positions, N_points)

    def _build_mat(self, positions, N_points):
        ndim = positions.shape[0]
        mg = mgrid[(slice(0,2),)*ndim]
        mg = array(list(map(ravel, mg)))
        dist = array(self.domain[0].distances).reshape((-1,1))
        pos = positions/dist
        excess = pos-pos.astype(int64)
        pos = pos.astype(int64)
        data = zeros((len(mg[0]), N_points))
        ii = zeros((len(mg[0]), N_points), dtype=int64)
        jj = zeros((len(mg[0]), N_points), dtype=int64)
        for i in range(len(mg[0])):
            factor = prod(abs(1-mg[:,i].reshape((-1,1))-excess),axis=0)
            #print(factor)
            data[i,:] = factor
            fromi = pos+mg[:,i].reshape((-1,1))
            ii[i, :] = arange(N_points)
            jj[i, :] = ravel_multi_index(fromi, self.domain.shape)
        self._mat = coo_matrix((data.reshape(-1), 
            (ii.reshape(-1),jj.reshape(-1))),
            (N_points, prod(self.domain.shape))) 
        self._mat = aslinearoperator(self._mat)
        
    def apply(self, x, mode):
        self._check_input(x, mode)
        x_val = x.to_global_data()
        if mode == self.TIMES:
            res = self._mat.matvec(x_val.reshape((-1,)))
            return Field.from_global_data(self.target, res)
        res = self._mat.rmatvec(x_val).reshape(self.domain.shape)
        return Field.from_global_data(self.domain, res)