central_zero_padder.py 2.97 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
from __future__ import absolute_import, division, print_function

Julia Stadler's avatar
Julia Stadler committed
3
4
5
import numpy as np
import itertools

Martin Reinecke's avatar
Martin Reinecke committed
6
from ..compat import *
Julia Stadler's avatar
Julia Stadler committed
7
8
9
10
11
from .. import utilities
from .linear_operator import LinearOperator
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
12
from .. import dobj
Julia Stadler's avatar
Julia Stadler committed
13
14


Martin Reinecke's avatar
Martin Reinecke committed
15
16
# MR FIXME: for even axis lengths, we probably should split the value at the
#           highest frequency.
Philipp Arras's avatar
Philipp Arras committed
17
class CentralZeroPadder(LinearOperator):
Julia Stadler's avatar
Julia Stadler committed
18
19
20
21
22
23
24
25
26
27
28
29
    def __init__(self, domain, new_shape, space=0):
        super(CentralZeroPadder, self).__init__()

        self._domain = DomainTuple.make(domain)
        self._space = utilities.infer_space(self._domain, space)
        dom = self._domain[self._space]

        if not isinstance(dom, RGSpace):
            raise TypeError("RGSpace required")
        if dom.harmonic:
            raise TypeError("RGSpace must not be harmonic")
        if len(new_shape) != len(dom.shape):
Martin Reinecke's avatar
Martin Reinecke committed
30
            raise ValueError("Shape mismatch")
Philipp Arras's avatar
Philipp Arras committed
31
        if any([a < b for a, b in zip(new_shape, dom.shape)]):
Julia Stadler's avatar
Julia Stadler committed
32
33
34
35
36
37
38
39
40
41
42
            raise ValueError("New shape must be larger than old shape")

        tgt = RGSpace(new_shape, dom.distances)
        self._target = list(self._domain)
        self._target[self._space] = tgt
        self._target = DomainTuple.make(self._target)

        slicer = []
        axes = self._target.axes[self._space]
        for i in range(len(self._domain.shape)):
            if i in axes:
Martin Reinecke's avatar
Martin Reinecke committed
43
44
                slicer_fw = slice(0, (self._domain.shape[i]+1)//2)
                slicer_bw = slice(-1, -1-(self._domain.shape[i]//2), -1)
Philipp Arras's avatar
Philipp Arras committed
45
                slicer.append([slicer_fw, slicer_bw])
Julia Stadler's avatar
Julia Stadler committed
46
47
48
49
        self.slicer = list(itertools.product(*slicer))

        for i in range(len(self.slicer)):
            for j in range(len(self._domain.shape)):
Philipp Arras's avatar
Philipp Arras committed
50
                if j not in axes:
Martin Reinecke's avatar
Martin Reinecke committed
51
                    tmp = list(self.slicer[i])
Julia Stadler's avatar
Julia Stadler committed
52
53
                    tmp.insert(j, slice(None))
                    self.slicer[i] = tmp
Philipp Arras's avatar
Philipp Arras committed
54

Julia Stadler's avatar
Julia Stadler committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

    @property
    def capability(self):
        return self.TIMES | self.ADJOINT_TIMES

    def apply(self, x, mode):
        self._check_input(x, mode)
        x = x.val
Martin Reinecke's avatar
Martin Reinecke committed
70
71
72
73
74
75
76
        dax = dobj.distaxis(x)
        shp_in = x.shape
        shp_out = self._tgt(mode).shape
        axes = self._target.axes[self._space]
        if dax in axes:
            x = dobj.redistribute(x, nodist=axes)
        curax = dobj.distaxis(x)
Martin Reinecke's avatar
fix    
Martin Reinecke committed
77
        x = dobj.local_data(x)
Philipp Arras's avatar
Philipp Arras committed
78

Julia Stadler's avatar
Julia Stadler committed
79
        if mode == self.TIMES:
Martin Reinecke's avatar
Martin Reinecke committed
80
            y = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
Julia Stadler's avatar
Julia Stadler committed
81
82
            for i in self.slicer:
                y[i] = x[i]
Martin Reinecke's avatar
Martin Reinecke committed
83
84
        else:
            y = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
Julia Stadler's avatar
Julia Stadler committed
85
86
            for i in self.slicer:
                y[i] = x[i]
Martin Reinecke's avatar
Martin Reinecke committed
87
88
89
90
        y = dobj.from_local_data(shp_out, y, distaxis=curax)
        if dax in axes:
            y = dobj.redistribute(y, dist=dax)
        return Field(self._tgt(mode), val=y)