Commit 88d6e60d authored by Martin Reinecke's avatar Martin Reinecke
Browse files

allow pyfftw even if MPI is not present

parent 45c8584b
Pipeline #13113 passed with stage
in 5 minutes and 17 seconds
......@@ -32,6 +32,7 @@ dependency_injector = keepers.DependencyInjector(
'plotly'])
dependency_injector.register('pyfftw', lambda z: hasattr(z, 'FFTW_MPI'))
dependency_injector.register(('pyfftw','pyfftw_scalar'))
# Initialize the variables
variable_fft_module = keepers.Variable(
......
......@@ -26,6 +26,7 @@ import nifty.nifty_utilities as utilities
from keepers import Loggable
pyfftw = gdi.get('pyfftw')
pyfftw_scalar = gdi.get('pyfftw_scalar')
class Transform(Loggable, object):
......@@ -572,6 +573,10 @@ class NUMPYFFT(Transform):
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# Enable caching for pyfftw_scalar.interfaces
if 'pyfftw_scalar' in gdi:
pyfftw_scalar.interfaces.cache.enable()
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis in range(len(val.shape)) for axis in axes):
......@@ -625,10 +630,16 @@ class NUMPYFFT(Transform):
local_val = self._apply_mask(temp_val, mask, axes)
# perform the transformation
if self.codomain.harmonic:
result_val = np.fft.fftn(local_val, axes=axes)
if 'pyfftw_scalar' in gdi:
if self.codomain.harmonic:
result_val = pyfftw_scalar.interfaces.numpy_fft.fftn(local_val, axes=axes)
else:
result_val = pyfftw_scalar.interfaces.numpy_fft.ifftn(local_val, axes=axes)
else:
result_val = np.fft.ifftn(local_val, axes=axes)
if self.codomain.harmonic:
result_val = np.fft.fftn(local_val, axes=axes)
else:
result_val = np.fft.ifftn(local_val, axes=axes)
# Apply domain centering mask
if reduce(lambda x, y: x + y, self.domain.zerocenter):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment