central_zero_padder.py 4.17 KB
Newer Older
Julia Stadler's avatar
Julia Stadler committed
1 2 3 4 5 6 7 8 9 10 11 12 13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Julia Stadler's avatar
Julia Stadler committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Julia Stadler's avatar
Julia Stadler committed
18 19
import itertools

Philipp Arras's avatar
Philipp Arras committed
20 21 22
import numpy as np

from .. import dobj, utilities
Julia Stadler's avatar
Julia Stadler committed
23 24 25
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..field import Field
Philipp Arras's avatar
Philipp Arras committed
26
from .linear_operator import LinearOperator
Julia Stadler's avatar
Julia Stadler committed
27 28


Martin Reinecke's avatar
Martin Reinecke committed
29 30
# MR FIXME: for even axis lengths, we probably should split the value at the
#           highest frequency.
Philipp Arras's avatar
Philipp Arras committed
31
class CentralZeroPadder(LinearOperator):
Julia Stadler's avatar
Julia Stadler committed
32 33 34 35 36 37
    """Operator that enlarges a fields domain by adding zeros from the middle.

    Parameters
    ---------

    domain: Domain, tuple of Domains or DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
38 39
            The domain of the data that is input by "times" and output by
            "adjoint_times"
Julia Stadler's avatar
Julia Stadler committed
40 41 42 43 44 45 46 47
    new_shape: tuple
               Shape of the target domain.
    space: int, optional
           The index of the subdomain on which the operator should act
           If None, it is set to 0 if `domain` contains exactly one space.
           `domain[space]` must be an RGSpace.

    """
Martin Reinecke's avatar
Martin Reinecke committed
48

Julia Stadler's avatar
Julia Stadler committed
49 50 51 52 53
    def __init__(self, domain, new_shape, space=0):
        self._domain = DomainTuple.make(domain)
        self._space = utilities.infer_space(self._domain, space)
        dom = self._domain[self._space]

Julia Stadler's avatar
Julia Stadler committed
54
        # verify domains
Julia Stadler's avatar
Julia Stadler committed
55 56 57
        if not isinstance(dom, RGSpace):
            raise TypeError("RGSpace required")
        if len(new_shape) != len(dom.shape):
Martin Reinecke's avatar
Martin Reinecke committed
58
            raise ValueError("Shape mismatch")
Philipp Arras's avatar
Philipp Arras committed
59
        if any([a < b for a, b in zip(new_shape, dom.shape)]):
Julia Stadler's avatar
Julia Stadler committed
60 61
            raise ValueError("New shape must be larger than old shape")

Julia Stadler's avatar
Julia Stadler committed
62
        # make target space
63
        tgt = RGSpace(new_shape, dom.distances, harmonic=dom.harmonic)
Julia Stadler's avatar
Julia Stadler committed
64 65 66
        self._target = list(self._domain)
        self._target[self._space] = tgt
        self._target = DomainTuple.make(self._target)
Martin Reinecke's avatar
Martin Reinecke committed
67

Martin Reinecke's avatar
Martin Reinecke committed
68
        self._capability = self.TIMES | self.ADJOINT_TIMES
Julia Stadler's avatar
Julia Stadler committed
69

Martin Reinecke's avatar
Martin Reinecke committed
70
        # define the axes along which the input field is sliced
Julia Stadler's avatar
Julia Stadler committed
71 72 73 74
        slicer = []
        axes = self._target.axes[self._space]
        for i in range(len(self._domain.shape)):
            if i in axes:
Martin Reinecke's avatar
Martin Reinecke committed
75 76
                slicer_fw = slice(0, (self._domain.shape[i]+1)//2)
                slicer_bw = slice(-1, -1-(self._domain.shape[i]//2), -1)
77
                slicer.append((slicer_fw, slicer_bw))
Julia Stadler's avatar
Julia Stadler committed
78 79 80 81
        self.slicer = list(itertools.product(*slicer))

        for i in range(len(self.slicer)):
            for j in range(len(self._domain.shape)):
Philipp Arras's avatar
Philipp Arras committed
82
                if j not in axes:
Martin Reinecke's avatar
Martin Reinecke committed
83
                    tmp = list(self.slicer[i])
Julia Stadler's avatar
Julia Stadler committed
84
                    tmp.insert(j, slice(None))
85 86
                    self.slicer[i] = tuple(tmp)
        self.slicer = tuple(self.slicer)
Philipp Arras's avatar
Philipp Arras committed
87

Julia Stadler's avatar
Julia Stadler committed
88 89
    def apply(self, x, mode):
        self._check_input(x, mode)
Martin Reinecke's avatar
Martin Reinecke committed
90
        v = x.val
Martin Reinecke's avatar
Martin Reinecke committed
91
        shp_out = self._tgt(mode).shape
Martin Reinecke's avatar
Martin Reinecke committed
92 93
        v, x = dobj.ensure_not_distributed(v, self._target.axes[self._space])
        curax = dobj.distaxis(v)
Philipp Arras's avatar
Philipp Arras committed
94

Julia Stadler's avatar
Julia Stadler committed
95
        if mode == self.TIMES:
Julia Stadler's avatar
Julia Stadler committed
96 97
            # slice along each axis and copy the data to an
            # array of zeros which has the shape of the target domain
Martin Reinecke's avatar
Martin Reinecke committed
98
            y = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
Julia Stadler's avatar
Julia Stadler committed
99 100
            for i in self.slicer:
                y[i] = x[i]
Martin Reinecke's avatar
Martin Reinecke committed
101
        else:
Julia Stadler's avatar
Julia Stadler committed
102 103
            # slice along each axis and copy the data to an array of zeros
            # which has the shape of the input domain to remove excess zeros
Martin Reinecke's avatar
Martin Reinecke committed
104
            y = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
Julia Stadler's avatar
Julia Stadler committed
105 106
            for i in self.slicer:
                y[i] = x[i]
Martin Reinecke's avatar
Martin Reinecke committed
107 108
        v = dobj.from_local_data(shp_out, y, distaxis=dobj.distaxis(v))
        return Field(self._tgt(mode), dobj.ensure_default_distributed(v))