Commit a67eda68 authored by Theo Steininger's avatar Theo Steininger

Merge branch 'fftw_for_the_masses' into 'master'

allow pyfftw even if MPI is not present

See merge request !144
parents 05d1d092 bb5cd1e6
Pipeline #13335 passed with stages
in 11 minutes and 7 seconds
......@@ -28,17 +28,25 @@ __all__ = ['dependency_injector', 'nifty_configuration']
# Setup the dependency injector
dependency_injector = keepers.DependencyInjector(
[('mpi4py.MPI', 'MPI'),
('pyfftw', 'fftw'),
'pyHealpix',
'plotly'])
dependency_injector.register('pyfftw', lambda z: hasattr(z, 'FFTW_MPI'))
def _fft_module_checker(z):
if z == 'fftw_mpi':
return hasattr(dependency_injector.get('fftw'), 'FFTW_MPI')
if z == 'fftw':
return ('fftw' in dependency_injector)
if z == 'numpy':
return True
return False
# Initialize the variables
variable_fft_module = keepers.Variable(
'fft_module',
['fftw', 'numpy'],
lambda z: (('pyfftw' in dependency_injector)
if z == 'fftw' else True))
['fftw_mpi', 'fftw', 'numpy'],
_fft_module_checker)
def dtype_validator(dtype):
......
......@@ -63,9 +63,10 @@ class FFTOperator(LinearOperator):
but for full control, the user should explicitly specify a codomain.
module: String (optional)
Software module employed for carrying out the transform operations.
For RGSpace pairs this can be "numpy" or "fftw", where "numpy" is
always available, but "fftw" offers higher performance and
parallelization. For sphere-related domains, only "pyHealpix" is
For RGSpace pairs this can be "scalar" or "mpi", where "scalar" is
always available (using pyfftw if available, else numpy.fft), and "mpi"
requires pyfftw and offers MPI parallelization.
For sphere-related domains, only "pyHealpix" is
available. If omitted, "fftw" is selected for RGSpaces if available,
else "numpy"; on the sphere the default is "pyHealpix".
domain_dtype: data type (optional)
......
......@@ -25,7 +25,7 @@ import nifty.nifty_utilities as utilities
from keepers import Loggable
pyfftw = gdi.get('pyfftw')
fftw = gdi.get('fftw')
class Transform(Loggable, object):
......@@ -200,20 +200,21 @@ class Transform(Loggable, object):
raise NotImplementedError
class FFTW(Transform):
class MPIFFT(Transform):
"""
The pyfftw pendant of a fft object.
The MPI-parallel FFTW pendant of a fft object.
"""
def __init__(self, domain, codomain):
if 'pyfftw' not in gdi:
raise ImportError("The module pyfftw is needed but not available.")
if not hasattr(fftw, 'FFTW_MPI'):
raise ImportError(
"The MPI FFTW module is needed but not available.")
super(FFTW, self).__init__(domain, codomain)
super(MPIFFT, self).__init__(domain, codomain)
# Enable caching for pyfftw.interfaces
pyfftw.interfaces.cache.enable()
# Enable caching
fftw.interfaces.cache.enable()
# The plan_dict stores the FFTWTransformInfo objects which correspond
# to a certain set of (field_val, domain, codomain) sets.
......@@ -409,7 +410,7 @@ class FFTW(Transform):
def transform(self, val, axes, **kwargs):
"""
The pyfftw transform function.
The MPI-parallel FFTW transform function.
Parameters
----------
......@@ -467,8 +468,9 @@ class FFTW(Transform):
class FFTWTransformInfo(object):
def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, **kwargs):
if pyfftw is None:
raise ImportError("The module pyfftw is needed but not available.")
if not hasattr(fftw, 'FFTW_MPI'):
raise ImportError(
"The MPI FFTW module is needed but not available.")
shape = (local_shape if axes is None else
[y for x, y in enumerate(local_shape) if x in axes])
......@@ -512,9 +514,9 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
fftw_context,
**kwargs)
if codomain.harmonic:
self._fftw_interface = pyfftw.interfaces.numpy_fft.fftn
self._fftw_interface = fftw.interfaces.numpy_fft.fftn
else:
self._fftw_interface = pyfftw.interfaces.numpy_fft.ifftn
self._fftw_interface = fftw.interfaces.numpy_fft.ifftn
@property
def fftw_interface(self):
......@@ -531,7 +533,7 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
local_offset_Q,
fftw_context,
**kwargs)
self._plan = pyfftw.create_mpi_plan(
self._plan = fftw.create_mpi_plan(
input_shape=transform_shape,
input_dtype='complex128',
output_dtype='complex128',
......@@ -545,15 +547,26 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
return self._plan
class NUMPYFFT(Transform):
class SerialFFT(Transform):
"""
The numpy fft pendant of a fft object.
"""
def __init__(self, domain, codomain, use_fftw):
super(SerialFFT, self).__init__(domain, codomain)
if use_fftw and (fftw is None):
raise ImportError(
"The serial FFTW module is needed but not available.")
self._use_fftw = use_fftw
# Enable caching
if self._use_fftw:
fftw.interfaces.cache.enable()
def transform(self, val, axes, **kwargs):
"""
The pyfftw transform function.
The scalar FFT transform function.
Parameters
----------
......@@ -572,6 +585,7 @@ class NUMPYFFT(Transform):
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis in range(len(val.shape)) for axis in axes):
......@@ -625,10 +639,18 @@ class NUMPYFFT(Transform):
local_val = self._apply_mask(temp_val, mask, axes)
# perform the transformation
if self.codomain.harmonic:
result_val = np.fft.fftn(local_val, axes=axes)
if self._use_fftw:
if self.codomain.harmonic:
result_val = fftw.interfaces.numpy_fft.fftn(
local_val, axes=axes)
else:
result_val = fftw.interfaces.numpy_fft.ifftn(
local_val, axes=axes)
else:
result_val = np.fft.ifftn(local_val, axes=axes)
if self.codomain.harmonic:
result_val = np.fft.fftn(local_val, axes=axes)
else:
result_val = np.fft.ifftn(local_val, axes=axes)
# Apply domain centering mask
if reduce(lambda x, y: x + y, self.domain.zerocenter):
......
......@@ -18,7 +18,7 @@
import numpy as np
from transformation import Transformation
from rg_transforms import FFTW, NUMPYFFT
from rg_transforms import MPIFFT, SerialFFT
from nifty import RGSpace, nifty_configuration
......@@ -30,20 +30,18 @@ class RGRGTransformation(Transformation):
super(RGRGTransformation, self).__init__(domain, codomain, module)
if module is None:
if nifty_configuration['fft_module'] == 'fftw':
self._transform = FFTW(self.domain, self.codomain)
elif nifty_configuration['fft_module'] == 'numpy':
self._transform = NUMPYFFT(self.domain, self.codomain)
else:
raise ValueError('Unsupported default FFT module:' +
nifty_configuration['fft_module'])
module = nifty_configuration['fft_module']
if module == 'fftw_mpi':
self._transform = MPIFFT(self.domain, self.codomain)
elif module == 'fftw':
self._transform = SerialFFT(self.domain, self.codomain,
use_fftw=True)
elif module == 'numpy':
self._transform = SerialFFT(self.domain, self.codomain,
use_fftw=False)
else:
if module == 'fftw':
self._transform = FFTW(self.domain, self.codomain)
elif module == 'numpy':
self._transform = NUMPYFFT(self.domain, self.codomain)
else:
raise ValueError('Unsupported FFT module:' + module)
raise ValueError('Unsupported FFT module:' + module)
# ---Mandatory properties and methods---
......
......@@ -18,7 +18,7 @@
from parameterized import parameterized
from nifty import RGSpace, LMSpace, HPSpace, GLSpace, PowerSpace
from nifty.config import dependency_injector as di
from nifty.config import dependency_injector as gdi
def custom_name_func(testcase_func, param_num, param):
......@@ -36,7 +36,7 @@ def expand(*args, **kwargs):
def generate_spaces():
spaces = [RGSpace(4), PowerSpace(RGSpace((4, 4), harmonic=True)),
LMSpace(5), HPSpace(4)]
if 'pyHealpix' in di:
if 'pyHealpix' in gdi:
spaces.append(GLSpace(4))
return spaces
......
from nifty import *
#This tests if it is possible to import all of nifties methods. Experience shows this is not always possible.
pass
\ No newline at end of file
# This tests if it is possible to import all of Nifty's methods.
# Experience shows this is not always possible.
pass
......@@ -20,7 +20,7 @@ import unittest
import numpy as np
from numpy.testing import assert_equal,\
assert_allclose
from nifty.config import dependency_injector as di
from nifty.config import dependency_injector as gdi
from nifty import Field,\
RGSpace,\
LMSpace,\
......@@ -62,11 +62,15 @@ class FFTOperatorTests(unittest.TestCase):
res = foo.get_distance_array('not')
assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.)
@expand(product(["numpy", "fftw"], [10, 11], [False, True], [False, True],
@expand(product(["numpy", "fftw", "fftw_mpi"],
[10, 11], [False, True], [False, True],
[0.1, 1, 3.7],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_fft1D(self, module, dim1, zc1, zc2, d, itp):
if module == "fftw" and "pyfftw" not in di:
if module == "fftw_mpi":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
raise SkipTest
if module == "fftw" and "fftw" not in gdi:
raise SkipTest
tol = _get_rtol(itp)
a = RGSpace(dim1, zerocenter=zc1, distances=d)
......@@ -78,12 +82,16 @@ class FFTOperatorTests(unittest.TestCase):
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product(["numpy", "fftw"], [10, 11], [9, 12], [False, True],
@expand(product(["numpy", "fftw", "fftw_mpi"],
[10, 11], [9, 12], [False, True],
[False, True], [False, True], [False, True], [0.1, 1, 3.7],
[0.4, 1, 2.7],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d1, d2, itp):
if module == "fftw" and "pyfftw" not in di:
if module == "fftw_mpi":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
raise SkipTest
if module == "fftw" and "fftw" not in gdi:
raise SkipTest
tol = _get_rtol(itp)
a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=[d1, d2])
......@@ -99,7 +107,7 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([0, 3, 6, 11, 30],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_sht(self, lm, tp):
if 'pyHealpix' not in di:
if 'pyHealpix' not in gdi:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
......@@ -113,7 +121,7 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_sht2(self, lm, tp):
if 'pyHealpix' not in di:
if 'pyHealpix' not in gdi:
raise SkipTest
a = LMSpace(lmax=lm)
b = HPSpace(nside=lm//2)
......@@ -126,7 +134,7 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_dotsht(self, lm, tp):
if 'pyHealpix' not in di:
if 'pyHealpix' not in gdi:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
......@@ -142,7 +150,7 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_dotsht2(self, lm, tp):
if 'pyHealpix' not in di:
if 'pyHealpix' not in gdi:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
......
......@@ -24,7 +24,7 @@ from numpy.testing import assert_, assert_equal, assert_raises,\
assert_almost_equal
from nose.plugins.skip import SkipTest
from nifty import GLSpace
from nifty.config import dependency_injector as di
from nifty.config import dependency_injector as gdi
from test.common import expand
# [nlat, nlon, expected]
......@@ -99,7 +99,7 @@ class GLSpaceFunctionalityTests(unittest.TestCase):
except ImportError:
raise SkipTest
if 'pyHealpix' not in di:
if 'pyHealpix' not in gdi:
raise SkipTest
else:
g = GLSpace(2)
......
......@@ -28,30 +28,36 @@ from nifty import PowerSpace, RGSpace, Space, LMSpace
from types import NoneType
from test.common import expand
from itertools import product, chain
#needed to check wether fftw is available
from d2o.config import dependency_injector as gdi
# needed to check wether fftw is available
from nifty import dependency_injector as gdi
from nose.plugins.skip import SkipTest
HARMONIC_SPACES = [RGSpace((8,), harmonic=True),
RGSpace((7,), harmonic=True,zerocenter=True),
RGSpace((8,), harmonic=True,zerocenter=True),
RGSpace((7,8), harmonic=True),
RGSpace((7,8), harmonic=True, zerocenter=True),
RGSpace((6,6), harmonic=True, zerocenter=True),
RGSpace((7,5), harmonic=True, zerocenter=True),
RGSpace((5,5), harmonic=True),
RGSpace((4,5,7), harmonic=True),
RGSpace((4,5,7), harmonic=True, zerocenter=True),
LMSpace(6),
LMSpace(9)]
RGSpace((7,), harmonic=True, zerocenter=True),
RGSpace((8,), harmonic=True, zerocenter=True),
RGSpace((7, 8), harmonic=True),
RGSpace((7, 8), harmonic=True, zerocenter=True),
RGSpace((6, 6), harmonic=True, zerocenter=True),
RGSpace((7, 5), harmonic=True, zerocenter=True),
RGSpace((5, 5), harmonic=True),
RGSpace((4, 5, 7), harmonic=True),
RGSpace((4, 5, 7), harmonic=True, zerocenter=True),
LMSpace(6),
LMSpace(9)]
#Try all sensible kinds of combinations of spaces, distributuion strategy and
#binning parameters
_maybe_fftw = ["fftw"] if ('pyfftw' in gdi) else []
CONSISTENCY_CONFIGS_IMPLICIT = product(HARMONIC_SPACES, ["not", "equal"] + _maybe_fftw, [None], [None, 3,4], [True, False])
CONSISTENCY_CONFIGS_EXPLICIT = product(HARMONIC_SPACES, ["not", "equal"] + _maybe_fftw, [[0.,1.3]],[None],[False])
CONSISTENCY_CONFIGS = chain(CONSISTENCY_CONFIGS_IMPLICIT, CONSISTENCY_CONFIGS_EXPLICIT)
CONSISTENCY_CONFIGS_IMPLICIT = product(HARMONIC_SPACES,
["not", "equal", "fftw"],
[None], [None, 3, 4], [True, False])
CONSISTENCY_CONFIGS_EXPLICIT = product(HARMONIC_SPACES,
["not", "equal", "fftw"],
[[0., 1.3]], [None], [False])
CONSISTENCY_CONFIGS = chain(CONSISTENCY_CONFIGS_IMPLICIT,
CONSISTENCY_CONFIGS_EXPLICIT)
# [harmonic_partner, distribution_strategy,
# logarithmic, nbin, binbounds, expected]
......@@ -128,10 +134,29 @@ class PowerSpaceInterfaceTest(unittest.TestCase):
p = PowerSpace(r)
assert_(isinstance(getattr(p, attribute), expected_type))
class PowerSpaceConsistencyCheck(unittest.TestCase):
@expand(CONSISTENCY_CONFIGS)
def test_pipundexInversion(self, harmonic_partner, distribution_strategy,
binbounds, nbin,logarithmic):
binbounds, nbin, logarithmic):
if distribution_strategy == "fftw":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
raise SkipTest
p = PowerSpace(harmonic_partner=harmonic_partner,
distribution_strategy=distribution_strategy,
logarithmic=logarithmic, nbin=nbin,
binbounds=binbounds)
assert_equal(p.pindex.flatten()[p.pundex], np.arange(p.dim),
err_msg='pundex is not right-inverse of pindex!')
@expand(CONSISTENCY_CONFIGS)
def test_rhopindexConsistency(self, harmonic_partner,
distribution_strategy, binbounds, nbin,
logarithmic):
if distribution_strategy == "fftw":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
print (gdi.get('fftw'), "blub \n\n\n")
raise SkipTest
p = PowerSpace(harmonic_partner=harmonic_partner,
distribution_strategy=distribution_strategy,
logarithmic=logarithmic, nbin=nbin,
......@@ -139,16 +164,17 @@ class PowerSpaceConsistencyCheck(unittest.TestCase):
assert_equal(p.pindex.flatten()[p.pundex],np.arange(p.dim),
err_msg='pundex is not right-inverse of pindex!')
@expand(CONSISTENCY_CONFIGS)
def test_rhopindexConsistency(self, harmonic_partner, distribution_strategy,
binbounds, nbin,logarithmic):
assert_equal(p.pindex.flatten().bincount(), p.rho,
err_msg='rho is not equal to pindex degeneracy')
class PowerSpaceFunctionalityTest(unittest.TestCase):
@expand(CONSISTENCY_CONFIGS)
@expand(CONSTRUCTOR_CONFIGS)
def test_constructor(self, harmonic_partner, distribution_strategy,
logarithmic, nbin, binbounds, expected):
if distribution_strategy == "fftw":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
raise SkipTest
raise SkipTest
if 'error' in expected:
with assert_raises(expected['error']):
PowerSpace(harmonic_partner=harmonic_partner,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment