regridding_operator.py 3.93 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17 18 19

import numpy as np

Martin Reinecke's avatar
Martin Reinecke committed
20
from .. import dobj
21
from ..domain_tuple import DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
22
from ..domains.rg_space import RGSpace
23
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
24
from ..utilities import infer_space, special_add_at
25 26 27 28
from .linear_operator import LinearOperator


class RegriddingOperator(LinearOperator):
Martin Reinecke's avatar
Martin Reinecke committed
29
    """Linearly interpolates an RGSpace to an RGSpace with coarser resolution.
Philipp Arras's avatar
Docs  
Philipp Arras committed
30 31 32 33 34 35 36 37 38 39 40

    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
        domain[space] needs to be an :class:`RGSpace`.
    new_shape : tuple of int
        Shape of the space which domain[space] is replaced by. Each entry must
        be smaller or equal to the respective entry in `domain[space].shape`.
    space : int
        Index of space in `domain` on which the operator shall act.
        Default is 0.
41
    """
Martin Reinecke's avatar
Martin Reinecke committed
42
    def __init__(self, domain, new_shape, space=0):
43
        self._domain = DomainTuple.make(domain)
Martin Reinecke's avatar
Martin Reinecke committed
44 45
        self._space = infer_space(self._domain, space)
        dom = self._domain[self._space]
46

Martin Reinecke's avatar
Martin Reinecke committed
47 48 49 50 51 52
        if not isinstance(dom, RGSpace):
            raise TypeError("RGSpace required")
        if len(new_shape) != len(dom.shape):
            raise ValueError("Shape mismatch")
        if any([a > b for a, b in zip(new_shape, dom.shape)]):
            raise ValueError("New shape must not be larger than old shape")
Philipp Arras's avatar
Docs  
Philipp Arras committed
53 54
        if any([ii <= 0 for ii in new_shape]):
            raise ValueError('New shape must not be zero or negative.')
55

Martin Reinecke's avatar
Martin Reinecke committed
56 57
        newdist = tuple(dom.distances[i]*dom.shape[i]/new_shape[i]
                        for i in range(len(dom.shape)))
58

Martin Reinecke's avatar
Martin Reinecke committed
59 60 61 62 63
        tgt = RGSpace(new_shape, newdist)
        self._target = list(self._domain)
        self._target[self._space] = tgt
        self._target = DomainTuple.make(self._target)
        self._capability = self.TIMES | self.ADJOINT_TIMES
64

Martin Reinecke's avatar
Martin Reinecke committed
65 66 67 68 69
        ndim = len(new_shape)
        self._bindex = [None] * ndim
        self._frac = [None] * ndim
        for d in range(ndim):
            tmp = np.arange(new_shape[d])*(newdist[d]/dom.distances[d])
Martin Reinecke's avatar
Martin Reinecke committed
70
            self._bindex[d] = np.minimum(dom.shape[d]-2, tmp.astype(np.int))
Martin Reinecke's avatar
bug fix  
Martin Reinecke committed
71
            self._frac[d] = tmp-self._bindex[d]
72 73 74

    def apply(self, x, mode):
        self._check_input(x, mode)
Martin Reinecke's avatar
Martin Reinecke committed
75
        v = x.val
Martin Reinecke's avatar
Martin Reinecke committed
76 77
        ndim = len(self.target.shape)
        curshp = list(self._dom(mode).shape)
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
78
        tgtshp = self._tgt(mode).shape
Martin Reinecke's avatar
Martin Reinecke committed
79 80 81 82
        d0 = self._target.axes[self._space][0]
        for d in self._target.axes[self._space]:
            idx = (slice(None),) * d
            wgt = self._frac[d-d0].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1))
83

Martin Reinecke's avatar
Martin Reinecke committed
84
            v, x = dobj.ensure_not_distributed(v, (d,))
85

Martin Reinecke's avatar
Martin Reinecke committed
86 87
            if mode == self.ADJOINT_TIMES:
                shp = list(x.shape)
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
88
                shp[d] = tgtshp[d]
Martin Reinecke's avatar
Martin Reinecke committed
89 90 91 92 93 94
                xnew = np.zeros(shp, dtype=x.dtype)
                xnew = special_add_at(xnew, d, self._bindex[d-d0], x*(1.-wgt))
                xnew = special_add_at(xnew, d, self._bindex[d-d0]+1, x*wgt)
            else:  # TIMES
                xnew = x[idx + (self._bindex[d-d0],)] * (1.-wgt)
                xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
95

Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
96
            curshp[d] = xnew.shape[d]
Martin Reinecke's avatar
Martin Reinecke committed
97 98
            v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
        return Field(self._tgt(mode), dobj.ensure_default_distributed(v))