field_zero_padder.py 4.66 KB
Newer Older
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17
18
19

import numpy as np

Philipp Arras's avatar
Philipp Arras committed
20
from .. import dobj, utilities
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..domain_tuple import DomainTuple
22
23
from ..domains.rg_space import RGSpace
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
24
25
26
27
from .linear_operator import LinearOperator


class FieldZeroPadder(LinearOperator):
Martin Reinecke's avatar
Martin Reinecke committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    """Operator which applies zero-padding to one of the subdomains of its
    input field

    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
        The operator's input domain.
    new_shape : list or tuple of int
        The new dimensions of the subdomain which is zero-padded.
        No entry must be smaller than the corresponding dimension in the
        operator's domain.
    space : int
        The index of the subdomain to be zero-padded. If None, it is set to 0
        if domain contains exactly one space. domain[space] must be an RGSpace.
    central : bool
        If `False`, padding is performed at the end of the domain axes,
        otherwise in the middle.

    Notes
    -----
    When doing central padding on an axis with an even length, the "central"
    entry should in principle be split up; this is currently not done.
    """
51
    def __init__(self, domain, new_shape, space=0, central=False):
52
        self._domain = DomainTuple.make(domain)
53
        self._space = utilities.infer_space(self._domain, space)
54
        self._central = central
55
56
        dom = self._domain[self._space]
        if not isinstance(dom, RGSpace):
Martin Reinecke's avatar
Martin Reinecke committed
57
            raise TypeError("RGSpace required")
58
59
        if len(new_shape) != len(dom.shape):
            raise ValueError("Shape mismatch")
60
61
        if any([a < b for a, b in zip(new_shape, dom.shape)]):
            raise ValueError("New shape must not be smaller than old shape")
62
        self._target = list(self._domain)
63
64
        self._target[self._space] = RGSpace(new_shape, dom.distances,
                                            dom.harmonic)
65
        self._target = DomainTuple.make(self._target)
Martin Reinecke's avatar
Martin Reinecke committed
66
        self._capability = self.TIMES | self.ADJOINT_TIMES
Martin Reinecke's avatar
Martin Reinecke committed
67
68
69

    def apply(self, x, mode):
        self._check_input(x, mode)
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
70
71
72
73
        v = x.val
        curshp = list(self._dom(mode).shape)
        tgtshp = self._tgt(mode).shape
        for d in self._target.axes[self._space]:
74
75
76
            if v.shape[d] == tgtshp[d]:  # nothing to do
                continue

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
77
78
79
80
81
82
83
84
            idx = (slice(None),) * d

            v, x = dobj.ensure_not_distributed(v, (d,))

            if mode == self.TIMES:
                shp = list(x.shape)
                shp[d] = tgtshp[d]
                xnew = np.zeros(shp, dtype=x.dtype)
85
86
87
88
89
90
91
92
93
94
95
96
97
                if self._central:
                    Nyquist = x.shape[d]//2
                    i1 = idx + (slice(0, Nyquist+1),)
                    xnew[i1] = x[i1]
                    i1 = idx + (slice(None, -(Nyquist+1), -1),)
                    xnew[i1] = x[i1]
#                     if (x.shape[d] & 1) == 0:  # even number of pixels
#                         i1 = idx+(Nyquist,)
#                         xnew[i1] *= 0.5
#                         i1 = idx+(-Nyquist,)
#                         xnew[i1] *= 0.5
                else:
                    xnew[idx + (slice(0, x.shape[d]),)] = x
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
98
            else:  # ADJOINT_TIMES
99
100
101
102
103
104
105
106
107
108
109
110
111
112
                if self._central:
                    shp = list(x.shape)
                    shp[d] = tgtshp[d]
                    xnew = np.zeros(shp, dtype=x.dtype)
                    Nyquist = xnew.shape[d]//2
                    i1 = idx + (slice(0, Nyquist+1),)
                    xnew[i1] = x[i1]
                    i1 = idx + (slice(None, -(Nyquist+1), -1),)
                    xnew[i1] += x[i1]
#                     if (xnew.shape[d] & 1) == 0:  # even number of pixels
#                         i1 = idx+(Nyquist,)
#                         xnew[i1] *= 0.5
                else:
                    xnew = x[idx + (slice(0, tgtshp[d]),)]
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
113
114
115
116

            curshp[d] = xnew.shape[d]
            v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
        return Field(self._tgt(mode), dobj.ensure_default_distributed(v))