field_zero_padder.py 2.42 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
from __future__ import absolute_import, division, print_function

import numpy as np

from .. import dobj
from ..compat import *
from ..domain_tuple import DomainTuple
8
9
from ..domains.rg_space import RGSpace
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
10
from .linear_operator import LinearOperator
11
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
12
13
14


class FieldZeroPadder(LinearOperator):
15
    def __init__(self, domain, new_shape, space=0):
Martin Reinecke's avatar
Martin Reinecke committed
16
        super(FieldZeroPadder, self).__init__()
17
        self._domain = DomainTuple.make(domain)
18
        self._space = utilities.infer_space(self._domain, space)
19
20
        dom = self._domain[self._space]
        if not isinstance(dom, RGSpace):
Martin Reinecke's avatar
Martin Reinecke committed
21
            raise TypeError("RGSpace required")
22
        if dom.harmonic:
Martin Reinecke's avatar
Martin Reinecke committed
23
24
            raise TypeError("RGSpace must not be harmonic")

25
26
27
28
29
        if len(new_shape) != len(dom.shape):
            raise ValueError("Shape mismatch")
        if any([a < b for a, b in zip(new_shape, dom.shape)]):
            raise ValueError("New shape must be larger than old shape")
        tgt = RGSpace(new_shape, dom.distances)
30
31
32
        self._target = list(self._domain)
        self._target[self._space] = tgt
        self._target = DomainTuple.make(self._target)
Martin Reinecke's avatar
Martin Reinecke committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

    @property
    def capability(self):
        return self.TIMES | self.ADJOINT_TIMES

    def apply(self, x, mode):
        self._check_input(x, mode)
        x = x.val
        dax = dobj.distaxis(x)
        shp_in = x.shape
        shp_out = self._tgt(mode).shape
Martin Reinecke's avatar
Martin Reinecke committed
52
53
54
55
        axbefore = self._target.axes[self._space][0]
        axes = self._target.axes[self._space]
        if dax in axes:
            x = dobj.redistribute(x, nodist=axes)
Martin Reinecke's avatar
Martin Reinecke committed
56
57
        curax = dobj.distaxis(x)

58
        if mode == self.ADJOINT_TIMES:
59
            newarr = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
Martin Reinecke's avatar
Martin Reinecke committed
60
61
            sl = tuple(slice(0, shp_out[axis]) for axis in axes)
            newarr[()] = dobj.local_data(x)[(slice(None),)*axbefore + sl]
Martin Reinecke's avatar
Martin Reinecke committed
62
        else:
63
            newarr = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
Martin Reinecke's avatar
Martin Reinecke committed
64
65
            sl = tuple(slice(0, shp_in[axis]) for axis in axes)
            newarr[(slice(None),)*axbefore + sl] = dobj.local_data(x)
Martin Reinecke's avatar
Martin Reinecke committed
66
        newarr = dobj.from_local_data(shp_out, newarr, distaxis=curax)
Martin Reinecke's avatar
Martin Reinecke committed
67
68
        if dax in axes:
            newarr = dobj.redistribute(newarr, dist=dax)
Martin Reinecke's avatar
Martin Reinecke committed
69
        return Field(self._tgt(mode), val=newarr)