linear_interpolation.py 4.05 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17

Philipp Arras's avatar
Philipp Arras committed
18
from functools import reduce
19
from operator import add
Philipp Arras's avatar
Philipp Arras committed
20

Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
21
import numpy as np
Philipp Arras's avatar
Cleanup  
Philipp Arras committed
22 23
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import aslinearoperator
24

25
from ..domains.rg_space import RGSpace
Philipp Arras's avatar
Cleanup  
Philipp Arras committed
26 27
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
28
from ..sugar import makeDomain
29 30 31 32
from .linear_operator import LinearOperator


class LinearInterpolator(LinearOperator):
33
    """Multilinear interpolation for points in an RGSpace
34 35 36 37

    Parameters
    ----------
    domain : RGSpace
Philipp Arras's avatar
Fixups  
Philipp Arras committed
38 39
    sampling_points : numpy.ndarray
        Positions at which to interpolate, shape (dim, ndata),
40 41 42 43 44 45 46

    Notes
    -----
    Positions that are not within the RGSpace are wrapped according to
    periodic boundary conditions. This reflects the general property of
    RGSpaces to be tori topologically.
    """
Philipp Arras's avatar
Philipp Arras committed
47
    def __init__(self, domain, sampling_points):
48
        self._domain = makeDomain(domain)
Philipp Arras's avatar
Philipp Arras committed
49 50 51 52 53 54
        for dom in self.domain:
            if not isinstance(dom, RGSpace):
                raise TypeError
        dims = [len(dom.shape) for dom in self.domain]

        # FIXME This needs to be removed as soon as the bug below is fixed.
55
        if dims.count(dims[0]) != len(dims):
Martin Reinecke's avatar
Martin Reinecke committed
56 57
            raise TypeError("This is a bug. Please extend"
                            "LinearInterpolator's functionality!")
Philipp Arras's avatar
Philipp Arras committed
58 59 60 61 62

        shp = sampling_points.shape
        if not (isinstance(sampling_points, np.ndarray) and len(shp) == 2):
            raise TypeError
        n_dim, n_points = shp
63
        if n_dim != reduce(add, dims):
Philipp Arras's avatar
Philipp Arras committed
64 65
            raise TypeError
        self._target = makeDomain(UnstructuredDomain(n_points))
66
        self._capability = self.TIMES | self.ADJOINT_TIMES
Philipp Arras's avatar
Philipp Arras committed
67
        self._build_mat(sampling_points, n_points)
68

Philipp Arras's avatar
Philipp Arras committed
69 70
    def _build_mat(self, sampling_points, N_points):
        ndim = sampling_points.shape[0]
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
71 72
        mg = np.mgrid[(slice(0, 2),)*ndim]
        mg = np.array(list(map(np.ravel, mg)))
Philipp Arras's avatar
Philipp Arras committed
73 74 75
        dist = [list(dom.distances) for dom in self.domain]
        # FIXME This breaks as soon as not all domains have the same number of
        # dimensions.
Philipp Arras's avatar
Tweaks  
Philipp Arras committed
76
        dist = np.array(dist).reshape(-1, 1)
Philipp Arras's avatar
Philipp Arras committed
77
        pos = sampling_points/dist
78 79
        excess = pos - np.floor(pos)
        pos = np.floor(pos).astype(np.int64)
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
80 81 82 83
        max_index = np.array(self.domain.shape).reshape(-1, 1)
        data = np.zeros((len(mg[0]), N_points))
        ii = np.zeros((len(mg[0]), N_points), dtype=np.int64)
        jj = np.zeros((len(mg[0]), N_points), dtype=np.int64)
84
        for i in range(len(mg[0])):
Philipp Arras's avatar
Philipp Arras committed
85 86
            factor = np.prod(
                np.abs(1 - mg[:, i].reshape(-1, 1) - excess), axis=0)
87
            data[i, :] = factor
Philipp Arras's avatar
Philipp Arras committed
88
            fromi = (pos + mg[:, i].reshape(-1, 1)) % max_index
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
89 90
            ii[i, :] = np.arange(N_points)
            jj[i, :] = np.ravel_multi_index(fromi, self.domain.shape)
91
        self._mat = coo_matrix((data.reshape(-1),
Philipp Arras's avatar
Philipp Arras committed
92
                                (ii.reshape(-1), jj.reshape(-1))),
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
93
                               (N_points, np.prod(self.domain.shape)))
94
        self._mat = aslinearoperator(self._mat)
95

96 97 98 99
    def apply(self, x, mode):
        self._check_input(x, mode)
        x_val = x.to_global_data()
        if mode == self.TIMES:
Philipp Arras's avatar
Tweaks  
Philipp Arras committed
100
            res = self._mat.matvec(x_val.reshape(-1))
Philipp Arras's avatar
Philipp Arras committed
101 102 103
        else:
            res = self._mat.rmatvec(x_val).reshape(self.domain.shape)
        return Field.from_global_data(self._tgt(mode), res)