central_zero_padder.py 4.19 KB
Newer Older
Julia Stadler's avatar
Julia Stadler committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Julia Stadler's avatar
Julia Stadler committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Julia Stadler's avatar
Julia Stadler committed
18
19
import itertools

Philipp Arras's avatar
Philipp Arras committed
20
21
22
import numpy as np

from .. import dobj, utilities
Martin Reinecke's avatar
Martin Reinecke committed
23
from ..compat import *
Julia Stadler's avatar
Julia Stadler committed
24
25
26
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..field import Field
Philipp Arras's avatar
Philipp Arras committed
27
from .linear_operator import LinearOperator
Julia Stadler's avatar
Julia Stadler committed
28
29


Martin Reinecke's avatar
Martin Reinecke committed
30
31
# MR FIXME: for even axis lengths, we probably should split the value at the
#           highest frequency.
Philipp Arras's avatar
Philipp Arras committed
32
class CentralZeroPadder(LinearOperator):
Julia Stadler's avatar
Julia Stadler committed
33
34
35
36
37
38
    """Operator that enlarges a fields domain by adding zeros from the middle.

    Parameters
    ---------

    domain: Domain, tuple of Domains or DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
39
40
            The domain of the data that is input by "times" and output by
            "adjoint_times"
Julia Stadler's avatar
Julia Stadler committed
41
42
43
44
45
46
47
48
    new_shape: tuple
               Shape of the target domain.
    space: int, optional
           The index of the subdomain on which the operator should act
           If None, it is set to 0 if `domain` contains exactly one space.
           `domain[space]` must be an RGSpace.

    """
Martin Reinecke's avatar
Martin Reinecke committed
49

Julia Stadler's avatar
Julia Stadler committed
50
51
52
53
54
    def __init__(self, domain, new_shape, space=0):
        self._domain = DomainTuple.make(domain)
        self._space = utilities.infer_space(self._domain, space)
        dom = self._domain[self._space]

Julia Stadler's avatar
Julia Stadler committed
55
        # verify domains
Julia Stadler's avatar
Julia Stadler committed
56
57
58
        if not isinstance(dom, RGSpace):
            raise TypeError("RGSpace required")
        if len(new_shape) != len(dom.shape):
Martin Reinecke's avatar
Martin Reinecke committed
59
            raise ValueError("Shape mismatch")
Philipp Arras's avatar
Philipp Arras committed
60
        if any([a < b for a, b in zip(new_shape, dom.shape)]):
Julia Stadler's avatar
Julia Stadler committed
61
62
            raise ValueError("New shape must be larger than old shape")

Julia Stadler's avatar
Julia Stadler committed
63
        # make target space
64
        tgt = RGSpace(new_shape, dom.distances, harmonic=dom.harmonic)
Julia Stadler's avatar
Julia Stadler committed
65
66
67
        self._target = list(self._domain)
        self._target[self._space] = tgt
        self._target = DomainTuple.make(self._target)
Martin Reinecke's avatar
Martin Reinecke committed
68

Martin Reinecke's avatar
Martin Reinecke committed
69
        self._capability = self.TIMES | self.ADJOINT_TIMES
Julia Stadler's avatar
Julia Stadler committed
70

Martin Reinecke's avatar
Martin Reinecke committed
71
        # define the axes along which the input field is sliced
Julia Stadler's avatar
Julia Stadler committed
72
73
74
75
        slicer = []
        axes = self._target.axes[self._space]
        for i in range(len(self._domain.shape)):
            if i in axes:
Martin Reinecke's avatar
Martin Reinecke committed
76
77
                slicer_fw = slice(0, (self._domain.shape[i]+1)//2)
                slicer_bw = slice(-1, -1-(self._domain.shape[i]//2), -1)
78
                slicer.append((slicer_fw, slicer_bw))
Julia Stadler's avatar
Julia Stadler committed
79
80
81
82
        self.slicer = list(itertools.product(*slicer))

        for i in range(len(self.slicer)):
            for j in range(len(self._domain.shape)):
Philipp Arras's avatar
Philipp Arras committed
83
                if j not in axes:
Martin Reinecke's avatar
Martin Reinecke committed
84
                    tmp = list(self.slicer[i])
Julia Stadler's avatar
Julia Stadler committed
85
                    tmp.insert(j, slice(None))
86
87
                    self.slicer[i] = tuple(tmp)
        self.slicer = tuple(self.slicer)
Philipp Arras's avatar
Philipp Arras committed
88

Julia Stadler's avatar
Julia Stadler committed
89
90
    def apply(self, x, mode):
        self._check_input(x, mode)
Martin Reinecke's avatar
Martin Reinecke committed
91
        v = x.val
Martin Reinecke's avatar
Martin Reinecke committed
92
        shp_out = self._tgt(mode).shape
Martin Reinecke's avatar
Martin Reinecke committed
93
94
        v, x = dobj.ensure_not_distributed(v, self._target.axes[self._space])
        curax = dobj.distaxis(v)
Philipp Arras's avatar
Philipp Arras committed
95

Julia Stadler's avatar
Julia Stadler committed
96
        if mode == self.TIMES:
Julia Stadler's avatar
Julia Stadler committed
97
98
            # slice along each axis and copy the data to an
            # array of zeros which has the shape of the target domain
Martin Reinecke's avatar
Martin Reinecke committed
99
            y = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
Julia Stadler's avatar
Julia Stadler committed
100
101
            for i in self.slicer:
                y[i] = x[i]
Martin Reinecke's avatar
Martin Reinecke committed
102
        else:
Julia Stadler's avatar
Julia Stadler committed
103
104
            # slice along each axis and copy the data to an array of zeros
            # which has the shape of the input domain to remove excess zeros
Martin Reinecke's avatar
Martin Reinecke committed
105
            y = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
Julia Stadler's avatar
Julia Stadler committed
106
107
            for i in self.slicer:
                y[i] = x[i]
Martin Reinecke's avatar
Martin Reinecke committed
108
109
        v = dobj.from_local_data(shp_out, y, distaxis=dobj.distaxis(v))
        return Field(self._tgt(mode), dobj.ensure_default_distributed(v))