field_zero_padder.py 2.73 KB
Newer Older
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
20
21
22
from __future__ import absolute_import, division, print_function

import numpy as np

Philipp Arras's avatar
Philipp Arras committed
23
from .. import dobj, utilities
Martin Reinecke's avatar
Martin Reinecke committed
24
25
from ..compat import *
from ..domain_tuple import DomainTuple
26
27
from ..domains.rg_space import RGSpace
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
28
29
30
31
from .linear_operator import LinearOperator


class FieldZeroPadder(LinearOperator):
32
    def __init__(self, domain, new_shape, space=0):
33
        self._domain = DomainTuple.make(domain)
34
        self._space = utilities.infer_space(self._domain, space)
35
36
        dom = self._domain[self._space]
        if not isinstance(dom, RGSpace):
Martin Reinecke's avatar
Martin Reinecke committed
37
            raise TypeError("RGSpace required")
38
        if dom.harmonic:
Martin Reinecke's avatar
Martin Reinecke committed
39
40
            raise TypeError("RGSpace must not be harmonic")

41
42
43
44
        if len(new_shape) != len(dom.shape):
            raise ValueError("Shape mismatch")
        if any([a < b for a, b in zip(new_shape, dom.shape)]):
            raise ValueError("New shape must be larger than old shape")
45
        self._target = list(self._domain)
Martin Reinecke's avatar
Martin Reinecke committed
46
        self._target[self._space] = RGSpace(new_shape, dom.distances)
47
        self._target = DomainTuple.make(self._target)
Martin Reinecke's avatar
Martin Reinecke committed
48
        self._capability = self.TIMES | self.ADJOINT_TIMES
Martin Reinecke's avatar
Martin Reinecke committed
49
50
51

    def apply(self, x, mode):
        self._check_input(x, mode)
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        v = x.val
        curshp = list(self._dom(mode).shape)
        tgtshp = self._tgt(mode).shape
        for d in self._target.axes[self._space]:
            idx = (slice(None),) * d

            v, x = dobj.ensure_not_distributed(v, (d,))

            if mode == self.TIMES:
                shp = list(x.shape)
                shp[d] = tgtshp[d]
                xnew = np.zeros(shp, dtype=x.dtype)
                xnew[idx + (slice(0, x.shape[d]),)] = x
            else:  # ADJOINT_TIMES
                xnew = x[idx + (slice(0, tgtshp[d]),)]

            curshp[d] = xnew.shape[d]
            v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
        return Field(self._tgt(mode), dobj.ensure_default_distributed(v))