fft_operator.py 5.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
Martin Reinecke's avatar
Martin Reinecke committed
14
# Copyright(C) 2013-2018 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

Martin Reinecke's avatar
Martin Reinecke committed
19
import numpy as np
Martin Reinecke's avatar
Martin Reinecke committed
20
from ..domain_tuple import DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..domains.rg_space import RGSpace
Martin Reinecke's avatar
Martin Reinecke committed
22
from .linear_operator import LinearOperator
Martin Reinecke's avatar
Martin Reinecke committed
23
24
25
from .. import dobj
from .. import utilities
from ..field import Field
Jait Dixit's avatar
Jait Dixit committed
26
27


Jait Dixit's avatar
Jait Dixit committed
28
class FFTOperator(LinearOperator):
29
    """Transforms between a pair of position and harmonic RGSpaces.
30
31
32

    Parameters
    ----------
Philipp Arras's avatar
Philipp Arras committed
33
    domain: Domain, tuple of Domain or DomainTuple
34
35
        The domain of the data that is input by "times" and output by
        "adjoint_times".
Martin Reinecke's avatar
docs    
Martin Reinecke committed
36
37
38
39
40
41
    target: Domain, optional
        The target (sub-)domain of the transform operation.
        If omitted, a domain will be chosen automatically.
    space: int, optional
        The index of the subdomain on which the operator should act
        If None, it is set to 0 if `domain` contains exactly one space.
Martin Reinecke's avatar
Martin Reinecke committed
42
        `domain[space]` must be an RGSpace.
43
    """
44

45
46
    def __init__(self, domain, target=None, space=None):
        super(FFTOperator, self).__init__()
47
48

        # Initialize domain and target
Martin Reinecke's avatar
Martin Reinecke committed
49
        self._domain = DomainTuple.make(domain)
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
50
        self._space = utilities.infer_space(self._domain, space)
51

Martin Reinecke's avatar
Martin Reinecke committed
52
        adom = self._domain[self._space]
53
54
        if not isinstance(adom, RGSpace):
            raise TypeError("FFTOperator only works on RGSpaces")
55
        if target is None:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
56
            target = adom.get_default_codomain()
57

Martin Reinecke's avatar
Martin Reinecke committed
58
        self._target = [dom for dom in self._domain]
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
59
60
61
62
        self._target[self._space] = target
        self._target = DomainTuple.make(self._target)
        adom.check_codomain(target)
        target.check_codomain(adom)
Jait Dixit's avatar
Jait Dixit committed
63

64
65
        import pyfftw
        pyfftw.interfaces.cache.enable()
Martin Reinecke's avatar
Martin Reinecke committed
66
        pyfftw.interfaces.cache.set_keepalive_time(1000.)
Martin Reinecke's avatar
Martin Reinecke committed
67

Martin Reinecke's avatar
Martin Reinecke committed
68
69
    def apply(self, x, mode):
        self._check_input(x, mode)
70
        if np.issubdtype(x.dtype, np.complexfloating):
71
72
            return (self._apply_cartesian(x.real, mode) +
                    1j*self._apply_cartesian(x.imag, mode))
Martin Reinecke's avatar
Martin Reinecke committed
73
        else:
74
            return self._apply_cartesian(x, mode)
Martin Reinecke's avatar
Martin Reinecke committed
75
76
77
78

    def _apply_cartesian(self, x, mode):
        from pyfftw.interfaces.numpy_fft import fftn
        axes = x.domain.axes[self._space]
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
79
        tdom = self._target if x.domain == self._domain else self._domain
Martin Reinecke's avatar
Martin Reinecke committed
80
81
        oldax = dobj.distaxis(x.val)
        if oldax not in axes:  # straightforward, no redistribution needed
Martin Reinecke's avatar
Martin Reinecke committed
82
            ldat = x.local_data
Martin Reinecke's avatar
Martin Reinecke committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
            ldat = utilities.hartley(ldat, axes=axes)
            tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
        elif len(axes) < len(x.shape) or len(axes) == 1:
            # we can use one Hartley pass in between the redistributions
            tmp = dobj.redistribute(x.val, nodist=axes)
            newax = dobj.distaxis(tmp)
            ldat = dobj.local_data(tmp)
            ldat = utilities.hartley(ldat, axes=axes)
            tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
            tmp = dobj.redistribute(tmp, dist=oldax)
        else:  # two separate, full FFTs needed
            # ideal strategy for the moment would be:
            # - do real-to-complex FFT on all local axes
            # - fill up array
            # - redistribute array
            # - do complex-to-complex FFT on remaining axis
            # - add re+im
            # - redistribute back
            rem_axes = tuple(i for i in axes if i != oldax)
            tmp = x.val
            ldat = dobj.local_data(tmp)
Martin Reinecke's avatar
fix    
Martin Reinecke committed
104
            ldat = utilities.my_fftn_r2c(ldat, axes=rem_axes)
Martin Reinecke's avatar
Martin Reinecke committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
            if oldax != 0:
                raise ValueError("bad distribution")
            ldat2 = ldat.reshape((ldat.shape[0],
                                  np.prod(ldat.shape[1:])))
            shp2d = (x.val.shape[0], np.prod(x.val.shape[1:]))
            tmp = dobj.from_local_data(shp2d, ldat2, distaxis=0)
            tmp = dobj.transpose(tmp)
            ldat2 = dobj.local_data(tmp)
            ldat2 = fftn(ldat2, axes=(1,))
            ldat2 = ldat2.real+ldat2.imag
            tmp = dobj.from_local_data(tmp.shape, ldat2, distaxis=0)
            tmp = dobj.transpose(tmp)
            ldat2 = dobj.local_data(tmp).reshape(ldat.shape)
            tmp = dobj.from_local_data(x.val.shape, ldat2, distaxis=0)
        Tval = Field(tdom, tmp)
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
120
        if mode & (LinearOperator.TIMES | LinearOperator.ADJOINT_TIMES):
Martin Reinecke's avatar
Martin Reinecke committed
121
            fct = self._domain[self._space].scalar_dvol
Martin Reinecke's avatar
Martin Reinecke committed
122
        else:
Martin Reinecke's avatar
Martin Reinecke committed
123
            fct = self._target[self._space].scalar_dvol
Martin Reinecke's avatar
Martin Reinecke committed
124
125
126
127
128
        if fct != 1:
            Tval *= fct

        return Tval

129
130
131
132
    @property
    def domain(self):
        return self._domain

Jait Dixit's avatar
Jait Dixit committed
133
134
135
136
    @property
    def target(self):
        return self._target

137
    @property
Martin Reinecke's avatar
Martin Reinecke committed
138
    def capability(self):
139
        return self._all_ops