test_fft_operator.py 6.79 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
14
15
16
17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
Martin Reinecke's avatar
Martin Reinecke committed
18
19
20

import unittest
import numpy as np
Theo Steininger's avatar
Theo Steininger committed
21
22
from numpy.testing import assert_equal,\
    assert_allclose
23
from nifty.config import dependency_injector as gdi
Martin Reinecke's avatar
Martin Reinecke committed
24
from nifty import Field,\
Theo Steininger's avatar
Theo Steininger committed
25
26
    RGSpace,\
    LMSpace,\
27
28
    HPSpace,\
    GLSpace,\
29
    FFTOperator
Martin Reinecke's avatar
Martin Reinecke committed
30
31
from itertools import product
from test.common import expand
32
33
from nose.plugins.skip import SkipTest

Theo Steininger's avatar
Theo Steininger committed
34

Martin Reinecke's avatar
Martin Reinecke committed
35
def _harmonic_type(itp):
Theo Steininger's avatar
Theo Steininger committed
36
37
38
39
40
    otp = itp
    if otp == np.float64:
        otp = np.complex128
    elif otp == np.float32:
        otp = np.complex64
Martin Reinecke's avatar
Martin Reinecke committed
41
42
    return otp

Theo Steininger's avatar
Theo Steininger committed
43

Martin Reinecke's avatar
Martin Reinecke committed
44
def _get_rtol(tp):
Theo Steininger's avatar
Theo Steininger committed
45
    if (tp == np.float64) or (tp == np.complex128):
Martin Reinecke's avatar
Martin Reinecke committed
46
47
48
        return 1e-10
    else:
        return 1e-5
Martin Reinecke's avatar
Martin Reinecke committed
49

Theo Steininger's avatar
Theo Steininger committed
50

51
class FFTOperatorTests(unittest.TestCase):
Theo Steininger's avatar
Theo Steininger committed
52
53
54
    @expand(product([10, 11], [False, True], [0.1, 1, 3.7]))
    def test_RG_distance_1D(self, dim1, zc1, d):
        foo = RGSpace([dim1], zerocenter=zc1, distances=d)
Martin Reinecke's avatar
Martin Reinecke committed
55
        res = foo.get_distance_array('not')
Theo Steininger's avatar
Theo Steininger committed
56
        assert_equal(res[zc1 * (dim1 // 2)], 0.)
Martin Reinecke's avatar
Martin Reinecke committed
57

Theo Steininger's avatar
Theo Steininger committed
58
59
    @expand(product([10, 11], [9, 28], [False, True], [False, True],
                    [0.1, 1, 3.7]))
Martin Reinecke's avatar
Martin Reinecke committed
60
    def test_RG_distance_2D(self, dim1, dim2, zc1, zc2, d):
Theo Steininger's avatar
Theo Steininger committed
61
        foo = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=d)
Martin Reinecke's avatar
Martin Reinecke committed
62
        res = foo.get_distance_array('not')
Theo Steininger's avatar
Theo Steininger committed
63
        assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.)
Martin Reinecke's avatar
Martin Reinecke committed
64

Theo Steininger's avatar
Theo Steininger committed
65
    @expand(product(["numpy", "fftw", "fftw_mpi"],
66
                    [16, ], [False, True], [False, True],
Theo Steininger's avatar
Theo Steininger committed
67
68
69
                    [0.1, 1, 3.7],
                    [np.float64, np.complex128, np.float32, np.complex64]))
    def test_fft1D(self, module, dim1, zc1, zc2, d, itp):
Theo Steininger's avatar
Theo Steininger committed
70
71
72
73
        if module == "fftw_mpi":
            if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
                raise SkipTest
        if module == "fftw" and "fftw" not in gdi:
Martin Reinecke's avatar
Martin Reinecke committed
74
            raise SkipTest
Theo Steininger's avatar
Theo Steininger committed
75
        tol = _get_rtol(itp)
Martin Reinecke's avatar
Martin Reinecke committed
76
        a = RGSpace(dim1, zerocenter=zc1, distances=d)
Martin Reinecke's avatar
Martin Reinecke committed
77
        b = RGSpace(dim1, zerocenter=zc2, distances=1./(dim1*d), harmonic=True)
Martin Reinecke's avatar
Martin Reinecke committed
78
        fft = FFTOperator(domain=a, target=b, domain_dtype=itp,
Theo Steininger's avatar
Theo Steininger committed
79
                          target_dtype=_harmonic_type(itp), module=module)
80
        np.random.seed(16)
Theo Steininger's avatar
Theo Steininger committed
81
        inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
Martin Reinecke's avatar
Martin Reinecke committed
82
                                dtype=itp)
83
        out = fft.adjoint_times(fft.times(inp))
84
85
86
        assert_allclose(inp.val.get_full_data(),
                        out.val.get_full_data(),
                        rtol=tol, atol=tol)
Martin Reinecke's avatar
Martin Reinecke committed
87

Theo Steininger's avatar
Theo Steininger committed
88
    @expand(product(["numpy", "fftw", "fftw_mpi"],
89
                    [12, 15], [9, 12], [False, True],
Theo Steininger's avatar
Theo Steininger committed
90
                    [False, True], [False, True], [False, True], [0.1, 1, 3.7],
91
                    [0.4, 1, 2.7],
Theo Steininger's avatar
Theo Steininger committed
92
                    [np.float64, np.complex128, np.float32, np.complex64]))
93
    def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d1, d2, itp):
Theo Steininger's avatar
Theo Steininger committed
94
95
96
97
        if module == "fftw_mpi":
            if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
                raise SkipTest
        if module == "fftw" and "fftw" not in gdi:
Martin Reinecke's avatar
Martin Reinecke committed
98
            raise SkipTest
Theo Steininger's avatar
Theo Steininger committed
99
        tol = _get_rtol(itp)
Martin Reinecke's avatar
Martin Reinecke committed
100
        a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=[d1, d2])
101
        b = RGSpace([dim1, dim2], zerocenter=[zc3, zc4],
Martin Reinecke's avatar
Martin Reinecke committed
102
                    distances=[1./(dim1*d1), 1./(dim2*d2)], harmonic=True)
Martin Reinecke's avatar
Martin Reinecke committed
103
        fft = FFTOperator(domain=a, target=b, domain_dtype=itp,
Theo Steininger's avatar
Theo Steininger committed
104
105
                          target_dtype=_harmonic_type(itp), module=module)
        inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
Martin Reinecke's avatar
Martin Reinecke committed
106
                                dtype=itp)
107
        out = fft.adjoint_times(fft.times(inp))
Theo Steininger's avatar
Theo Steininger committed
108
        assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
Martin Reinecke's avatar
Martin Reinecke committed
109

Theo Steininger's avatar
Theo Steininger committed
110
111
112
    @expand(product([0, 3, 6, 11, 30],
                    [np.float64, np.complex128, np.float32, np.complex64]))
    def test_sht(self, lm, tp):
113
        if 'pyHealpix' not in gdi:
114
            raise SkipTest
Theo Steininger's avatar
Theo Steininger committed
115
        tol = _get_rtol(tp)
Martin Reinecke's avatar
Martin Reinecke committed
116
        a = LMSpace(lmax=lm)
117
        b = GLSpace(nlat=lm+1)
Martin Reinecke's avatar
Martin Reinecke committed
118
        fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
Theo Steininger's avatar
Theo Steininger committed
119
        inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
Martin Reinecke's avatar
Martin Reinecke committed
120
                                dtype=tp)
121
        out = fft.adjoint_times(fft.times(inp))
Theo Steininger's avatar
Theo Steininger committed
122
        assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
123

Theo Steininger's avatar
Theo Steininger committed
124
125
126
    @expand(product([128, 256],
                    [np.float64, np.complex128, np.float32, np.complex64]))
    def test_sht2(self, lm, tp):
127
        if 'pyHealpix' not in gdi:
128
            raise SkipTest
Martin Reinecke's avatar
Martin Reinecke committed
129
        a = LMSpace(lmax=lm)
130
        b = HPSpace(nside=lm//2)
Martin Reinecke's avatar
Martin Reinecke committed
131
        fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
Theo Steininger's avatar
Theo Steininger committed
132
        inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
Martin Reinecke's avatar
Martin Reinecke committed
133
                                dtype=tp)
134
135
        out = fft.adjoint_times(fft.times(inp))
        assert_allclose(inp.val, out.val, rtol=1e-3, atol=1e-1)
136
137
138
139

    @expand(product([128, 256],
                    [np.float64, np.complex128, np.float32, np.complex64]))
    def test_dotsht(self, lm, tp):
140
        if 'pyHealpix' not in gdi:
141
142
143
144
145
146
147
148
            raise SkipTest
        tol = _get_rtol(tp)
        a = LMSpace(lmax=lm)
        b = GLSpace(nlat=lm+1)
        fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
        inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
                                dtype=tp)
        out = fft.times(inp)
Martin Reinecke's avatar
Martin Reinecke committed
149
150
        v1 = np.sqrt(out.vdot(out))
        v2 = np.sqrt(inp.vdot(fft.adjoint_times(out)))
Martin Reinecke's avatar
Martin Reinecke committed
151
        assert_allclose(v1, v2, rtol=tol, atol=tol)
152
153
154
155

    @expand(product([128, 256],
                    [np.float64, np.complex128, np.float32, np.complex64]))
    def test_dotsht2(self, lm, tp):
156
        if 'pyHealpix' not in gdi:
157
158
159
160
161
162
163
164
            raise SkipTest
        tol = _get_rtol(tp)
        a = LMSpace(lmax=lm)
        b = HPSpace(nside=lm//2)
        fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
        inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
                                dtype=tp)
        out = fft.times(inp)
Martin Reinecke's avatar
Martin Reinecke committed
165
166
        v1 = np.sqrt(out.vdot(out))
        v2 = np.sqrt(inp.vdot(fft.adjoint_times(out)))
Martin Reinecke's avatar
Martin Reinecke committed
167
        assert_allclose(v1, v2, rtol=tol, atol=tol)