Commit b0d67512 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

renaming, WIP

parent 88d6e60d
Pipeline #13115 passed with stage
in 5 minutes and 12 seconds
...@@ -31,15 +31,15 @@ dependency_injector = keepers.DependencyInjector( ...@@ -31,15 +31,15 @@ dependency_injector = keepers.DependencyInjector(
'pyHealpix', 'pyHealpix',
'plotly']) 'plotly'])
dependency_injector.register('pyfftw', lambda z: hasattr(z, 'FFTW_MPI')) dependency_injector.register(('pyfftw','fftw_mpi'), lambda z: hasattr(z, 'FFTW_MPI'))
dependency_injector.register(('pyfftw','pyfftw_scalar')) dependency_injector.register(('pyfftw','fftw_scalar'))
# Initialize the variables # Initialize the variables
variable_fft_module = keepers.Variable( variable_fft_module = keepers.Variable(
'fft_module', 'fft_module',
['fftw', 'numpy'], ['mpi', 'scalar'],
lambda z: (('pyfftw' in dependency_injector) lambda z: (('fftw_mpi' in dependency_injector)
if z == 'fftw' else True)) if z == 'mpi' else True))
def dtype_validator(dtype): def dtype_validator(dtype):
......
...@@ -63,9 +63,10 @@ class FFTOperator(LinearOperator): ...@@ -63,9 +63,10 @@ class FFTOperator(LinearOperator):
but for full control, the user should explicitly specify a codomain. but for full control, the user should explicitly specify a codomain.
module: String (optional) module: String (optional)
Software module employed for carrying out the transform operations. Software module employed for carrying out the transform operations.
For RGSpace pairs this can be "numpy" or "fftw", where "numpy" is For RGSpace pairs this can be "scalar" or "mpi", where "scalar" is
always available, but "fftw" offers higher performance and always available (using pyfftw if available, else numpy.fft), and "mpi"
parallelization. For sphere-related domains, only "pyHealpix" is requires pyfftw and offers MPI parallelization.
For sphere-related domains, only "pyHealpix" is
available. If omitted, "fftw" is selected for RGSpaces if available, available. If omitted, "fftw" is selected for RGSpaces if available,
else "numpy"; on the sphere the default is "pyHealpix". else "numpy"; on the sphere the default is "pyHealpix".
domain_dtype: data type (optional) domain_dtype: data type (optional)
......
...@@ -25,8 +25,8 @@ import nifty.nifty_utilities as utilities ...@@ -25,8 +25,8 @@ import nifty.nifty_utilities as utilities
from keepers import Loggable from keepers import Loggable
pyfftw = gdi.get('pyfftw') fftw_mpi = gdi.get('fftw_mpi')
pyfftw_scalar = gdi.get('pyfftw_scalar') fftw_scalar = gdi.get('fftw_scalar')
class Transform(Loggable, object): class Transform(Loggable, object):
...@@ -201,20 +201,21 @@ class Transform(Loggable, object): ...@@ -201,20 +201,21 @@ class Transform(Loggable, object):
raise NotImplementedError raise NotImplementedError
class FFTW(Transform): class MPIFFT(Transform):
""" """
The pyfftw pendant of a fft object. The MPI-parallel FFTW pendant of a fft object.
""" """
def __init__(self, domain, codomain): def __init__(self, domain, codomain):
if 'pyfftw' not in gdi: if 'fftw_mpi' not in gdi:
raise ImportError("The module pyfftw is needed but not available.") raise ImportError(
"The MPI FFTW module is needed but not available.")
super(FFTW, self).__init__(domain, codomain) super(MPIFFT, self).__init__(domain, codomain)
# Enable caching for pyfftw.interfaces # Enable caching
pyfftw.interfaces.cache.enable() fftw_mpi.interfaces.cache.enable()
# The plan_dict stores the FFTWTransformInfo objects which correspond # The plan_dict stores the FFTWTransformInfo objects which correspond
# to a certain set of (field_val, domain, codomain) sets. # to a certain set of (field_val, domain, codomain) sets.
...@@ -410,7 +411,7 @@ class FFTW(Transform): ...@@ -410,7 +411,7 @@ class FFTW(Transform):
def transform(self, val, axes, **kwargs): def transform(self, val, axes, **kwargs):
""" """
The pyfftw transform function. The MPI-parallel FFTW transform function.
Parameters Parameters
---------- ----------
...@@ -468,8 +469,9 @@ class FFTW(Transform): ...@@ -468,8 +469,9 @@ class FFTW(Transform):
class FFTWTransformInfo(object): class FFTWTransformInfo(object):
def __init__(self, domain, codomain, axes, local_shape, def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, **kwargs): local_offset_Q, fftw_context, **kwargs):
if pyfftw is None: if fftw_mpi is None:
raise ImportError("The module pyfftw is needed but not available.") raise ImportError(
"The MPI FFTW module is needed but not available.")
shape = (local_shape if axes is None else shape = (local_shape if axes is None else
[y for x, y in enumerate(local_shape) if x in axes]) [y for x, y in enumerate(local_shape) if x in axes])
...@@ -513,9 +515,9 @@ class FFTWLocalTransformInfo(FFTWTransformInfo): ...@@ -513,9 +515,9 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
fftw_context, fftw_context,
**kwargs) **kwargs)
if codomain.harmonic: if codomain.harmonic:
self._fftw_interface = pyfftw.interfaces.numpy_fft.fftn self._fftw_interface = fftw_mpi.interfaces.numpy_fft.fftn
else: else:
self._fftw_interface = pyfftw.interfaces.numpy_fft.ifftn self._fftw_interface = fftw_mpi.interfaces.numpy_fft.ifftn
@property @property
def fftw_interface(self): def fftw_interface(self):
...@@ -532,7 +534,7 @@ class FFTWMPITransfromInfo(FFTWTransformInfo): ...@@ -532,7 +534,7 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
local_offset_Q, local_offset_Q,
fftw_context, fftw_context,
**kwargs) **kwargs)
self._plan = pyfftw.create_mpi_plan( self._plan = fftw_mpi.create_mpi_plan(
input_shape=transform_shape, input_shape=transform_shape,
input_dtype='complex128', input_dtype='complex128',
output_dtype='complex128', output_dtype='complex128',
...@@ -546,7 +548,7 @@ class FFTWMPITransfromInfo(FFTWTransformInfo): ...@@ -546,7 +548,7 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
return self._plan return self._plan
class NUMPYFFT(Transform): class ScalarFFT(Transform):
""" """
The numpy fft pendant of a fft object. The numpy fft pendant of a fft object.
...@@ -554,7 +556,7 @@ class NUMPYFFT(Transform): ...@@ -554,7 +556,7 @@ class NUMPYFFT(Transform):
def transform(self, val, axes, **kwargs): def transform(self, val, axes, **kwargs):
""" """
The pyfftw transform function. The scalar FFT transform function.
Parameters Parameters
---------- ----------
...@@ -573,9 +575,9 @@ class NUMPYFFT(Transform): ...@@ -573,9 +575,9 @@ class NUMPYFFT(Transform):
result : np.ndarray or distributed_data_object result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field. Fourier-transformed pendant of the input field.
""" """
# Enable caching for pyfftw_scalar.interfaces # Enable caching
if 'pyfftw_scalar' in gdi: if 'fftw_scalar' in gdi:
pyfftw_scalar.interfaces.cache.enable() fftw_scalar.interfaces.cache.enable()
# Check if the axes provided are valid given the shape # Check if the axes provided are valid given the shape
if axes is not None and \ if axes is not None and \
...@@ -630,11 +632,11 @@ class NUMPYFFT(Transform): ...@@ -630,11 +632,11 @@ class NUMPYFFT(Transform):
local_val = self._apply_mask(temp_val, mask, axes) local_val = self._apply_mask(temp_val, mask, axes)
# perform the transformation # perform the transformation
if 'pyfftw_scalar' in gdi: if 'fftw_scalar' in gdi:
if self.codomain.harmonic: if self.codomain.harmonic:
result_val = pyfftw_scalar.interfaces.numpy_fft.fftn(local_val, axes=axes) result_val = fftw_scalar.interfaces.numpy_fft.fftn(local_val, axes=axes)
else: else:
result_val = pyfftw_scalar.interfaces.numpy_fft.ifftn(local_val, axes=axes) result_val = fftw_scalar.interfaces.numpy_fft.ifftn(local_val, axes=axes)
else: else:
if self.codomain.harmonic: if self.codomain.harmonic:
result_val = np.fft.fftn(local_val, axes=axes) result_val = np.fft.fftn(local_val, axes=axes)
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import numpy as np import numpy as np
from transformation import Transformation from transformation import Transformation
from rg_transforms import FFTW, NUMPYFFT from rg_transforms import MPIFFT, ScalarFFT
from nifty import RGSpace, nifty_configuration from nifty import RGSpace, nifty_configuration
...@@ -30,18 +30,18 @@ class RGRGTransformation(Transformation): ...@@ -30,18 +30,18 @@ class RGRGTransformation(Transformation):
super(RGRGTransformation, self).__init__(domain, codomain, module) super(RGRGTransformation, self).__init__(domain, codomain, module)
if module is None: if module is None:
if nifty_configuration['fft_module'] == 'fftw': if nifty_configuration['fft_module'] == 'mpi':
self._transform = FFTW(self.domain, self.codomain) self._transform = MPIFFT(self.domain, self.codomain)
elif nifty_configuration['fft_module'] == 'numpy': elif nifty_configuration['fft_module'] == 'scalar':
self._transform = NUMPYFFT(self.domain, self.codomain) self._transform = ScalarFFT(self.domain, self.codomain)
else: else:
raise ValueError('Unsupported default FFT module:' + raise ValueError('Unsupported default FFT module:' +
nifty_configuration['fft_module']) nifty_configuration['fft_module'])
else: else:
if module == 'fftw': if module == 'mpi':
self._transform = FFTW(self.domain, self.codomain) self._transform = MPIFFT(self.domain, self.codomain)
elif module == 'numpy': elif module == 'scalar':
self._transform = NUMPYFFT(self.domain, self.codomain) self._transform = ScalarFFT(self.domain, self.codomain)
else: else:
raise ValueError('Unsupported FFT module:' + module) raise ValueError('Unsupported FFT module:' + module)
......
...@@ -62,11 +62,11 @@ class FFTOperatorTests(unittest.TestCase): ...@@ -62,11 +62,11 @@ class FFTOperatorTests(unittest.TestCase):
res = foo.get_distance_array('not') res = foo.get_distance_array('not')
assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.) assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.)
@expand(product(["numpy", "fftw"], [10, 11], [False, True], [False, True], @expand(product(["scalar","mpi"], [10, 11], [False, True], [False, True],
[0.1, 1, 3.7], [0.1, 1, 3.7],
[np.float64, np.complex128, np.float32, np.complex64])) [np.float64, np.complex128, np.float32, np.complex64]))
def test_fft1D(self, module, dim1, zc1, zc2, d, itp): def test_fft1D(self, module, dim1, zc1, zc2, d, itp):
if module == "fftw" and "pyfftw" not in di: if module == "mpi" and "fftw_mpi" not in di:
raise SkipTest raise SkipTest
tol = _get_rtol(itp) tol = _get_rtol(itp)
a = RGSpace(dim1, zerocenter=zc1, distances=d) a = RGSpace(dim1, zerocenter=zc1, distances=d)
...@@ -78,12 +78,12 @@ class FFTOperatorTests(unittest.TestCase): ...@@ -78,12 +78,12 @@ class FFTOperatorTests(unittest.TestCase):
out = fft.adjoint_times(fft.times(inp)) out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol) assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product(["numpy", "fftw"], [10, 11], [9, 12], [False, True], @expand(product(["scalar", "mpi"], [10, 11], [9, 12], [False, True],
[False, True], [False, True], [False, True], [0.1, 1, 3.7], [False, True], [False, True], [False, True], [0.1, 1, 3.7],
[0.4, 1, 2.7], [0.4, 1, 2.7],
[np.float64, np.complex128, np.float32, np.complex64])) [np.float64, np.complex128, np.float32, np.complex64]))
def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d1, d2, itp): def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d1, d2, itp):
if module == "fftw" and "pyfftw" not in di: if module == "mpi" and "fftw_mpi" not in di:
raise SkipTest raise SkipTest
tol = _get_rtol(itp) tol = _get_rtol(itp)
a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=[d1, d2]) a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=[d1, d2])
......
...@@ -32,22 +32,22 @@ from itertools import product, chain ...@@ -32,22 +32,22 @@ from itertools import product, chain
from d2o.config import dependency_injector as gdi from d2o.config import dependency_injector as gdi
HARMONIC_SPACES = [RGSpace((8,), harmonic=True), HARMONIC_SPACES = [RGSpace((8,), harmonic=True),
RGSpace((7,), harmonic=True,zerocenter=True), RGSpace((7,), harmonic=True,zerocenter=True),
RGSpace((8,), harmonic=True,zerocenter=True), RGSpace((8,), harmonic=True,zerocenter=True),
RGSpace((7,8), harmonic=True), RGSpace((7,8), harmonic=True),
RGSpace((7,8), harmonic=True, zerocenter=True), RGSpace((7,8), harmonic=True, zerocenter=True),
RGSpace((6,6), harmonic=True, zerocenter=True), RGSpace((6,6), harmonic=True, zerocenter=True),
RGSpace((7,5), harmonic=True, zerocenter=True), RGSpace((7,5), harmonic=True, zerocenter=True),
RGSpace((5,5), harmonic=True), RGSpace((5,5), harmonic=True),
RGSpace((4,5,7), harmonic=True), RGSpace((4,5,7), harmonic=True),
RGSpace((4,5,7), harmonic=True, zerocenter=True), RGSpace((4,5,7), harmonic=True, zerocenter=True),
LMSpace(6), LMSpace(6),
LMSpace(9)] LMSpace(9)]
#Try all sensible kinds of combinations of spaces, distributuion strategy and #Try all sensible kinds of combinations of spaces, distributuion strategy and
#binning parameters #binning parameters
_maybe_fftw = ["fftw"] if ('pyfftw' in gdi) else [] _maybe_fftw = ["fftw"] if ('fftw_mpi' in gdi) else []
CONSISTENCY_CONFIGS_IMPLICIT = product(HARMONIC_SPACES, ["not", "equal"] + _maybe_fftw, [None], [None, 3,4], [True, False]) CONSISTENCY_CONFIGS_IMPLICIT = product(HARMONIC_SPACES, ["not", "equal"] + _maybe_fftw, [None], [None, 3,4], [True, False])
CONSISTENCY_CONFIGS_EXPLICIT = product(HARMONIC_SPACES, ["not", "equal"] + _maybe_fftw, [[0.,1.3]],[None],[False]) CONSISTENCY_CONFIGS_EXPLICIT = product(HARMONIC_SPACES, ["not", "equal"] + _maybe_fftw, [[0.,1.3]],[None],[False])
...@@ -138,13 +138,13 @@ class PowerSpaceConsistencyCheck(unittest.TestCase): ...@@ -138,13 +138,13 @@ class PowerSpaceConsistencyCheck(unittest.TestCase):
binbounds=binbounds) binbounds=binbounds)
assert_equal(p.pindex.flatten()[p.pundex],np.arange(p.dim), assert_equal(p.pindex.flatten()[p.pundex],np.arange(p.dim),
err_msg='pundex is not right-inverse of pindex!') err_msg='pundex is not right-inverse of pindex!')
@expand(CONSISTENCY_CONFIGS) @expand(CONSISTENCY_CONFIGS)
def test_rhopindexConsistency(self, harmonic_partner, distribution_strategy, def test_rhopindexConsistency(self, harmonic_partner, distribution_strategy,
binbounds, nbin,logarithmic): binbounds, nbin,logarithmic):
assert_equal(p.pindex.flatten().bincount(), p.rho, assert_equal(p.pindex.flatten().bincount(), p.rho,
err_msg='rho is not equal to pindex degeneracy') err_msg='rho is not equal to pindex degeneracy')
class PowerSpaceFunctionalityTest(unittest.TestCase): class PowerSpaceFunctionalityTest(unittest.TestCase):
@expand(CONSISTENCY_CONFIGS) @expand(CONSISTENCY_CONFIGS)
def test_constructor(self, harmonic_partner, distribution_strategy, def test_constructor(self, harmonic_partner, distribution_strategy,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment