Commit b6257738 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

allow explicit selection of scalar FFT module

parent dcebff53
Pipeline #13265 passed with stage
in 5 minutes and 19 seconds
...@@ -35,12 +35,19 @@ dependency_injector.register(('pyfftw', 'fftw_mpi'), ...@@ -35,12 +35,19 @@ dependency_injector.register(('pyfftw', 'fftw_mpi'),
lambda z: hasattr(z, 'FFTW_MPI')) lambda z: hasattr(z, 'FFTW_MPI'))
dependency_injector.register(('pyfftw', 'fftw_scalar')) dependency_injector.register(('pyfftw', 'fftw_scalar'))
def _fft_module_checker(z):
if z == 'mpi_fftw':
return 'fftw_mpi' in dependency_injector
if z == 'scalar_fftw':
return 'fftw_scalar' in dependency_injector
return True
# Initialize the variables # Initialize the variables
variable_fft_module = keepers.Variable( variable_fft_module = keepers.Variable(
'fft_module', 'fft_module',
['mpi', 'scalar'], ['mpi_fftw', 'scalar_fftw', 'scalar_numpy'],
lambda z: (('fftw_mpi' in dependency_injector) lambda z: _fft_module_checker(z))
if z == 'mpi' else True))
def dtype_validator(dtype): def dtype_validator(dtype):
......
...@@ -553,6 +553,17 @@ class ScalarFFT(Transform): ...@@ -553,6 +553,17 @@ class ScalarFFT(Transform):
The numpy fft pendant of a fft object. The numpy fft pendant of a fft object.
""" """
def __init__(self, domain, codomain, fftw):
super(ScalarFFT, self).__init__(domain, codomain)
if fftw and (fftw_scalar is None):
raise ImportError(
"The scalar FFTW module is needed but not available.")
self._fftw = fftw
# Enable caching
if self._fftw:
fftw_scalar.interfaces.cache.enable()
def transform(self, val, axes, **kwargs): def transform(self, val, axes, **kwargs):
""" """
...@@ -575,9 +586,6 @@ class ScalarFFT(Transform): ...@@ -575,9 +586,6 @@ class ScalarFFT(Transform):
result : np.ndarray or distributed_data_object result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field. Fourier-transformed pendant of the input field.
""" """
# Enable caching
if fftw_scalar is not None:
fftw_scalar.interfaces.cache.enable()
# Check if the axes provided are valid given the shape # Check if the axes provided are valid given the shape
if axes is not None and \ if axes is not None and \
...@@ -632,7 +640,7 @@ class ScalarFFT(Transform): ...@@ -632,7 +640,7 @@ class ScalarFFT(Transform):
local_val = self._apply_mask(temp_val, mask, axes) local_val = self._apply_mask(temp_val, mask, axes)
# perform the transformation # perform the transformation
if fftw_scalar is not None: if self._fftw:
if self.codomain.harmonic: if self.codomain.harmonic:
result_val = fftw_scalar.interfaces.numpy_fft.fftn( result_val = fftw_scalar.interfaces.numpy_fft.fftn(
local_val, axes=axes) local_val, axes=axes)
......
...@@ -30,20 +30,16 @@ class RGRGTransformation(Transformation): ...@@ -30,20 +30,16 @@ class RGRGTransformation(Transformation):
super(RGRGTransformation, self).__init__(domain, codomain, module) super(RGRGTransformation, self).__init__(domain, codomain, module)
if module is None: if module is None:
if nifty_configuration['fft_module'] == 'mpi': module = nifty_configuration['fft_module']
self._transform = MPIFFT(self.domain, self.codomain)
elif nifty_configuration['fft_module'] == 'scalar': if module == 'mpi_fftw':
self._transform = ScalarFFT(self.domain, self.codomain) self._transform = MPIFFT(self.domain, self.codomain)
else: elif module == 'scalar_fftw':
raise ValueError('Unsupported default FFT module:' + self._transform = ScalarFFT(self.domain, self.codomain, True)
nifty_configuration['fft_module']) elif module == 'scalar_numpy':
self._transform = ScalarFFT(self.domain, self.codomain, False)
else: else:
if module == 'mpi': raise ValueError('Unsupported FFT module:' + module)
self._transform = MPIFFT(self.domain, self.codomain)
elif module == 'scalar':
self._transform = ScalarFFT(self.domain, self.codomain)
else:
raise ValueError('Unsupported FFT module:' + module)
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
......
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
import numpy as np import numpy as np
from numpy.testing import assert_equal,\ from numpy.testing import assert_equal,\
assert_allclose assert_allclose
from nifty.config import dependency_injector as di from nifty.config import dependency_injector as gdi
from nifty import Field,\ from nifty import Field,\
RGSpace,\ RGSpace,\
LMSpace,\ LMSpace,\
...@@ -62,11 +62,14 @@ class FFTOperatorTests(unittest.TestCase): ...@@ -62,11 +62,14 @@ class FFTOperatorTests(unittest.TestCase):
res = foo.get_distance_array('not') res = foo.get_distance_array('not')
assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.) assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.)
@expand(product(["scalar", "mpi"], [10, 11], [False, True], [False, True], @expand(product(["scalar_numpy", "scalar_fftw", "mpi_fftw"],
[10, 11], [False, True], [False, True],
[0.1, 1, 3.7], [0.1, 1, 3.7],
[np.float64, np.complex128, np.float32, np.complex64])) [np.float64, np.complex128, np.float32, np.complex64]))
def test_fft1D(self, module, dim1, zc1, zc2, d, itp): def test_fft1D(self, module, dim1, zc1, zc2, d, itp):
if module == "mpi" and "fftw_mpi" not in di: if module == "mpi_fftw" and "fftw_mpi" not in gdi:
raise SkipTest
if module == "scalar_fftw" and "fftw_scalar" not in gdi:
raise SkipTest raise SkipTest
tol = _get_rtol(itp) tol = _get_rtol(itp)
a = RGSpace(dim1, zerocenter=zc1, distances=d) a = RGSpace(dim1, zerocenter=zc1, distances=d)
...@@ -78,12 +81,15 @@ class FFTOperatorTests(unittest.TestCase): ...@@ -78,12 +81,15 @@ class FFTOperatorTests(unittest.TestCase):
out = fft.adjoint_times(fft.times(inp)) out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol) assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product(["scalar", "mpi"], [10, 11], [9, 12], [False, True], @expand(product(["scalar_numpy", "scalar_fftw", "mpi_fftw"],
[10, 11], [9, 12], [False, True],
[False, True], [False, True], [False, True], [0.1, 1, 3.7], [False, True], [False, True], [False, True], [0.1, 1, 3.7],
[0.4, 1, 2.7], [0.4, 1, 2.7],
[np.float64, np.complex128, np.float32, np.complex64])) [np.float64, np.complex128, np.float32, np.complex64]))
def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d1, d2, itp): def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d1, d2, itp):
if module == "mpi" and "fftw_mpi" not in di: if module == "mpi_fftw" and "fftw_mpi" not in gdi:
raise SkipTest
if module == "scalar_fftw" and "fftw_scalar" not in gdi:
raise SkipTest raise SkipTest
tol = _get_rtol(itp) tol = _get_rtol(itp)
a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=[d1, d2]) a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=[d1, d2])
...@@ -99,7 +105,7 @@ class FFTOperatorTests(unittest.TestCase): ...@@ -99,7 +105,7 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([0, 3, 6, 11, 30], @expand(product([0, 3, 6, 11, 30],
[np.float64, np.complex128, np.float32, np.complex64])) [np.float64, np.complex128, np.float32, np.complex64]))
def test_sht(self, lm, tp): def test_sht(self, lm, tp):
if 'pyHealpix' not in di: if 'pyHealpix' not in gdi:
raise SkipTest raise SkipTest
tol = _get_rtol(tp) tol = _get_rtol(tp)
a = LMSpace(lmax=lm) a = LMSpace(lmax=lm)
...@@ -113,7 +119,7 @@ class FFTOperatorTests(unittest.TestCase): ...@@ -113,7 +119,7 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([128, 256], @expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64])) [np.float64, np.complex128, np.float32, np.complex64]))
def test_sht2(self, lm, tp): def test_sht2(self, lm, tp):
if 'pyHealpix' not in di: if 'pyHealpix' not in gdi:
raise SkipTest raise SkipTest
a = LMSpace(lmax=lm) a = LMSpace(lmax=lm)
b = HPSpace(nside=lm//2) b = HPSpace(nside=lm//2)
...@@ -126,7 +132,7 @@ class FFTOperatorTests(unittest.TestCase): ...@@ -126,7 +132,7 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([128, 256], @expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64])) [np.float64, np.complex128, np.float32, np.complex64]))
def test_dotsht(self, lm, tp): def test_dotsht(self, lm, tp):
if 'pyHealpix' not in di: if 'pyHealpix' not in gdi:
raise SkipTest raise SkipTest
tol = _get_rtol(tp) tol = _get_rtol(tp)
a = LMSpace(lmax=lm) a = LMSpace(lmax=lm)
...@@ -142,7 +148,7 @@ class FFTOperatorTests(unittest.TestCase): ...@@ -142,7 +148,7 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([128, 256], @expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64])) [np.float64, np.complex128, np.float32, np.complex64]))
def test_dotsht2(self, lm, tp): def test_dotsht2(self, lm, tp):
if 'pyHealpix' not in di: if 'pyHealpix' not in gdi:
raise SkipTest raise SkipTest
tol = _get_rtol(tp) tol = _get_rtol(tp)
a = LMSpace(lmax=lm) a = LMSpace(lmax=lm)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment