Commit b6257738 authored by Martin Reinecke's avatar Martin Reinecke

allow explicit selection of scalar FFT module

parent dcebff53
Pipeline #13265 passed with stage
in 5 minutes and 19 seconds
......@@ -35,12 +35,19 @@ dependency_injector.register(('pyfftw', 'fftw_mpi'),
lambda z: hasattr(z, 'FFTW_MPI'))
dependency_injector.register(('pyfftw', 'fftw_scalar'))
def _fft_module_checker(z):
if z == 'mpi_fftw':
return 'fftw_mpi' in dependency_injector
if z == 'scalar_fftw':
return 'fftw_scalar' in dependency_injector
return True
# Initialize the variables
variable_fft_module = keepers.Variable(
'fft_module',
['mpi', 'scalar'],
lambda z: (('fftw_mpi' in dependency_injector)
if z == 'mpi' else True))
['mpi_fftw', 'scalar_fftw', 'scalar_numpy'],
lambda z: _fft_module_checker(z))
def dtype_validator(dtype):
......
......@@ -553,6 +553,17 @@ class ScalarFFT(Transform):
The numpy fft pendant of a fft object.
"""
def __init__(self, domain, codomain, fftw):
super(ScalarFFT, self).__init__(domain, codomain)
if fftw and (fftw_scalar is None):
raise ImportError(
"The scalar FFTW module is needed but not available.")
self._fftw = fftw
# Enable caching
if self._fftw:
fftw_scalar.interfaces.cache.enable()
def transform(self, val, axes, **kwargs):
"""
......@@ -575,9 +586,6 @@ class ScalarFFT(Transform):
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# Enable caching
if fftw_scalar is not None:
fftw_scalar.interfaces.cache.enable()
# Check if the axes provided are valid given the shape
if axes is not None and \
......@@ -632,7 +640,7 @@ class ScalarFFT(Transform):
local_val = self._apply_mask(temp_val, mask, axes)
# perform the transformation
if fftw_scalar is not None:
if self._fftw:
if self.codomain.harmonic:
result_val = fftw_scalar.interfaces.numpy_fft.fftn(
local_val, axes=axes)
......
......@@ -30,20 +30,16 @@ class RGRGTransformation(Transformation):
super(RGRGTransformation, self).__init__(domain, codomain, module)
if module is None:
if nifty_configuration['fft_module'] == 'mpi':
self._transform = MPIFFT(self.domain, self.codomain)
elif nifty_configuration['fft_module'] == 'scalar':
self._transform = ScalarFFT(self.domain, self.codomain)
else:
raise ValueError('Unsupported default FFT module:' +
nifty_configuration['fft_module'])
module = nifty_configuration['fft_module']
if module == 'mpi_fftw':
self._transform = MPIFFT(self.domain, self.codomain)
elif module == 'scalar_fftw':
self._transform = ScalarFFT(self.domain, self.codomain, True)
elif module == 'scalar_numpy':
self._transform = ScalarFFT(self.domain, self.codomain, False)
else:
if module == 'mpi':
self._transform = MPIFFT(self.domain, self.codomain)
elif module == 'scalar':
self._transform = ScalarFFT(self.domain, self.codomain)
else:
raise ValueError('Unsupported FFT module:' + module)
raise ValueError('Unsupported FFT module:' + module)
# ---Mandatory properties and methods---
......
......@@ -20,7 +20,7 @@ import unittest
import numpy as np
from numpy.testing import assert_equal,\
assert_allclose
from nifty.config import dependency_injector as di
from nifty.config import dependency_injector as gdi
from nifty import Field,\
RGSpace,\
LMSpace,\
......@@ -62,11 +62,14 @@ class FFTOperatorTests(unittest.TestCase):
res = foo.get_distance_array('not')
assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.)
@expand(product(["scalar", "mpi"], [10, 11], [False, True], [False, True],
@expand(product(["scalar_numpy", "scalar_fftw", "mpi_fftw"],
[10, 11], [False, True], [False, True],
[0.1, 1, 3.7],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_fft1D(self, module, dim1, zc1, zc2, d, itp):
if module == "mpi" and "fftw_mpi" not in di:
if module == "mpi_fftw" and "fftw_mpi" not in gdi:
raise SkipTest
if module == "scalar_fftw" and "fftw_scalar" not in gdi:
raise SkipTest
tol = _get_rtol(itp)
a = RGSpace(dim1, zerocenter=zc1, distances=d)
......@@ -78,12 +81,15 @@ class FFTOperatorTests(unittest.TestCase):
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product(["scalar", "mpi"], [10, 11], [9, 12], [False, True],
@expand(product(["scalar_numpy", "scalar_fftw", "mpi_fftw"],
[10, 11], [9, 12], [False, True],
[False, True], [False, True], [False, True], [0.1, 1, 3.7],
[0.4, 1, 2.7],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d1, d2, itp):
if module == "mpi" and "fftw_mpi" not in di:
if module == "mpi_fftw" and "fftw_mpi" not in gdi:
raise SkipTest
if module == "scalar_fftw" and "fftw_scalar" not in gdi:
raise SkipTest
tol = _get_rtol(itp)
a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=[d1, d2])
......@@ -99,7 +105,7 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([0, 3, 6, 11, 30],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_sht(self, lm, tp):
if 'pyHealpix' not in di:
if 'pyHealpix' not in gdi:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
......@@ -113,7 +119,7 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_sht2(self, lm, tp):
if 'pyHealpix' not in di:
if 'pyHealpix' not in gdi:
raise SkipTest
a = LMSpace(lmax=lm)
b = HPSpace(nside=lm//2)
......@@ -126,7 +132,7 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_dotsht(self, lm, tp):
if 'pyHealpix' not in di:
if 'pyHealpix' not in gdi:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
......@@ -142,7 +148,7 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_dotsht2(self, lm, tp):
if 'pyHealpix' not in di:
if 'pyHealpix' not in gdi:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment