Commit 693d7194 authored by Theo Steininger's avatar Theo Steininger

Some renaming.

parent b6257738
Pipeline #13332 failed with stage
in 5 minutes and 7 seconds
......@@ -28,26 +28,27 @@ __all__ = ['dependency_injector', 'nifty_configuration']
# Setup the dependency injector
dependency_injector = keepers.DependencyInjector(
[('mpi4py.MPI', 'MPI'),
('pyfftw', 'fftw'),
'pyHealpix',
'plotly'])
dependency_injector.register(('pyfftw', 'fftw_mpi'),
lambda z: hasattr(z, 'FFTW_MPI'))
dependency_injector.register(('pyfftw', 'fftw_scalar'))
def _fft_module_checker(z):
if z == 'mpi_fftw':
return 'fftw_mpi' in dependency_injector
if z == 'scalar_fftw':
return 'fftw_scalar' in dependency_injector
if z == 'fftw_mpi':
if 'fftw' in dependency_injector:
if lambda z: hasattr(dependency_injector['fftw'], 'FFTW_MPI'):
return True
else:
return False
if z == 'fftw':
return 'fftw' in dependency_injector
return True
# Initialize the variables
variable_fft_module = keepers.Variable(
'fft_module',
['mpi_fftw', 'scalar_fftw', 'scalar_numpy'],
lambda z: _fft_module_checker(z))
['fftw_mpi', 'fftw', 'numpy'],
_fft_module_checker)
def dtype_validator(dtype):
......
......@@ -25,8 +25,7 @@ import nifty.nifty_utilities as utilities
from keepers import Loggable
fftw_mpi = gdi.get('fftw_mpi')
fftw_scalar = gdi.get('fftw_scalar')
fftw = gdi.get('fftw')
class Transform(Loggable, object):
......@@ -208,14 +207,14 @@ class MPIFFT(Transform):
def __init__(self, domain, codomain):
if fftw_mpi is None:
if not hasattr(fftw, 'FFTW_MPI'):
raise ImportError(
"The MPI FFTW module is needed but not available.")
super(MPIFFT, self).__init__(domain, codomain)
# Enable caching
fftw_mpi.interfaces.cache.enable()
fftw.interfaces.cache.enable()
# The plan_dict stores the FFTWTransformInfo objects which correspond
# to a certain set of (field_val, domain, codomain) sets.
......@@ -469,7 +468,7 @@ class MPIFFT(Transform):
class FFTWTransformInfo(object):
def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, **kwargs):
if fftw_mpi is None:
if not hasattr(fftw, 'FFTW_MPI'):
raise ImportError(
"The MPI FFTW module is needed but not available.")
......@@ -515,9 +514,9 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
fftw_context,
**kwargs)
if codomain.harmonic:
self._fftw_interface = fftw_mpi.interfaces.numpy_fft.fftn
self._fftw_interface = fftw.interfaces.numpy_fft.fftn
else:
self._fftw_interface = fftw_mpi.interfaces.numpy_fft.ifftn
self._fftw_interface = fftw.interfaces.numpy_fft.ifftn
@property
def fftw_interface(self):
......@@ -534,7 +533,7 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
local_offset_Q,
fftw_context,
**kwargs)
self._plan = fftw_mpi.create_mpi_plan(
self._plan = fftw.create_mpi_plan(
input_shape=transform_shape,
input_dtype='complex128',
output_dtype='complex128',
......@@ -548,22 +547,22 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
return self._plan
class ScalarFFT(Transform):
class SerialFFT(Transform):
"""
The numpy fft pendant of a fft object.
"""
def __init__(self, domain, codomain, fftw):
super(ScalarFFT, self).__init__(domain, codomain)
def __init__(self, domain, codomain, use_fftw):
super(SerialFFT, self).__init__(domain, codomain)
if fftw and (fftw_scalar is None):
if use_fftw and (fftw is None):
raise ImportError(
"The scalar FFTW module is needed but not available.")
"The serial FFTW module is needed but not available.")
self._fftw = fftw
self._use_fftw = use_fftw
# Enable caching
if self._fftw:
fftw_scalar.interfaces.cache.enable()
if self._use_fftw:
fftw.interfaces.cache.enable()
def transform(self, val, axes, **kwargs):
"""
......@@ -640,12 +639,12 @@ class ScalarFFT(Transform):
local_val = self._apply_mask(temp_val, mask, axes)
# perform the transformation
if self._fftw:
if self._use_fftw:
if self.codomain.harmonic:
result_val = fftw_scalar.interfaces.numpy_fft.fftn(
result_val = fftw.interfaces.numpy_fft.fftn(
local_val, axes=axes)
else:
result_val = fftw_scalar.interfaces.numpy_fft.ifftn(
result_val = fftw.interfaces.numpy_fft.ifftn(
local_val, axes=axes)
else:
if self.codomain.harmonic:
......
......@@ -18,7 +18,7 @@
import numpy as np
from transformation import Transformation
from rg_transforms import MPIFFT, ScalarFFT
from rg_transforms import MPIFFT, SerialFFT
from nifty import RGSpace, nifty_configuration
......@@ -32,12 +32,14 @@ class RGRGTransformation(Transformation):
if module is None:
module = nifty_configuration['fft_module']
if module == 'mpi_fftw':
if module == 'fftw_mpi':
self._transform = MPIFFT(self.domain, self.codomain)
elif module == 'scalar_fftw':
self._transform = ScalarFFT(self.domain, self.codomain, True)
elif module == 'scalar_numpy':
self._transform = ScalarFFT(self.domain, self.codomain, False)
elif module == 'fftw':
self._transform = SerialFFT(self.domain, self.codomain,
use_fftw=True)
elif module == 'numpy':
self._transform = SerialFFT(self.domain, self.codomain,
use_fftw=False)
else:
raise ValueError('Unsupported FFT module:' + module)
......
......@@ -62,14 +62,15 @@ class FFTOperatorTests(unittest.TestCase):
res = foo.get_distance_array('not')
assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.)
@expand(product(["scalar_numpy", "scalar_fftw", "mpi_fftw"],
@expand(product(["numpy", "fftw", "fftw_mpi"],
[10, 11], [False, True], [False, True],
[0.1, 1, 3.7],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_fft1D(self, module, dim1, zc1, zc2, d, itp):
if module == "mpi_fftw" and "fftw_mpi" not in gdi:
raise SkipTest
if module == "scalar_fftw" and "fftw_scalar" not in gdi:
if module == "fftw_mpi":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
raise SkipTest
if module == "fftw" and "fftw" not in gdi:
raise SkipTest
tol = _get_rtol(itp)
a = RGSpace(dim1, zerocenter=zc1, distances=d)
......@@ -81,15 +82,16 @@ class FFTOperatorTests(unittest.TestCase):
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product(["scalar_numpy", "scalar_fftw", "mpi_fftw"],
@expand(product(["numpy", "fftw", "fftw_mpi"],
[10, 11], [9, 12], [False, True],
[False, True], [False, True], [False, True], [0.1, 1, 3.7],
[0.4, 1, 2.7],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d1, d2, itp):
if module == "mpi_fftw" and "fftw_mpi" not in gdi:
raise SkipTest
if module == "scalar_fftw" and "fftw_scalar" not in gdi:
if module == "fftw_mpi":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
raise SkipTest
if module == "fftw" and "fftw" not in gdi:
raise SkipTest
tol = _get_rtol(itp)
a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=[d1, d2])
......
......@@ -29,7 +29,7 @@ from types import NoneType
from test.common import expand
from itertools import product, chain
# needed to check wether fftw is available
from d2o.config import dependency_injector as gdi
from nifty import dependency_injector as gdi
from nose.plugins.skip import SkipTest
HARMONIC_SPACES = [RGSpace((8,), harmonic=True),
......@@ -134,24 +134,27 @@ class PowerSpaceInterfaceTest(unittest.TestCase):
class PowerSpaceConsistencyCheck(unittest.TestCase):
@expand(CONSISTENCY_CONFIGS)
def test_pipundexInversion(self, harmonic_partner, distribution_strategy,
binbounds, nbin, logarithmic):
if distribution_strategy == "fftw" and "fftw_mpi" not in gdi:
raise SkipTest
p = PowerSpace(harmonic_partner=harmonic_partner,
distribution_strategy=distribution_strategy,
logarithmic=logarithmic, nbin=nbin,
binbounds=binbounds)
assert_equal(p.pindex.flatten()[p.pundex], np.arange(p.dim),
err_msg='pundex is not right-inverse of pindex!')
# @expand(CONSISTENCY_CONFIGS)
# def test_pipundexInversion(self, harmonic_partner, distribution_strategy,
# binbounds, nbin, logarithmic):
# if distribution_strategy == "fftw":
# if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
# raise SkipTest
# p = PowerSpace(harmonic_partner=harmonic_partner,
# distribution_strategy=distribution_strategy,
# logarithmic=logarithmic, nbin=nbin,
# binbounds=binbounds)
# assert_equal(p.pindex.flatten()[p.pundex], np.arange(p.dim),
# err_msg='pundex is not right-inverse of pindex!')
@expand(CONSISTENCY_CONFIGS)
def test_rhopindexConsistency(self, harmonic_partner,
distribution_strategy, binbounds, nbin,
logarithmic):
if distribution_strategy == "fftw" and "fftw_mpi" not in gdi:
raise SkipTest
if distribution_strategy == "fftw":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
print (gdi.get('fftw'), "blub \n\n\n")
raise SkipTest
p = PowerSpace(harmonic_partner=harmonic_partner,
distribution_strategy=distribution_strategy,
logarithmic=logarithmic, nbin=nbin,
......@@ -164,7 +167,9 @@ class PowerSpaceFunctionalityTest(unittest.TestCase):
@expand(CONSTRUCTOR_CONFIGS)
def test_constructor(self, harmonic_partner, distribution_strategy,
logarithmic, nbin, binbounds, expected):
if distribution_strategy == "fftw" and "fftw_mpi" not in gdi:
if distribution_strategy == "fftw":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
raise SkipTest
raise SkipTest
if 'error' in expected:
with assert_raises(expected['error']):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment