regridding_operator.py 3.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17
18
19

import numpy as np

Martin Reinecke's avatar
Martin Reinecke committed
20
from .. import dobj
21
from ..domain_tuple import DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
22
from ..domains.rg_space import RGSpace
23
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
24
from ..utilities import infer_space, special_add_at
25
26
27
28
from .linear_operator import LinearOperator


class RegriddingOperator(LinearOperator):
29
    """Linearly interpolates a RGSpace to an RGSpace with coarser resolution.
Philipp Arras's avatar
Docs    
Philipp Arras committed
30
31
32
33
34
35
36
37
38
39
40

    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
        domain[space] needs to be an :class:`RGSpace`.
    new_shape : tuple of int
        Shape of the space which domain[space] is replaced by. Each entry must
        be smaller or equal to the respective entry in `domain[space].shape`.
    space : int
        Index of space in `domain` on which the operator shall act.
        Default is 0.
41
    """
Martin Reinecke's avatar
Martin Reinecke committed
42
    def __init__(self, domain, new_shape, space=0):
43
        self._domain = DomainTuple.make(domain)
Martin Reinecke's avatar
Martin Reinecke committed
44
45
        self._space = infer_space(self._domain, space)
        dom = self._domain[self._space]
46

Martin Reinecke's avatar
Martin Reinecke committed
47
48
49
50
51
52
        if not isinstance(dom, RGSpace):
            raise TypeError("RGSpace required")
        if len(new_shape) != len(dom.shape):
            raise ValueError("Shape mismatch")
        if any([a > b for a, b in zip(new_shape, dom.shape)]):
            raise ValueError("New shape must not be larger than old shape")
Philipp Arras's avatar
Docs    
Philipp Arras committed
53
54
        if any([ii <= 0 for ii in new_shape]):
            raise ValueError('New shape must not be zero or negative.')
55

Martin Reinecke's avatar
Martin Reinecke committed
56
57
        newdist = tuple(dom.distances[i]*dom.shape[i]/new_shape[i]
                        for i in range(len(dom.shape)))
58

Martin Reinecke's avatar
Martin Reinecke committed
59
60
61
62
63
        tgt = RGSpace(new_shape, newdist)
        self._target = list(self._domain)
        self._target[self._space] = tgt
        self._target = DomainTuple.make(self._target)
        self._capability = self.TIMES | self.ADJOINT_TIMES
64

Martin Reinecke's avatar
Martin Reinecke committed
65
66
67
68
69
        ndim = len(new_shape)
        self._bindex = [None] * ndim
        self._frac = [None] * ndim
        for d in range(ndim):
            tmp = np.arange(new_shape[d])*(newdist[d]/dom.distances[d])
Martin Reinecke's avatar
Martin Reinecke committed
70
            self._bindex[d] = np.minimum(dom.shape[d]-2, tmp.astype(np.int))
Martin Reinecke's avatar
bug fix    
Martin Reinecke committed
71
            self._frac[d] = tmp-self._bindex[d]
72
73
74

    def apply(self, x, mode):
        self._check_input(x, mode)
Martin Reinecke's avatar
Martin Reinecke committed
75
        v = x.val
Martin Reinecke's avatar
Martin Reinecke committed
76
77
        ndim = len(self.target.shape)
        curshp = list(self._dom(mode).shape)
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
78
        tgtshp = self._tgt(mode).shape
Martin Reinecke's avatar
Martin Reinecke committed
79
80
81
82
        d0 = self._target.axes[self._space][0]
        for d in self._target.axes[self._space]:
            idx = (slice(None),) * d
            wgt = self._frac[d-d0].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1))
83

Martin Reinecke's avatar
Martin Reinecke committed
84
            v, x = dobj.ensure_not_distributed(v, (d,))
85

Martin Reinecke's avatar
Martin Reinecke committed
86
87
            if mode == self.ADJOINT_TIMES:
                shp = list(x.shape)
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
88
                shp[d] = tgtshp[d]
Martin Reinecke's avatar
Martin Reinecke committed
89
90
91
92
93
94
                xnew = np.zeros(shp, dtype=x.dtype)
                xnew = special_add_at(xnew, d, self._bindex[d-d0], x*(1.-wgt))
                xnew = special_add_at(xnew, d, self._bindex[d-d0]+1, x*wgt)
            else:  # TIMES
                xnew = x[idx + (self._bindex[d-d0],)] * (1.-wgt)
                xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
95

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
96
            curshp[d] = xnew.shape[d]
Martin Reinecke's avatar
Martin Reinecke committed
97
98
            v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
        return Field(self._tgt(mode), dobj.ensure_default_distributed(v))