fft_operator.py 5.27 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
Martin Reinecke's avatar
Martin Reinecke committed
14
# Copyright(C) 2013-2018 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

Martin Reinecke's avatar
Martin Reinecke committed
19
import numpy as np
Martin Reinecke's avatar
Martin Reinecke committed
20
from ..domain_tuple import DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..domains.rg_space import RGSpace
Martin Reinecke's avatar
Martin Reinecke committed
22
from .linear_operator import LinearOperator
Martin Reinecke's avatar
Martin Reinecke committed
23
24
25
from .. import dobj
from .. import utilities
from ..field import Field
Jait Dixit's avatar
Jait Dixit committed
26
27


Jait Dixit's avatar
Jait Dixit committed
28
class FFTOperator(LinearOperator):
29
    """Transforms between a pair of position and harmonic RGSpaces.
30
31
32

    Parameters
    ----------
33
    domain: Space, tuple of Spaces or DomainObject
34
35
        The domain of the data that is input by "times" and output by
        "adjoint_times".
36
37
38
    target: Space
        The target space of the transform operation.
        If omitted, a space will be chosen automatically.
39
    space: the index of the space on which the operator should act
40
41
        If None, it is set to 0 if domain contains exactly one space.
        domain[space] must be an RGSpace.
42
    """
43

44
45
    def __init__(self, domain, target=None, space=None):
        super(FFTOperator, self).__init__()
46
47

        # Initialize domain and target
Martin Reinecke's avatar
Martin Reinecke committed
48
        self._domain = DomainTuple.make(domain)
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
49
        self._space = utilities.infer_space(self._domain, space)
50

Martin Reinecke's avatar
Martin Reinecke committed
51
        adom = self._domain[self._space]
52
53
        if not isinstance(adom, RGSpace):
            raise TypeError("FFTOperator only works on RGSpaces")
54
        if target is None:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
55
            target = adom.get_default_codomain()
56

Martin Reinecke's avatar
Martin Reinecke committed
57
        self._target = [dom for dom in self._domain]
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
58
59
60
61
        self._target[self._space] = target
        self._target = DomainTuple.make(self._target)
        adom.check_codomain(target)
        target.check_codomain(adom)
Jait Dixit's avatar
Jait Dixit committed
62

63
64
        import pyfftw
        pyfftw.interfaces.cache.enable()
Martin Reinecke's avatar
Martin Reinecke committed
65

Martin Reinecke's avatar
Martin Reinecke committed
66
67
    def apply(self, x, mode):
        self._check_input(x, mode)
68
        if np.issubdtype(x.dtype, np.complexfloating):
69
70
            return (self._apply_cartesian(x.real, mode) +
                    1j*self._apply_cartesian(x.imag, mode))
Martin Reinecke's avatar
Martin Reinecke committed
71
        else:
72
            return self._apply_cartesian(x, mode)
Martin Reinecke's avatar
Martin Reinecke committed
73
74
75
76

    def _apply_cartesian(self, x, mode):
        from pyfftw.interfaces.numpy_fft import fftn
        axes = x.domain.axes[self._space]
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
77
        tdom = self._target if x.domain == self._domain else self._domain
Martin Reinecke's avatar
Martin Reinecke committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        oldax = dobj.distaxis(x.val)
        if oldax not in axes:  # straightforward, no redistribution needed
            ldat = dobj.local_data(x.val)
            ldat = utilities.hartley(ldat, axes=axes)
            tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
        elif len(axes) < len(x.shape) or len(axes) == 1:
            # we can use one Hartley pass in between the redistributions
            tmp = dobj.redistribute(x.val, nodist=axes)
            newax = dobj.distaxis(tmp)
            ldat = dobj.local_data(tmp)
            ldat = utilities.hartley(ldat, axes=axes)
            tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
            tmp = dobj.redistribute(tmp, dist=oldax)
        else:  # two separate, full FFTs needed
            # ideal strategy for the moment would be:
            # - do real-to-complex FFT on all local axes
            # - fill up array
            # - redistribute array
            # - do complex-to-complex FFT on remaining axis
            # - add re+im
            # - redistribute back
            rem_axes = tuple(i for i in axes if i != oldax)
            tmp = x.val
            ldat = dobj.local_data(tmp)
            ldat = utilities.my_fftn_r2c(ldat, axes=rem_axes)
            if oldax != 0:
                raise ValueError("bad distribution")
            ldat2 = ldat.reshape((ldat.shape[0],
                                  np.prod(ldat.shape[1:])))
            shp2d = (x.val.shape[0], np.prod(x.val.shape[1:]))
            tmp = dobj.from_local_data(shp2d, ldat2, distaxis=0)
            tmp = dobj.transpose(tmp)
            ldat2 = dobj.local_data(tmp)
            ldat2 = fftn(ldat2, axes=(1,))
            ldat2 = ldat2.real+ldat2.imag
            tmp = dobj.from_local_data(tmp.shape, ldat2, distaxis=0)
            tmp = dobj.transpose(tmp)
            ldat2 = dobj.local_data(tmp).reshape(ldat.shape)
            tmp = dobj.from_local_data(x.val.shape, ldat2, distaxis=0)
        Tval = Field(tdom, tmp)
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
118
119
        if mode & (LinearOperator.TIMES | LinearOperator.ADJOINT_TIMES):
            fct = self._domain[self._space].scalar_dvol()
Martin Reinecke's avatar
Martin Reinecke committed
120
        else:
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
121
            fct = self._target[self._space].scalar_dvol()
Martin Reinecke's avatar
Martin Reinecke committed
122
123
124
125
126
        if fct != 1:
            Tval *= fct

        return Tval

127
128
129
130
    @property
    def domain(self):
        return self._domain

Jait Dixit's avatar
Jait Dixit committed
131
132
133
134
    @property
    def target(self):
        return self._target

135
    @property
Martin Reinecke's avatar
Martin Reinecke committed
136
    def capability(self):
137
        return self._all_ops