linear_interpolation.py 4.14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17

Philipp Arras's avatar
Philipp Arras committed
18
19
20
from functools import reduce
from operator import mul

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
21
import numpy as np
Philipp Arras's avatar
Cleanup    
Philipp Arras committed
22
23
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import aslinearoperator
24

25
from ..domains.rg_space import RGSpace
Philipp Arras's avatar
Cleanup    
Philipp Arras committed
26
27
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
28
from ..sugar import makeDomain
29
30
31
32
from .linear_operator import LinearOperator


class LinearInterpolator(LinearOperator):
Philipp Arras's avatar
Philipp Arras committed
33
    def __init__(self, domain, sampling_points):
34
        """
35
        Multilinear interpolation for points in an RGSpace
36

Philipp Arras's avatar
Philipp Arras committed
37
38
39
40
41
        Parameters
        ----------
        domain : RGSpace
        positions : numpy.ndarray
            Positions at which to interpolate
42
            Field with UnstructuredDomain, shape (dim, ndata)
Philipp Arras's avatar
Philipp Arras committed
43
44
45
46
47
48

        Notes
        -----
        Positions that are not within the RGSpace are wrapped according to
        periodic boundary conditions. This reflects the general property of
        RGSpaces to be tori topologically.
49
50
        """
        self._domain = makeDomain(domain)
Philipp Arras's avatar
Philipp Arras committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        for dom in self.domain:
            if not isinstance(dom, RGSpace):
                raise TypeError
        dims = [len(dom.shape) for dom in self.domain]

        # FIXME This needs to be removed as soon as the bug below is fixed.
        if not dims.count(dims[0]) == len(dims):
            raise TypeError(
                'This is a bug. Please extend LinearInterpolators functionality!'
            )

        shp = sampling_points.shape
        if not (isinstance(sampling_points, np.ndarray) and len(shp) == 2):
            raise TypeError
        n_dim, n_points = shp
        if not n_dim == reduce(mul, dims):
            raise TypeError
        self._target = makeDomain(UnstructuredDomain(n_points))
69
        self._capability = self.TIMES | self.ADJOINT_TIMES
Philipp Arras's avatar
Philipp Arras committed
70
        self._build_mat(sampling_points, n_points)
71

Philipp Arras's avatar
Philipp Arras committed
72
73
    def _build_mat(self, sampling_points, N_points):
        ndim = sampling_points.shape[0]
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
74
75
        mg = np.mgrid[(slice(0, 2),)*ndim]
        mg = np.array(list(map(np.ravel, mg)))
Philipp Arras's avatar
Philipp Arras committed
76
77
78
        dist = [list(dom.distances) for dom in self.domain]
        # FIXME This breaks as soon as not all domains have the same number of
        # dimensions.
Philipp Arras's avatar
Tweaks    
Philipp Arras committed
79
        dist = np.array(dist).reshape(-1, 1)
Philipp Arras's avatar
Philipp Arras committed
80
81
        pos = sampling_points/dist
        excess = pos - pos.astype(np.int64)
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
82
83
84
85
86
        pos = pos.astype(np.int64)
        max_index = np.array(self.domain.shape).reshape(-1, 1)
        data = np.zeros((len(mg[0]), N_points))
        ii = np.zeros((len(mg[0]), N_points), dtype=np.int64)
        jj = np.zeros((len(mg[0]), N_points), dtype=np.int64)
87
        for i in range(len(mg[0])):
Philipp Arras's avatar
Philipp Arras committed
88
89
            factor = np.prod(
                np.abs(1 - mg[:, i].reshape(-1, 1) - excess), axis=0)
90
            data[i, :] = factor
Philipp Arras's avatar
Philipp Arras committed
91
            fromi = (pos + mg[:, i].reshape(-1, 1)) % max_index
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
92
93
            ii[i, :] = np.arange(N_points)
            jj[i, :] = np.ravel_multi_index(fromi, self.domain.shape)
94
        self._mat = coo_matrix((data.reshape(-1),
Philipp Arras's avatar
Philipp Arras committed
95
                                (ii.reshape(-1), jj.reshape(-1))),
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
96
                               (N_points, np.prod(self.domain.shape)))
97
        self._mat = aslinearoperator(self._mat)
98

99
100
101
102
    def apply(self, x, mode):
        self._check_input(x, mode)
        x_val = x.to_global_data()
        if mode == self.TIMES:
Philipp Arras's avatar
Tweaks    
Philipp Arras committed
103
            res = self._mat.matvec(x_val.reshape(-1))
Philipp Arras's avatar
Philipp Arras committed
104
105
106
        else:
            res = self._mat.rmatvec(x_val).reshape(self.domain.shape)
        return Field.from_global_data(self._tgt(mode), res)