fft_operator_support.py 7.94 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

from __future__ import division
import numpy as np
Martin Reinecke's avatar
Martin Reinecke committed
21
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
22
from .. import dobj
Martin Reinecke's avatar
Martin Reinecke committed
23
24
from ..field import Field
from ..spaces.gl_space import GLSpace
Martin Reinecke's avatar
Martin Reinecke committed
25
from .linear_operator import LinearOperator
Martin Reinecke's avatar
Martin Reinecke committed
26

Martin Reinecke's avatar
Martin Reinecke committed
27

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
28
class Transformation(object):
29
    def __init__(self, hdom, pdom, space):
Martin Reinecke's avatar
Martin Reinecke committed
30
        self.hdom = hdom
31
        self.pdom = pdom
Martin Reinecke's avatar
Martin Reinecke committed
32
        self.space = space
Martin Reinecke's avatar
Martin Reinecke committed
33
34
35


class RGRGTransformation(Transformation):
36
    def __init__(self, hdom, pdom, space):
Martin Reinecke's avatar
Martin Reinecke committed
37
        import pyfftw
38
        super(RGRGTransformation, self).__init__(hdom, pdom, space)
Martin Reinecke's avatar
Martin Reinecke committed
39
        pyfftw.interfaces.cache.enable()
40
41
        self.fct_noninverse = hdom[space].scalar_dvol()
        self.fct_inverse = 1./(hdom[space].scalar_dvol()*hdom[space].dim)
Martin Reinecke's avatar
Martin Reinecke committed
42
43
44
45
46

    @property
    def unitary(self):
        return True

Martin Reinecke's avatar
Martin Reinecke committed
47
    def apply(self, x, mode):
Martin Reinecke's avatar
Martin Reinecke committed
48
49
50
51
52
        """
        RG -> RG transform method.

        Parameters
        ----------
Martin Reinecke's avatar
Martin Reinecke committed
53
54
        x : Field
            The field to be transformed
Martin Reinecke's avatar
Martin Reinecke committed
55
        """
Martin Reinecke's avatar
Martin Reinecke committed
56
        from pyfftw.interfaces.numpy_fft import fftn
Martin Reinecke's avatar
Martin Reinecke committed
57
58
        axes = x.domain.axes[self.space]
        p2h = x.domain == self.pdom
Martin Reinecke's avatar
Martin Reinecke committed
59
        tdom = self.hdom if p2h else self.pdom
Martin Reinecke's avatar
Martin Reinecke committed
60
        oldax = dobj.distaxis(x.val)
61
62
63
64
        if oldax not in axes:  # straightforward, no redistribution needed
            ldat = dobj.local_data(x.val)
            ldat = utilities.hartley(ldat, axes=axes)
            tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
Martin Reinecke's avatar
Martin Reinecke committed
65
66
67
68
69
70
71
72
73
        elif len(axes) < len(x.shape) or len(axes) == 1:
            # we can use one Hartley pass in between the redistributions
            tmp = dobj.redistribute(x.val, nodist=axes)
            newax = dobj.distaxis(tmp)
            ldat = dobj.local_data(tmp)
            ldat = utilities.hartley(ldat, axes=axes)
            tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
            tmp = dobj.redistribute(tmp, dist=oldax)
        else:  # two separate, full FFTs needed
Martin Reinecke's avatar
Martin Reinecke committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
            # ideal strategy for the moment would be:
            # - do real-to-complex FFT on all local axes
            # - fill up array
            # - redistribute array
            # - do complex-to-complex FFT on remaining axis
            # - add re+im
            # - redistribute back
            if True:
                rem_axes = tuple(i for i in axes if i != oldax)
                tmp = x.val
                ldat = dobj.local_data(tmp)
                ldat = utilities.my_fftn_r2c(ldat, axes=rem_axes)
                # new, experimental code
                if True:
                    if oldax != 0:
                        raise ValueError("bad distribution")
Martin Reinecke's avatar
Martin Reinecke committed
90
91
                    ldat2 = ldat.reshape((ldat.shape[0],
                                          np.prod(ldat.shape[1:])))
Martin Reinecke's avatar
Martin Reinecke committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
                    shp2d = (x.val.shape[0], np.prod(x.val.shape[1:]))
                    tmp = dobj.from_local_data(shp2d, ldat2, distaxis=0)
                    tmp = dobj.transpose(tmp)
                    ldat2 = dobj.local_data(tmp)
                    ldat2 = fftn(ldat2, axes=(1,))
                    ldat2 = ldat2.real+ldat2.imag
                    tmp = dobj.from_local_data(tmp.shape, ldat2, distaxis=0)
                    tmp = dobj.transpose(tmp)
                    ldat2 = dobj.local_data(tmp).reshape(ldat.shape)
                    tmp = dobj.from_local_data(x.val.shape, ldat2, distaxis=0)
                else:
                    tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=oldax)
                    tmp = dobj.redistribute(tmp, nodist=(oldax,))
                    newax = dobj.distaxis(tmp)
                    ldat = dobj.local_data(tmp)
                    ldat = fftn(ldat, axes=(oldax,))
                    ldat = ldat.real+ldat.imag
                    tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
                    tmp = dobj.redistribute(tmp, dist=oldax)
            else:
                tmp = dobj.redistribute(x.val, nodist=(oldax,))
                newax = dobj.distaxis(tmp)
                ldat = dobj.local_data(tmp)
                ldat = fftn(ldat, axes=(oldax,))
                tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
                tmp = dobj.redistribute(tmp, dist=oldax)
                rem_axes = tuple(i for i in axes if i != oldax)
                ldat = dobj.local_data(tmp)
                ldat = fftn(ldat, axes=rem_axes)
                ldat = ldat.real+ldat.imag
                tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=oldax)
Martin Reinecke's avatar
Martin Reinecke committed
123
        Tval = Field(tdom, tmp)
Martin Reinecke's avatar
Martin Reinecke committed
124
125
126
127
128
        if (mode == LinearOperator.TIMES or
                mode == LinearOperator.ADJOINT_TIMES):
            fct = self.fct_noninverse
        else:
            fct = self.fct_inverse
Martin Reinecke's avatar
Martin Reinecke committed
129
130
        if fct != 1:
            Tval *= fct
Martin Reinecke's avatar
Martin Reinecke committed
131

Martin Reinecke's avatar
Martin Reinecke committed
132
        return Tval
Martin Reinecke's avatar
Martin Reinecke committed
133
134


135
class SphericalTransformation(Transformation):
136
137
    def __init__(self, hdom, pdom, space):
        super(SphericalTransformation, self).__init__(hdom, pdom, space)
Martin Reinecke's avatar
Martin Reinecke committed
138
139
        from pyHealpix import sharpjob_d

Martin Reinecke's avatar
Martin Reinecke committed
140
        self.lmax = self.hdom[self.space].lmax
141
        self.mmax = self.hdom[self.space].mmax
Martin Reinecke's avatar
Martin Reinecke committed
142
        self.sjob = sharpjob_d()
143
        self.sjob.set_triangular_alm_info(self.lmax, self.mmax)
Martin Reinecke's avatar
Martin Reinecke committed
144
145
146
        if isinstance(self.pdom[self.space], GLSpace):
            self.sjob.set_Gauss_geometry(self.pdom[self.space].nlat,
                                         self.pdom[self.space].nlon)
Martin Reinecke's avatar
Martin Reinecke committed
147
        else:
Martin Reinecke's avatar
Martin Reinecke committed
148
            self.sjob.set_Healpix_geometry(self.pdom[self.space].nside)
Martin Reinecke's avatar
Martin Reinecke committed
149
150
151
152
153

    @property
    def unitary(self):
        return False

Martin Reinecke's avatar
Martin Reinecke committed
154
    def _slice_p2h(self, inp):
Martin Reinecke's avatar
Martin Reinecke committed
155
        rr = self.sjob.alm2map_adjoint(inp)
156
157
158
159
160
161
        assert len(rr) == ((self.mmax+1)*(self.mmax+2))//2 + \
                          (self.mmax+1)*(self.lmax-self.mmax)
        res = np.empty(2*len(rr)-self.lmax-1, dtype=rr[0].real.dtype)
        res[0:self.lmax+1] = rr[0:self.lmax+1].real
        res[self.lmax+1::2] = np.sqrt(2)*rr[self.lmax+1:].real
        res[self.lmax+2::2] = np.sqrt(2)*rr[self.lmax+1:].imag
Martin Reinecke's avatar
Martin Reinecke committed
162
        return res/np.sqrt(np.pi*4)
Martin Reinecke's avatar
Martin Reinecke committed
163

Martin Reinecke's avatar
Martin Reinecke committed
164
    def _slice_h2p(self, inp):
165
166
167
168
169
170
        res = np.empty((len(inp)+self.lmax+1)//2, dtype=(inp[0]*1j).dtype)
        assert len(res) == ((self.mmax+1)*(self.mmax+2))//2 + \
                           (self.mmax+1)*(self.lmax-self.mmax)
        res[0:self.lmax+1] = inp[0:self.lmax+1]
        res[self.lmax+1:] = np.sqrt(0.5)*(inp[self.lmax+1::2] +
                                          1j*inp[self.lmax+2::2])
Martin Reinecke's avatar
Martin Reinecke committed
171
        res = self.sjob.alm2map(res)
Martin Reinecke's avatar
Martin Reinecke committed
172
        return res/np.sqrt(np.pi*4)
173

Martin Reinecke's avatar
Martin Reinecke committed
174
    def apply(self, x, mode):
175
        axes = x.domain.axes[self.space]
Martin Reinecke's avatar
Martin Reinecke committed
176
177
        axis = axes[0]
        tval = x.val
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
178
        if dobj.distaxis(tval) == axis:
Martin Reinecke's avatar
Martin Reinecke committed
179
            tval = dobj.redistribute(tval, nodist=(axis,))
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
180
        distaxis = dobj.distaxis(tval)
Martin Reinecke's avatar
Martin Reinecke committed
181

Martin Reinecke's avatar
Martin Reinecke committed
182
        p2h = x.domain == self.pdom
Martin Reinecke's avatar
Martin Reinecke committed
183
184
        tdom = self.hdom if p2h else self.pdom
        func = self._slice_p2h if p2h else self._slice_h2p
Martin Reinecke's avatar
Martin Reinecke committed
185
        idat = dobj.local_data(tval)
Martin Reinecke's avatar
Martin Reinecke committed
186
187
188
189
190
191
192
193
        odat = np.empty(dobj.local_shape(tdom.shape, distaxis=distaxis),
                        dtype=x.dtype)
        for slice in utilities.get_slice_list(idat.shape, axes):
            odat[slice] = func(idat[slice])
        odat = dobj.from_local_data(tdom.shape, odat, distaxis)
        if distaxis != dobj.distaxis(x.val):
            odat = dobj.redistribute(odat, dist=dobj.distaxis(x.val))
        return Field(tdom, odat)