field_zero_padder.py 3.96 KB
Newer Older
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
20
21
22
from __future__ import absolute_import, division, print_function

import numpy as np

Philipp Arras's avatar
Philipp Arras committed
23
from .. import dobj, utilities
Martin Reinecke's avatar
Martin Reinecke committed
24
25
from ..compat import *
from ..domain_tuple import DomainTuple
26
27
from ..domains.rg_space import RGSpace
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
28
29
30
31
from .linear_operator import LinearOperator


class FieldZeroPadder(LinearOperator):
32
    def __init__(self, domain, new_shape, space=0, central=False):
33
        self._domain = DomainTuple.make(domain)
34
        self._space = utilities.infer_space(self._domain, space)
35
        self._central = central
36
37
        dom = self._domain[self._space]
        if not isinstance(dom, RGSpace):
Martin Reinecke's avatar
Martin Reinecke committed
38
            raise TypeError("RGSpace required")
39
        if dom.harmonic:
Martin Reinecke's avatar
Martin Reinecke committed
40
41
            raise TypeError("RGSpace must not be harmonic")

42
43
        if len(new_shape) != len(dom.shape):
            raise ValueError("Shape mismatch")
44
        if any([a <= b for a, b in zip(new_shape, dom.shape)]):
45
            raise ValueError("New shape must be larger than old shape")
46
        self._target = list(self._domain)
Martin Reinecke's avatar
Martin Reinecke committed
47
        self._target[self._space] = RGSpace(new_shape, dom.distances)
48
        self._target = DomainTuple.make(self._target)
Martin Reinecke's avatar
Martin Reinecke committed
49
        self._capability = self.TIMES | self.ADJOINT_TIMES
Martin Reinecke's avatar
Martin Reinecke committed
50
51
52

    def apply(self, x, mode):
        self._check_input(x, mode)
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
53
54
55
56
57
58
59
60
61
62
63
64
        v = x.val
        curshp = list(self._dom(mode).shape)
        tgtshp = self._tgt(mode).shape
        for d in self._target.axes[self._space]:
            idx = (slice(None),) * d

            v, x = dobj.ensure_not_distributed(v, (d,))

            if mode == self.TIMES:
                shp = list(x.shape)
                shp[d] = tgtshp[d]
                xnew = np.zeros(shp, dtype=x.dtype)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
                if self._central:
                    Nyquist = x.shape[d]//2
                    i1 = idx + (slice(0, Nyquist+1),)
                    xnew[i1] = x[i1]
                    i1 = idx + (slice(None, -(Nyquist+1), -1),)
                    xnew[i1] = x[i1]
#                     if (x.shape[d] & 1) == 0:  # even number of pixels
#                         print (Nyquist, x.shape[d]-Nyquist)
#                         i1 = idx+(Nyquist,)
#                         xnew[i1] *= 0.5
#                         i1 = idx+(-Nyquist,)
#                         xnew[i1] *= 0.5
                else:
                    xnew[idx + (slice(0, x.shape[d]),)] = x
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
79
            else:  # ADJOINT_TIMES
80
81
82
83
84
85
86
87
88
89
90
91
92
93
                if self._central:
                    shp = list(x.shape)
                    shp[d] = tgtshp[d]
                    xnew = np.zeros(shp, dtype=x.dtype)
                    Nyquist = xnew.shape[d]//2
                    i1 = idx + (slice(0, Nyquist+1),)
                    xnew[i1] = x[i1]
                    i1 = idx + (slice(None, -(Nyquist+1), -1),)
                    xnew[i1] += x[i1]
#                     if (xnew.shape[d] & 1) == 0:  # even number of pixels
#                         i1 = idx+(Nyquist,)
#                         xnew[i1] *= 0.5
                else:
                    xnew = x[idx + (slice(0, tgtshp[d]),)]
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
94
95
96
97

            curshp[d] = xnew.shape[d]
            v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
        return Field(self._tgt(mode), dobj.ensure_default_distributed(v))