central_zero_padder.py 2.78 KB
Newer Older
Julia Stadler's avatar
Julia Stadler committed
1
2
3
4
5
6
7
8
import numpy as np
import itertools

from .. import utilities
from .linear_operator import LinearOperator
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
9
from .. import dobj
Julia Stadler's avatar
Julia Stadler committed
10
11


Philipp Arras's avatar
Philipp Arras committed
12
class CentralZeroPadder(LinearOperator):
Julia Stadler's avatar
Julia Stadler committed
13
14
15
16
17
18
19
20
21
22
23
24
    def __init__(self, domain, new_shape, space=0):
        super(CentralZeroPadder, self).__init__()

        self._domain = DomainTuple.make(domain)
        self._space = utilities.infer_space(self._domain, space)
        dom = self._domain[self._space]

        if not isinstance(dom, RGSpace):
            raise TypeError("RGSpace required")
        if dom.harmonic:
            raise TypeError("RGSpace must not be harmonic")
        if len(new_shape) != len(dom.shape):
Martin Reinecke's avatar
Martin Reinecke committed
25
            raise ValueError("Shape mismatch")
Philipp Arras's avatar
Philipp Arras committed
26
        if any([a < b for a, b in zip(new_shape, dom.shape)]):
Julia Stadler's avatar
Julia Stadler committed
27
28
29
30
31
32
33
34
35
36
37
            raise ValueError("New shape must be larger than old shape")

        tgt = RGSpace(new_shape, dom.distances)
        self._target = list(self._domain)
        self._target[self._space] = tgt
        self._target = DomainTuple.make(self._target)

        slicer = []
        axes = self._target.axes[self._space]
        for i in range(len(self._domain.shape)):
            if i in axes:
Martin Reinecke's avatar
Martin Reinecke committed
38
39
                slicer_fw = slice(0, (self._domain.shape[i]+1)//2)
                slicer_bw = slice(-1, -1-(self._domain.shape[i]//2), -1)
Philipp Arras's avatar
Philipp Arras committed
40
                slicer.append([slicer_fw, slicer_bw])
Julia Stadler's avatar
Julia Stadler committed
41
42
43
44
        self.slicer = list(itertools.product(*slicer))

        for i in range(len(self.slicer)):
            for j in range(len(self._domain.shape)):
Philipp Arras's avatar
Philipp Arras committed
45
                if j not in axes:
Martin Reinecke's avatar
Martin Reinecke committed
46
                    tmp = list(self.slicer[i])
Julia Stadler's avatar
Julia Stadler committed
47
48
                    tmp.insert(j, slice(None))
                    self.slicer[i] = tmp
Philipp Arras's avatar
Philipp Arras committed
49

Julia Stadler's avatar
Julia Stadler committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

    @property
    def capability(self):
        return self.TIMES | self.ADJOINT_TIMES

    def apply(self, x, mode):
        self._check_input(x, mode)
        x = x.val
Martin Reinecke's avatar
Martin Reinecke committed
65
66
67
68
69
70
71
        dax = dobj.distaxis(x)
        shp_in = x.shape
        shp_out = self._tgt(mode).shape
        axes = self._target.axes[self._space]
        if dax in axes:
            x = dobj.redistribute(x, nodist=axes)
        curax = dobj.distaxis(x)
Martin Reinecke's avatar
fix    
Martin Reinecke committed
72
        x = dobj.local_data(x)
Philipp Arras's avatar
Philipp Arras committed
73

Julia Stadler's avatar
Julia Stadler committed
74
        if mode == self.TIMES:
Martin Reinecke's avatar
Martin Reinecke committed
75
            y = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
Julia Stadler's avatar
Julia Stadler committed
76
77
            for i in self.slicer:
                y[i] = x[i]
Martin Reinecke's avatar
Martin Reinecke committed
78
79
        else:
            y = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
Julia Stadler's avatar
Julia Stadler committed
80
81
            for i in self.slicer:
                y[i] = x[i]
Martin Reinecke's avatar
Martin Reinecke committed
82
83
84
85
        y = dobj.from_local_data(shp_out, y, distaxis=curax)
        if dax in axes:
            y = dobj.redistribute(y, dist=dax)
        return Field(self._tgt(mode), val=y)