Commit 693d7194 authored by Theo Steininger's avatar Theo Steininger
Browse files

Some renaming.

parent b6257738
Pipeline #13332 failed with stage
in 5 minutes and 7 seconds
...@@ -28,26 +28,27 @@ __all__ = ['dependency_injector', 'nifty_configuration'] ...@@ -28,26 +28,27 @@ __all__ = ['dependency_injector', 'nifty_configuration']
# Setup the dependency injector # Setup the dependency injector
dependency_injector = keepers.DependencyInjector( dependency_injector = keepers.DependencyInjector(
[('mpi4py.MPI', 'MPI'), [('mpi4py.MPI', 'MPI'),
('pyfftw', 'fftw'),
'pyHealpix', 'pyHealpix',
'plotly']) 'plotly'])
dependency_injector.register(('pyfftw', 'fftw_mpi'),
lambda z: hasattr(z, 'FFTW_MPI'))
dependency_injector.register(('pyfftw', 'fftw_scalar'))
def _fft_module_checker(z): def _fft_module_checker(z):
if z == 'mpi_fftw': if z == 'fftw_mpi':
return 'fftw_mpi' in dependency_injector if 'fftw' in dependency_injector:
if z == 'scalar_fftw': if lambda z: hasattr(dependency_injector['fftw'], 'FFTW_MPI'):
return 'fftw_scalar' in dependency_injector return True
else:
return False
if z == 'fftw':
return 'fftw' in dependency_injector
return True return True
# Initialize the variables # Initialize the variables
variable_fft_module = keepers.Variable( variable_fft_module = keepers.Variable(
'fft_module', 'fft_module',
['mpi_fftw', 'scalar_fftw', 'scalar_numpy'], ['fftw_mpi', 'fftw', 'numpy'],
lambda z: _fft_module_checker(z)) _fft_module_checker)
def dtype_validator(dtype): def dtype_validator(dtype):
......
...@@ -25,8 +25,7 @@ import nifty.nifty_utilities as utilities ...@@ -25,8 +25,7 @@ import nifty.nifty_utilities as utilities
from keepers import Loggable from keepers import Loggable
fftw_mpi = gdi.get('fftw_mpi') fftw = gdi.get('fftw')
fftw_scalar = gdi.get('fftw_scalar')
class Transform(Loggable, object): class Transform(Loggable, object):
...@@ -208,14 +207,14 @@ class MPIFFT(Transform): ...@@ -208,14 +207,14 @@ class MPIFFT(Transform):
def __init__(self, domain, codomain): def __init__(self, domain, codomain):
if fftw_mpi is None: if not hasattr(fftw, 'FFTW_MPI'):
raise ImportError( raise ImportError(
"The MPI FFTW module is needed but not available.") "The MPI FFTW module is needed but not available.")
super(MPIFFT, self).__init__(domain, codomain) super(MPIFFT, self).__init__(domain, codomain)
# Enable caching # Enable caching
fftw_mpi.interfaces.cache.enable() fftw.interfaces.cache.enable()
# The plan_dict stores the FFTWTransformInfo objects which correspond # The plan_dict stores the FFTWTransformInfo objects which correspond
# to a certain set of (field_val, domain, codomain) sets. # to a certain set of (field_val, domain, codomain) sets.
...@@ -469,7 +468,7 @@ class MPIFFT(Transform): ...@@ -469,7 +468,7 @@ class MPIFFT(Transform):
class FFTWTransformInfo(object): class FFTWTransformInfo(object):
def __init__(self, domain, codomain, axes, local_shape, def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, **kwargs): local_offset_Q, fftw_context, **kwargs):
if fftw_mpi is None: if not hasattr(fftw, 'FFTW_MPI'):
raise ImportError( raise ImportError(
"The MPI FFTW module is needed but not available.") "The MPI FFTW module is needed but not available.")
...@@ -515,9 +514,9 @@ class FFTWLocalTransformInfo(FFTWTransformInfo): ...@@ -515,9 +514,9 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
fftw_context, fftw_context,
**kwargs) **kwargs)
if codomain.harmonic: if codomain.harmonic:
self._fftw_interface = fftw_mpi.interfaces.numpy_fft.fftn self._fftw_interface = fftw.interfaces.numpy_fft.fftn
else: else:
self._fftw_interface = fftw_mpi.interfaces.numpy_fft.ifftn self._fftw_interface = fftw.interfaces.numpy_fft.ifftn
@property @property
def fftw_interface(self): def fftw_interface(self):
...@@ -534,7 +533,7 @@ class FFTWMPITransfromInfo(FFTWTransformInfo): ...@@ -534,7 +533,7 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
local_offset_Q, local_offset_Q,
fftw_context, fftw_context,
**kwargs) **kwargs)
self._plan = fftw_mpi.create_mpi_plan( self._plan = fftw.create_mpi_plan(
input_shape=transform_shape, input_shape=transform_shape,
input_dtype='complex128', input_dtype='complex128',
output_dtype='complex128', output_dtype='complex128',
...@@ -548,22 +547,22 @@ class FFTWMPITransfromInfo(FFTWTransformInfo): ...@@ -548,22 +547,22 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
return self._plan return self._plan
class ScalarFFT(Transform): class SerialFFT(Transform):
""" """
The numpy fft pendant of a fft object. The numpy fft pendant of a fft object.
""" """
def __init__(self, domain, codomain, fftw): def __init__(self, domain, codomain, use_fftw):
super(ScalarFFT, self).__init__(domain, codomain) super(SerialFFT, self).__init__(domain, codomain)
if fftw and (fftw_scalar is None): if use_fftw and (fftw is None):
raise ImportError( raise ImportError(
"The scalar FFTW module is needed but not available.") "The serial FFTW module is needed but not available.")
self._fftw = fftw self._use_fftw = use_fftw
# Enable caching # Enable caching
if self._fftw: if self._use_fftw:
fftw_scalar.interfaces.cache.enable() fftw.interfaces.cache.enable()
def transform(self, val, axes, **kwargs): def transform(self, val, axes, **kwargs):
""" """
...@@ -640,12 +639,12 @@ class ScalarFFT(Transform): ...@@ -640,12 +639,12 @@ class ScalarFFT(Transform):
local_val = self._apply_mask(temp_val, mask, axes) local_val = self._apply_mask(temp_val, mask, axes)
# perform the transformation # perform the transformation
if self._fftw: if self._use_fftw:
if self.codomain.harmonic: if self.codomain.harmonic:
result_val = fftw_scalar.interfaces.numpy_fft.fftn( result_val = fftw.interfaces.numpy_fft.fftn(
local_val, axes=axes) local_val, axes=axes)
else: else:
result_val = fftw_scalar.interfaces.numpy_fft.ifftn( result_val = fftw.interfaces.numpy_fft.ifftn(
local_val, axes=axes) local_val, axes=axes)
else: else:
if self.codomain.harmonic: if self.codomain.harmonic:
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import numpy as np import numpy as np
from transformation import Transformation from transformation import Transformation
from rg_transforms import MPIFFT, ScalarFFT from rg_transforms import MPIFFT, SerialFFT
from nifty import RGSpace, nifty_configuration from nifty import RGSpace, nifty_configuration
...@@ -32,12 +32,14 @@ class RGRGTransformation(Transformation): ...@@ -32,12 +32,14 @@ class RGRGTransformation(Transformation):
if module is None: if module is None:
module = nifty_configuration['fft_module'] module = nifty_configuration['fft_module']
if module == 'mpi_fftw': if module == 'fftw_mpi':
self._transform = MPIFFT(self.domain, self.codomain) self._transform = MPIFFT(self.domain, self.codomain)
elif module == 'scalar_fftw': elif module == 'fftw':
self._transform = ScalarFFT(self.domain, self.codomain, True) self._transform = SerialFFT(self.domain, self.codomain,
elif module == 'scalar_numpy': use_fftw=True)
self._transform = ScalarFFT(self.domain, self.codomain, False) elif module == 'numpy':
self._transform = SerialFFT(self.domain, self.codomain,
use_fftw=False)
else: else:
raise ValueError('Unsupported FFT module:' + module) raise ValueError('Unsupported FFT module:' + module)
......
...@@ -62,14 +62,15 @@ class FFTOperatorTests(unittest.TestCase): ...@@ -62,14 +62,15 @@ class FFTOperatorTests(unittest.TestCase):
res = foo.get_distance_array('not') res = foo.get_distance_array('not')
assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.) assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.)
@expand(product(["scalar_numpy", "scalar_fftw", "mpi_fftw"], @expand(product(["numpy", "fftw", "fftw_mpi"],
[10, 11], [False, True], [False, True], [10, 11], [False, True], [False, True],
[0.1, 1, 3.7], [0.1, 1, 3.7],
[np.float64, np.complex128, np.float32, np.complex64])) [np.float64, np.complex128, np.float32, np.complex64]))
def test_fft1D(self, module, dim1, zc1, zc2, d, itp): def test_fft1D(self, module, dim1, zc1, zc2, d, itp):
if module == "mpi_fftw" and "fftw_mpi" not in gdi: if module == "fftw_mpi":
raise SkipTest if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
if module == "scalar_fftw" and "fftw_scalar" not in gdi: raise SkipTest
if module == "fftw" and "fftw" not in gdi:
raise SkipTest raise SkipTest
tol = _get_rtol(itp) tol = _get_rtol(itp)
a = RGSpace(dim1, zerocenter=zc1, distances=d) a = RGSpace(dim1, zerocenter=zc1, distances=d)
...@@ -81,15 +82,16 @@ class FFTOperatorTests(unittest.TestCase): ...@@ -81,15 +82,16 @@ class FFTOperatorTests(unittest.TestCase):
out = fft.adjoint_times(fft.times(inp)) out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol) assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product(["scalar_numpy", "scalar_fftw", "mpi_fftw"], @expand(product(["numpy", "fftw", "fftw_mpi"],
[10, 11], [9, 12], [False, True], [10, 11], [9, 12], [False, True],
[False, True], [False, True], [False, True], [0.1, 1, 3.7], [False, True], [False, True], [False, True], [0.1, 1, 3.7],
[0.4, 1, 2.7], [0.4, 1, 2.7],
[np.float64, np.complex128, np.float32, np.complex64])) [np.float64, np.complex128, np.float32, np.complex64]))
def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d1, d2, itp): def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d1, d2, itp):
if module == "mpi_fftw" and "fftw_mpi" not in gdi: if module == "fftw_mpi":
raise SkipTest if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
if module == "scalar_fftw" and "fftw_scalar" not in gdi: raise SkipTest
if module == "fftw" and "fftw" not in gdi:
raise SkipTest raise SkipTest
tol = _get_rtol(itp) tol = _get_rtol(itp)
a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=[d1, d2]) a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=[d1, d2])
......
...@@ -29,7 +29,7 @@ from types import NoneType ...@@ -29,7 +29,7 @@ from types import NoneType
from test.common import expand from test.common import expand
from itertools import product, chain from itertools import product, chain
# needed to check wether fftw is available # needed to check wether fftw is available
from d2o.config import dependency_injector as gdi from nifty import dependency_injector as gdi
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
HARMONIC_SPACES = [RGSpace((8,), harmonic=True), HARMONIC_SPACES = [RGSpace((8,), harmonic=True),
...@@ -134,24 +134,27 @@ class PowerSpaceInterfaceTest(unittest.TestCase): ...@@ -134,24 +134,27 @@ class PowerSpaceInterfaceTest(unittest.TestCase):
class PowerSpaceConsistencyCheck(unittest.TestCase): class PowerSpaceConsistencyCheck(unittest.TestCase):
@expand(CONSISTENCY_CONFIGS) # @expand(CONSISTENCY_CONFIGS)
def test_pipundexInversion(self, harmonic_partner, distribution_strategy, # def test_pipundexInversion(self, harmonic_partner, distribution_strategy,
binbounds, nbin, logarithmic): # binbounds, nbin, logarithmic):
if distribution_strategy == "fftw" and "fftw_mpi" not in gdi: # if distribution_strategy == "fftw":
raise SkipTest # if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
p = PowerSpace(harmonic_partner=harmonic_partner, # raise SkipTest
distribution_strategy=distribution_strategy, # p = PowerSpace(harmonic_partner=harmonic_partner,
logarithmic=logarithmic, nbin=nbin, # distribution_strategy=distribution_strategy,
binbounds=binbounds) # logarithmic=logarithmic, nbin=nbin,
assert_equal(p.pindex.flatten()[p.pundex], np.arange(p.dim), # binbounds=binbounds)
err_msg='pundex is not right-inverse of pindex!') # assert_equal(p.pindex.flatten()[p.pundex], np.arange(p.dim),
# err_msg='pundex is not right-inverse of pindex!')
@expand(CONSISTENCY_CONFIGS) @expand(CONSISTENCY_CONFIGS)
def test_rhopindexConsistency(self, harmonic_partner, def test_rhopindexConsistency(self, harmonic_partner,
distribution_strategy, binbounds, nbin, distribution_strategy, binbounds, nbin,
logarithmic): logarithmic):
if distribution_strategy == "fftw" and "fftw_mpi" not in gdi: if distribution_strategy == "fftw":
raise SkipTest if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
print (gdi.get('fftw'), "blub \n\n\n")
raise SkipTest
p = PowerSpace(harmonic_partner=harmonic_partner, p = PowerSpace(harmonic_partner=harmonic_partner,
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
logarithmic=logarithmic, nbin=nbin, logarithmic=logarithmic, nbin=nbin,
...@@ -164,7 +167,9 @@ class PowerSpaceFunctionalityTest(unittest.TestCase): ...@@ -164,7 +167,9 @@ class PowerSpaceFunctionalityTest(unittest.TestCase):
@expand(CONSTRUCTOR_CONFIGS) @expand(CONSTRUCTOR_CONFIGS)
def test_constructor(self, harmonic_partner, distribution_strategy, def test_constructor(self, harmonic_partner, distribution_strategy,
logarithmic, nbin, binbounds, expected): logarithmic, nbin, binbounds, expected):
if distribution_strategy == "fftw" and "fftw_mpi" not in gdi: if distribution_strategy == "fftw":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
raise SkipTest
raise SkipTest raise SkipTest
if 'error' in expected: if 'error' in expected:
with assert_raises(expected['error']): with assert_raises(expected['error']):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment