test_fft_operator.py 6.38 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
14
15
16
17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
Martin Reinecke's avatar
Martin Reinecke committed
18
19
20

import unittest
import numpy as np
Theo Steininger's avatar
Theo Steininger committed
21
22
from numpy.testing import assert_equal,\
    assert_allclose
23
from nifty.config import dependency_injector as di
Martin Reinecke's avatar
Martin Reinecke committed
24
from nifty import Field,\
Theo Steininger's avatar
Theo Steininger committed
25
26
    RGSpace,\
    LMSpace,\
27
28
    HPSpace,\
    GLSpace,\
29
    FFTOperator
Martin Reinecke's avatar
Martin Reinecke committed
30
31
from itertools import product
from test.common import expand
32
33
from nose.plugins.skip import SkipTest

Theo Steininger's avatar
Theo Steininger committed
34

Martin Reinecke's avatar
Martin Reinecke committed
35
def _harmonic_type(itp):
Theo Steininger's avatar
Theo Steininger committed
36
37
38
39
40
    otp = itp
    if otp == np.float64:
        otp = np.complex128
    elif otp == np.float32:
        otp = np.complex64
Martin Reinecke's avatar
Martin Reinecke committed
41
42
    return otp

Theo Steininger's avatar
Theo Steininger committed
43

Martin Reinecke's avatar
Martin Reinecke committed
44
def _get_rtol(tp):
Theo Steininger's avatar
Theo Steininger committed
45
    if (tp == np.float64) or (tp == np.complex128):
Martin Reinecke's avatar
Martin Reinecke committed
46
47
48
        return 1e-10
    else:
        return 1e-5
Martin Reinecke's avatar
Martin Reinecke committed
49

Theo Steininger's avatar
Theo Steininger committed
50

51
class FFTOperatorTests(unittest.TestCase):
Theo Steininger's avatar
Theo Steininger committed
52
53
54
    @expand(product([10, 11], [False, True], [0.1, 1, 3.7]))
    def test_RG_distance_1D(self, dim1, zc1, d):
        foo = RGSpace([dim1], zerocenter=zc1, distances=d)
Martin Reinecke's avatar
Martin Reinecke committed
55
        res = foo.get_distance_array('not')
Theo Steininger's avatar
Theo Steininger committed
56
        assert_equal(res[zc1 * (dim1 // 2)], 0.)
Martin Reinecke's avatar
Martin Reinecke committed
57

Theo Steininger's avatar
Theo Steininger committed
58
59
    @expand(product([10, 11], [9, 28], [False, True], [False, True],
                    [0.1, 1, 3.7]))
Martin Reinecke's avatar
Martin Reinecke committed
60
    def test_RG_distance_2D(self, dim1, dim2, zc1, zc2, d):
Theo Steininger's avatar
Theo Steininger committed
61
        foo = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=d)
Martin Reinecke's avatar
Martin Reinecke committed
62
        res = foo.get_distance_array('not')
Theo Steininger's avatar
Theo Steininger committed
63
        assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.)
Martin Reinecke's avatar
Martin Reinecke committed
64

Martin Reinecke's avatar
Martin Reinecke committed
65
    @expand(product(["scalar","mpi"], [10, 11], [False, True], [False, True],
Theo Steininger's avatar
Theo Steininger committed
66
67
68
                    [0.1, 1, 3.7],
                    [np.float64, np.complex128, np.float32, np.complex64]))
    def test_fft1D(self, module, dim1, zc1, zc2, d, itp):
Martin Reinecke's avatar
Martin Reinecke committed
69
        if module == "mpi" and "fftw_mpi" not in di:
Martin Reinecke's avatar
Martin Reinecke committed
70
            raise SkipTest
Theo Steininger's avatar
Theo Steininger committed
71
        tol = _get_rtol(itp)
Martin Reinecke's avatar
Martin Reinecke committed
72
        a = RGSpace(dim1, zerocenter=zc1, distances=d)
Martin Reinecke's avatar
Martin Reinecke committed
73
        b = RGSpace(dim1, zerocenter=zc2, distances=1./(dim1*d), harmonic=True)
Martin Reinecke's avatar
Martin Reinecke committed
74
        fft = FFTOperator(domain=a, target=b, domain_dtype=itp,
Theo Steininger's avatar
Theo Steininger committed
75
76
                          target_dtype=_harmonic_type(itp), module=module)
        inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
Martin Reinecke's avatar
Martin Reinecke committed
77
                                dtype=itp)
78
        out = fft.adjoint_times(fft.times(inp))
Theo Steininger's avatar
Theo Steininger committed
79
        assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
Martin Reinecke's avatar
Martin Reinecke committed
80

