Commit b0d67512 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

renaming, WIP

parent 88d6e60d
Pipeline #13115 passed with stage
in 5 minutes and 12 seconds
......@@ -31,15 +31,15 @@ dependency_injector = keepers.DependencyInjector(
'pyHealpix',
'plotly'])
dependency_injector.register('pyfftw', lambda z: hasattr(z, 'FFTW_MPI'))
dependency_injector.register(('pyfftw','pyfftw_scalar'))
dependency_injector.register(('pyfftw','fftw_mpi'), lambda z: hasattr(z, 'FFTW_MPI'))
dependency_injector.register(('pyfftw','fftw_scalar'))
# Initialize the variables
variable_fft_module = keepers.Variable(
'fft_module',
['fftw', 'numpy'],
lambda z: (('pyfftw' in dependency_injector)
if z == 'fftw' else True))
['mpi', 'scalar'],
lambda z: (('fftw_mpi' in dependency_injector)
if z == 'mpi' else True))
def dtype_validator(dtype):
......
......@@ -63,9 +63,10 @@ class FFTOperator(LinearOperator):
but for full control, the user should explicitly specify a codomain.
module: String (optional)
Software module employed for carrying out the transform operations.
For RGSpace pairs this can be "numpy" or "fftw", where "numpy" is
always available, but "fftw" offers higher performance and
parallelization. For sphere-related domains, only "pyHealpix" is
For RGSpace pairs this can be "scalar" or "mpi", where "scalar" is
always available (using pyfftw if available, else numpy.fft), and "mpi"
requires pyfftw and offers MPI parallelization.
For sphere-related domains, only "pyHealpix" is
available. If omitted, "fftw" is selected for RGSpaces if available,
else "numpy"; on the sphere the default is "pyHealpix".
domain_dtype: data type (optional)
......
......@@ -25,8 +25,8 @@ import nifty.nifty_utilities as utilities
from keepers import Loggable
pyfftw = gdi.get('pyfftw')
pyfftw_scalar = gdi.get('pyfftw_scalar')
fftw_mpi = gdi.get('fftw_mpi')
fftw_scalar = gdi.get('fftw_scalar')
class Transform(Loggable, object):
......@@ -201,20 +201,21 @@ class Transform(Loggable, object):
raise NotImplementedError
class FFTW(Transform):
class MPIFFT(Transform):
"""
The pyfftw pendant of a fft object.
The MPI-parallel FFTW pendant of a fft object.
"""
def __init__(self, domain, codomain):
if 'pyfftw' not in gdi:
raise ImportError("The module pyfftw is needed but not available.")
if 'fftw_mpi' not in gdi:
raise ImportError(
"The MPI FFTW module is needed but not available.")
super(FFTW, self).__init__(domain, codomain)
super(MPIFFT, self).__init__(domain, codomain)
# Enable caching for pyfftw.interfaces
pyfftw.interfaces.cache.enable()
# Enable caching
fftw_mpi.interfaces.cache.enable()
# The plan_dict stores the FFTWTransformInfo objects which correspond
# to a certain set of (field_val, domain, codomain) sets.
......@@ -410,7 +411,7 @@ class FFTW(Transform):
def transform(self, val, axes, **kwargs):
"""
The pyfftw transform function.
The MPI-parallel FFTW transform function.
Parameters
----------
......@@ -468,8 +469,9 @@ class FFTW(Transform):
class FFTWTransformInfo(object):
def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, **kwargs):
if pyfftw is None:
raise ImportError("The module pyfftw is needed but not available.")
if fftw_mpi is None:
raise ImportError(
"The MPI FFTW module is needed but not available.")
shape = (local_shape if axes is None else
[y for x, y in enumerate(local_shape) if x in axes])
......@@ -513,9 +515,9 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
fftw_context,
**kwargs)
if codomain.harmonic:
self._fftw_interface = pyfftw.interfaces.numpy_fft.fftn
self._fftw_interface = fftw_mpi.interfaces.numpy_fft.fftn
else:
self._fftw_interface = pyfftw.interfaces.numpy_fft.ifftn
self._fftw_interface = fftw_mpi.interfaces.numpy_fft.ifftn
@property
def fftw_interface(self):
......@@ -532,7 +534,7 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
local_offset_Q,
fftw_context,
**kwargs)
self._plan = pyfftw.create_mpi_plan(
self._plan = fftw_mpi.create_mpi_plan(
input_shape=transform_shape,
input_dtype='complex128',
output_dtype='complex128',
......@@ -546,7 +548,7 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
return self._plan
class NUMPYFFT(Transform):
class ScalarFFT(Transform):
"""
The numpy fft pendant of a fft object.
......@@ -554,7 +556,7 @@ class NUMPYFFT(Transform):
def transform(self, val, axes, **kwargs):
"""
The pyfftw transform function.
The scalar FFT transform function.
Parameters
----------
......@@ -573,9 +575,9 @@ class NUMPYFFT(Transform):
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# Enable caching for pyfftw_scalar.interfaces
if 'pyfftw_scalar' in gdi:
pyfftw_scalar.interfaces.cache.enable()
# Enable caching
if 'fftw_scalar' in gdi:
fftw_scalar.interfaces.cache.enable()
# Check if the axes provided are valid given the shape
if axes is not None and \
......@@ -630,11 +632,11 @@ class NUMPYFFT(Transform):
local_val = self._apply_mask(temp_val, mask, axes)
# perform the transformation
if 'pyfftw_scalar' in gdi:
if 'fftw_scalar' in gdi:
if self.codomain.harmonic:
result_val = pyfftw_scalar.interfaces.numpy_fft.fftn(local_val, axes=axes)
result_val = fftw_scalar.interfaces.numpy_fft.fftn(local_val, axes=axes)
else:
result_val = pyfftw_scalar.interfaces.numpy_fft.ifftn(local_val, axes=axes)
result_val = fftw_scalar.interfaces.numpy_fft.ifftn(local_val, axes=axes)
else:
if self.codomain.harmonic:
result_val = np.fft.fftn(local_val, axes=axes)
......
......@@ -18,7 +18,7 @@
import numpy as np
from transformation import Transformation
from rg_transforms import FFTW, NUMPYFFT
from rg_transforms import MPIFFT, ScalarFFT
from nifty import RGSpace, nifty_configuration
......@@ -30,18 +30,18 @@ class RGRGTransformation(Transformation):
super(RGRGTransformation, self).__init__(domain, codomain, module)
if module is None:
if nifty_configuration['fft_module'] == 'fftw':
self._transform = FFTW(self.domain, self.codomain)
elif nifty_configuration['fft_module'] == 'numpy':
self._transform = NUMPYFFT(self.domain, self.codomain)
if nifty_configuration['fft_module'] == 'mpi':
self._transform = MPIFFT(self.domain, self.codomain)
elif nifty_configuration['fft_module'] == 'scalar':
self._transform = ScalarFFT(self.domain, self.codomain)
else:
raise ValueError('Unsupported default FFT module:' +
nifty_configuration['fft_module'])
else:
if module == 'fftw':
self._transform = FFTW(self.domain, self.codomain)
elif module == 'numpy':
self._transform = NUMPYFFT(self.domain, self.codomain)
if module == 'mpi':
self._transform = MPIFFT(self.domain, self.codomain)
elif module == 'scalar':
self._transform = ScalarFFT(self.domain, self.codomain)
else:
raise ValueError('Unsupported FFT module:' + module)
......
......@@ -62,11 +62,11 @@ class FFTOperatorTests(unittest.TestCase):
res = foo.get_distance_array('not')
assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.)
@expand(product(["numpy", "fftw"], [10, 11], [False, True], [False, True],
@expand(product(["scalar","mpi"], [10, 11], [False, True], [False, True],
[0.1, 1, 3.7],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_fft1D(self, module, dim1, zc1, zc2, d, itp):
if module == "fftw" and "pyfftw" not in di:
if module == "mpi" and "fftw_mpi" not in di:
raise SkipTest
tol = _get_rtol(itp)
a = RGSpace(dim1, zerocenter=zc1, distances=d)
......@@ -78,12 +78,12 @@ class FFTOperatorTests(unittest.TestCase):
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product(["numpy", "fftw"], [10, 11], [9, 12], [False, True],
@expand(product(["scalar", "mpi"], [10, 11], [9, 12], [False, True],
[False, True], [False, True], [False, True], [0.1, 1, 3.7],
[0.4, 1, 2.7],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d1, d2, itp):
if module == "fftw" and "pyfftw" not in di:
if module == "mpi" and "fftw_mpi" not in di:
raise SkipTest
tol = _get_rtol(itp)
a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=[d1, d2])
......
......@@ -32,22 +32,22 @@ from itertools import product, chain
from d2o.config import dependency_injector as gdi
HARMONIC_SPACES = [RGSpace((8,), harmonic=True),
RGSpace((7,), harmonic=True,zerocenter=True),
RGSpace((8,), harmonic=True,zerocenter=True),
RGSpace((7,8), harmonic=True),
RGSpace((7,), harmonic=True,zerocenter=True),
RGSpace((8,), harmonic=True,zerocenter=True),
RGSpace((7,8), harmonic=True),
RGSpace((7,8), harmonic=True, zerocenter=True),
RGSpace((6,6), harmonic=True, zerocenter=True),
RGSpace((7,5), harmonic=True, zerocenter=True),
RGSpace((5,5), harmonic=True),
RGSpace((5,5), harmonic=True),
RGSpace((4,5,7), harmonic=True),
RGSpace((4,5,7), harmonic=True, zerocenter=True),
LMSpace(6),
LMSpace(9)]
#Try all sensible kinds of combinations of spaces, distributuion strategy and
#Try all sensible kinds of combinations of spaces, distributuion strategy and
#binning parameters
_maybe_fftw = ["fftw"] if ('pyfftw' in gdi) else []
_maybe_fftw = ["fftw"] if ('fftw_mpi' in gdi) else []
CONSISTENCY_CONFIGS_IMPLICIT = product(HARMONIC_SPACES, ["not", "equal"] + _maybe_fftw, [None], [None, 3,4], [True, False])
CONSISTENCY_CONFIGS_EXPLICIT = product(HARMONIC_SPACES, ["not", "equal"] + _maybe_fftw, [[0.,1.3]],[None],[False])
......@@ -138,13 +138,13 @@ class PowerSpaceConsistencyCheck(unittest.TestCase):
binbounds=binbounds)
assert_equal(p.pindex.flatten()[p.pundex],np.arange(p.dim),
err_msg='pundex is not right-inverse of pindex!')
@expand(CONSISTENCY_CONFIGS)
def test_rhopindexConsistency(self, harmonic_partner, distribution_strategy,
binbounds, nbin,logarithmic):
assert_equal(p.pindex.flatten().bincount(), p.rho,
err_msg='rho is not equal to pindex degeneracy')
class PowerSpaceFunctionalityTest(unittest.TestCase):
@expand(CONSISTENCY_CONFIGS)
def test_constructor(self, harmonic_partner, distribution_strategy,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment