test_fft_operator.py 6.82 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
14
15
16
17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
Martin Reinecke's avatar
Martin Reinecke committed
18

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import division
Martin Reinecke's avatar
Martin Reinecke committed
20
21
import unittest
import numpy as np
Theo Steininger's avatar
Theo Steininger committed
22
23
from numpy.testing import assert_equal,\
    assert_allclose
24
from nifty.config import dependency_injector as gdi
Martin Reinecke's avatar
Martin Reinecke committed
25
from nifty import Field,\
Theo Steininger's avatar
Theo Steininger committed
26
27
    RGSpace,\
    LMSpace,\
28
29
    HPSpace,\
    GLSpace,\
30
    FFTOperator
Martin Reinecke's avatar
Martin Reinecke committed
31
32
from itertools import product
from test.common import expand
33
34
from nose.plugins.skip import SkipTest

Theo Steininger's avatar
Theo Steininger committed
35

Martin Reinecke's avatar
Martin Reinecke committed
36
def _harmonic_type(itp):
Theo Steininger's avatar
Theo Steininger committed
37
38
39
40
41
    otp = itp
    if otp == np.float64:
        otp = np.complex128
    elif otp == np.float32:
        otp = np.complex64
Martin Reinecke's avatar
Martin Reinecke committed
42
43
    return otp

Theo Steininger's avatar
Theo Steininger committed
44

Martin Reinecke's avatar
Martin Reinecke committed
45
def _get_rtol(tp):
Theo Steininger's avatar
Theo Steininger committed
46
    if (tp == np.float64) or (tp == np.complex128):
Martin Reinecke's avatar
Martin Reinecke committed
47
48
49
        return 1e-10
    else:
        return 1e-5
Martin Reinecke's avatar
Martin Reinecke committed
50

Theo Steininger's avatar
Theo Steininger committed
51

52
class FFTOperatorTests(unittest.TestCase):
Theo Steininger's avatar
Theo Steininger committed
53
54
55
    @expand(product([10, 11], [False, True], [0.1, 1, 3.7]))
    def test_RG_distance_1D(self, dim1, zc1, d):
        foo = RGSpace([dim1], zerocenter=zc1, distances=d)
Martin Reinecke's avatar
Martin Reinecke committed
56
        res = foo.get_distance_array('not')
Theo Steininger's avatar
Theo Steininger committed
57
        assert_equal(res[zc1 * (dim1 // 2)], 0.)
Martin Reinecke's avatar
Martin Reinecke committed
58

Theo Steininger's avatar
Theo Steininger committed
59
60
    @expand(product([10, 11], [9, 28], [False, True], [False, True],
                    [0.1, 1, 3.7]))
Martin Reinecke's avatar
Martin Reinecke committed
61
    def test_RG_distance_2D(self, dim1, dim2, zc1, zc2, d):
Theo Steininger's avatar
Theo Steininger committed
62
        foo = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=d)
Martin Reinecke's avatar
Martin Reinecke committed
63
        res = foo.get_distance_array('not')
Theo Steininger's avatar
Theo Steininger committed
64
        assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.)
Martin Reinecke's avatar
Martin Reinecke committed
65

Theo Steininger's avatar
Theo Steininger committed
66
    @expand(product(["numpy", "fftw", "fftw_mpi"],
67
                    [12, ], [False, True], [False, True],
Theo Steininger's avatar
Theo Steininger committed
68
69
70
                    [0.1, 1, 3.7],
                    [np.float64, np.complex128, np.float32, np.complex64]))
    def test_fft1D(self, module, dim1, zc1, zc2, d, itp):
Theo Steininger's avatar
Theo Steininger committed
71
72
73
74
        if module == "fftw_mpi":
            if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
                raise SkipTest
        if module == "fftw" and "fftw" not in gdi:
Martin Reinecke's avatar
Martin Reinecke committed
75
            raise SkipTest
Theo Steininger's avatar
Theo Steininger committed
76
        tol = _get_rtol(itp)
Martin Reinecke's avatar
Martin Reinecke committed
77
        a = RGSpace(dim1, zerocenter=zc1, distances=d)
Martin Reinecke's avatar
Martin Reinecke committed
78
        b = RGSpace(dim1, zerocenter=zc2, distances=1./(dim1*d), harmonic=True)
Martin Reinecke's avatar
Martin Reinecke committed
79
        fft = FFTOperator(domain=a, target=b, domain_dtype=itp,
Theo Steininger's avatar
Theo Steininger committed
80
                          target_dtype=_harmonic_type(itp), module=module)
81
        np.random.seed(16)
Theo Steininger's avatar
Theo Steininger committed
82
        inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
Martin Reinecke's avatar
Martin Reinecke committed
83
                                dtype=itp)
84
        out = fft.adjoint_times(fft.times(inp))
85
86
87
        assert_allclose(inp.val.get_full_data(),
                        out.val.get_full_data(),
                        rtol=tol, atol=tol)
Martin Reinecke's avatar
Martin Reinecke committed
88

Theo Steininger's avatar
Theo Steininger committed
89
    @expand(product(["numpy", "fftw", "fftw_mpi"],
90
                    [12, 15], [9, 12], [False, True],
Theo Steininger's avatar
Theo Steininger committed
91
                    [False, True], [False, True], [False, True], [0.1, 1, 3.7],
92
                    [0.4, 1, 2.7],
Theo Steininger's avatar
Theo Steininger committed
93
                    [np.float64, np.complex128, np.float32, np.complex64]))
94
    def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d1, d2, itp):
Theo Steininger's avatar
Theo Steininger committed
95
96
97
98
        if module == "fftw_mpi":
            if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
                raise SkipTest
        if module == "fftw" and "fftw" not in gdi:
Martin Reinecke's avatar
Martin Reinecke committed
99
            raise SkipTest
Theo Steininger's avatar
Theo Steininger committed
100
        tol = _get_rtol(itp)
Martin Reinecke's avatar
Martin Reinecke committed
101
        a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=[d1, d2])
102
        b = RGSpace([dim1, dim2], zerocenter=[zc3, zc4],
Martin Reinecke's avatar
Martin Reinecke committed
103
                    distances=[1./(dim1*d1), 1./(dim2*d2)], harmonic=True)
Martin Reinecke's avatar
Martin Reinecke committed
104
        fft = FFTOperator(domain=a, target=b, domain_dtype=itp,
Theo Steininger's avatar
Theo Steininger committed
105
106
                          target_dtype=_harmonic_type(itp), module=module)
        inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
Martin Reinecke's avatar
Martin Reinecke committed
107
                                dtype=itp)
108
        out = fft.adjoint_times(fft.times(inp))
Theo Steininger's avatar
Theo Steininger committed
109
        assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
Martin Reinecke's avatar
Martin Reinecke committed
110

Theo Steininger's avatar
Theo Steininger committed
111
112
113
    @expand(product([0, 3, 6, 11, 30],
                    [np.float64, np.complex128, np.float32, np.complex64]))
    def test_sht(self, lm, tp):
114
        if 'pyHealpix' not in gdi:
115
            raise SkipTest
Theo Steininger's avatar
Theo Steininger committed
116
        tol = _get_rtol(tp)
Martin Reinecke's avatar
Martin Reinecke committed
117
        a = LMSpace(lmax=lm)
118
        b = GLSpace(nlat=lm+1)
Martin Reinecke's avatar
Martin Reinecke committed
119
        fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
Theo Steininger's avatar
Theo Steininger committed
120
        inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
Martin Reinecke's avatar
Martin Reinecke committed
121
                                dtype=tp)
122
        out = fft.adjoint_times(fft.times(inp))
Theo Steininger's avatar
Theo Steininger committed
123
        assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
124

Theo Steininger's avatar
Theo Steininger committed
125
126
127
    @expand(product([128, 256],
                    [np.float64, np.complex128, np.float32, np.complex64]))
    def test_sht2(self, lm, tp):
128
        if 'pyHealpix' not in gdi:
129
            raise SkipTest
Martin Reinecke's avatar
Martin Reinecke committed
130
        a = LMSpace(lmax=lm)
131
        b = HPSpace(nside=lm//2)
Martin Reinecke's avatar
Martin Reinecke committed
132
        fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
Theo Steininger's avatar
Theo Steininger committed
133
        inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
Martin Reinecke's avatar
Martin Reinecke committed
134
                                dtype=tp)
135
136
        out = fft.adjoint_times(fft.times(inp))
        assert_allclose(inp.val, out.val, rtol=1e-3, atol=1e-1)
137
138
139
140

    @expand(product([128, 256],
                    [np.float64, np.complex128, np.float32, np.complex64]))
    def test_dotsht(self, lm, tp):
141
        if 'pyHealpix' not in gdi:
142
143
144
145
146
147
148
149
            raise SkipTest
        tol = _get_rtol(tp)
        a = LMSpace(lmax=lm)
        b = GLSpace(nlat=lm+1)
        fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
        inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
                                dtype=tp)
        out = fft.times(inp)
Martin Reinecke's avatar
Martin Reinecke committed
150
151
        v1 = np.sqrt(out.vdot(out))
        v2 = np.sqrt(inp.vdot(fft.adjoint_times(out)))
Martin Reinecke's avatar
Martin Reinecke committed
152
        assert_allclose(v1, v2, rtol=tol, atol=tol)
153
154
155
156

    @expand(product([128, 256],
                    [np.float64, np.complex128, np.float32, np.complex64]))
    def test_dotsht2(self, lm, tp):
157
        if 'pyHealpix' not in gdi:
158
159
160
161
162
163
164
165
            raise SkipTest
        tol = _get_rtol(tp)
        a = LMSpace(lmax=lm)
        b = HPSpace(nside=lm//2)
        fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
        inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
                                dtype=tp)
        out = fft.times(inp)
Martin Reinecke's avatar
Martin Reinecke committed
166
167
        v1 = np.sqrt(out.vdot(out))
        v2 = np.sqrt(inp.vdot(fft.adjoint_times(out)))
Martin Reinecke's avatar
Martin Reinecke committed
168
        assert_allclose(v1, v2, rtol=tol, atol=tol)