Martin Reinecke's avatar
Martin Reinecke committed
81
    @expand(product(["scalar", "mpi"], [10, 11], [9, 12], [False, True],
Theo Steininger's avatar
Theo Steininger committed
82
                    [False, True], [False, True], [False, True], [0.1, 1, 3.7],
83
                    [0.4, 1, 2.7],
Theo Steininger's avatar
Theo Steininger committed
84
                    [np.float64, np.complex128, np.float32, np.complex64]))
85
    def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d1, d2, itp):
Martin Reinecke's avatar
Martin Reinecke committed
86
        if module == "mpi" and "fftw_mpi" not in di:
Martin Reinecke's avatar
Martin Reinecke committed
87
            raise SkipTest
Theo Steininger's avatar
Theo Steininger committed
88
        tol = _get_rtol(itp)
Martin Reinecke's avatar
Martin Reinecke committed
89
        a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=[d1, d2])
90
        b = RGSpace([dim1, dim2], zerocenter=[zc3, zc4],
Martin Reinecke's avatar
Martin Reinecke committed
91
                    distances=[1./(dim1*d1), 1./(dim2*d2)], harmonic=True)
Martin Reinecke's avatar
Martin Reinecke committed
92
        fft = FFTOperator(domain=a, target=b, domain_dtype=itp,
Theo Steininger's avatar
Theo Steininger committed
93
94
                          target_dtype=_harmonic_type(itp), module=module)
        inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
Martin Reinecke's avatar
Martin Reinecke committed
95
                                dtype=itp)
96
        out = fft.adjoint_times(fft.times(inp))
Theo Steininger's avatar
Theo Steininger committed
97
        assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
Martin Reinecke's avatar
Martin Reinecke committed
98

Theo Steininger's avatar
Theo Steininger committed
99
100
101
    @expand(product([0, 3, 6, 11, 30],
                    [np.float64, np.complex128, np.float32, np.complex64]))
    def test_sht(self, lm, tp):
102
103
        if 'pyHealpix' not in di:
            raise SkipTest
Theo Steininger's avatar
Theo Steininger committed
104
        tol = _get_rtol(tp)
Martin Reinecke's avatar
Martin Reinecke committed
105
        a = LMSpace(lmax=lm)
106
        b = GLSpace(nlat=lm+1)
Martin Reinecke's avatar
Martin Reinecke committed
107
        fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
Theo Steininger's avatar
Theo Steininger committed
108
        inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
Martin Reinecke's avatar
Martin Reinecke committed
109
                                dtype=tp)
110
        out = fft.adjoint_times(fft.times(inp))
Theo Steininger's avatar
Theo Steininger committed
111
        assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
112

Theo Steininger's avatar
Theo Steininger committed
113
114
115
    @expand(product([128, 256],
                    [np.float64, np.complex128, np.float32, np.complex64]))
    def test_sht2(self, lm, tp):
116
117
        if 'pyHealpix' not in di:
            raise SkipTest
Martin Reinecke's avatar
Martin Reinecke committed
118
        a = LMSpace(lmax=lm)
119
        b = HPSpace(nside=lm//2)
Martin Reinecke's avatar
Martin Reinecke committed
120
        fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
Theo Steininger's avatar
Theo Steininger committed
121
        inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
Martin Reinecke's avatar
Martin Reinecke committed
122
                                dtype=tp)
123
124
        out = fft.adjoint_times(fft.times(inp))
        assert_allclose(inp.val, out.val, rtol=1e-3, atol=1e-1)
125
126
127
128
129
130
131
132
133
134
135
136
137

    @expand(product([128, 256],
                    [np.float64, np.complex128, np.float32, np.complex64]))
    def test_dotsht(self, lm, tp):
        if 'pyHealpix' not in di:
            raise SkipTest
        tol = _get_rtol(tp)
        a = LMSpace(lmax=lm)
        b = GLSpace(nlat=lm+1)
        fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
        inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
                                dtype=tp)
        out = fft.times(inp)
Martin Reinecke's avatar
Martin Reinecke committed
138
139
140
        v1 = np.sqrt(out.dot(out))
        v2 = np.sqrt(inp.dot(fft.adjoint_times(out)))
        assert_allclose(v1, v2, rtol=tol, atol=tol)
141
142
143
144
145
146
147
148
149
150
151
152
153

    @expand(product([128, 256],
                    [np.float64, np.complex128, np.float32, np.complex64]))
    def test_dotsht2(self, lm, tp):
        if 'pyHealpix' not in di:
            raise SkipTest
        tol = _get_rtol(tp)
        a = LMSpace(lmax=lm)
        b = HPSpace(nside=lm//2)
        fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
        inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
                                dtype=tp)
        out = fft.times(inp)
Martin Reinecke's avatar
Martin Reinecke committed
154
155
156
        v1 = np.sqrt(out.dot(out))
        v2 = np.sqrt(inp.dot(fft.adjoint_times(out)))
        assert_allclose(v1, v2, rtol=tol, atol=tol)