Commit 88d6e60d authored by Martin Reinecke's avatar Martin Reinecke
Browse files

allow pyfftw even if MPI is not present

parent 45c8584b
Pipeline #13113 passed with stage
in 5 minutes and 17 seconds
...@@ -32,6 +32,7 @@ dependency_injector = keepers.DependencyInjector( ...@@ -32,6 +32,7 @@ dependency_injector = keepers.DependencyInjector(
'plotly']) 'plotly'])
dependency_injector.register('pyfftw', lambda z: hasattr(z, 'FFTW_MPI')) dependency_injector.register('pyfftw', lambda z: hasattr(z, 'FFTW_MPI'))
dependency_injector.register(('pyfftw','pyfftw_scalar'))
# Initialize the variables # Initialize the variables
variable_fft_module = keepers.Variable( variable_fft_module = keepers.Variable(
......
...@@ -26,6 +26,7 @@ import nifty.nifty_utilities as utilities ...@@ -26,6 +26,7 @@ import nifty.nifty_utilities as utilities
from keepers import Loggable from keepers import Loggable
pyfftw = gdi.get('pyfftw') pyfftw = gdi.get('pyfftw')
pyfftw_scalar = gdi.get('pyfftw_scalar')
class Transform(Loggable, object): class Transform(Loggable, object):
...@@ -572,6 +573,10 @@ class NUMPYFFT(Transform): ...@@ -572,6 +573,10 @@ class NUMPYFFT(Transform):
result : np.ndarray or distributed_data_object result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field. Fourier-transformed pendant of the input field.
""" """
# Enable caching for pyfftw_scalar.interfaces
if 'pyfftw_scalar' in gdi:
pyfftw_scalar.interfaces.cache.enable()
# Check if the axes provided are valid given the shape # Check if the axes provided are valid given the shape
if axes is not None and \ if axes is not None and \
not all(axis in range(len(val.shape)) for axis in axes): not all(axis in range(len(val.shape)) for axis in axes):
...@@ -625,10 +630,16 @@ class NUMPYFFT(Transform): ...@@ -625,10 +630,16 @@ class NUMPYFFT(Transform):
local_val = self._apply_mask(temp_val, mask, axes) local_val = self._apply_mask(temp_val, mask, axes)
# perform the transformation # perform the transformation
if self.codomain.harmonic: if 'pyfftw_scalar' in gdi:
result_val = np.fft.fftn(local_val, axes=axes) if self.codomain.harmonic:
result_val = pyfftw_scalar.interfaces.numpy_fft.fftn(local_val, axes=axes)
else:
result_val = pyfftw_scalar.interfaces.numpy_fft.ifftn(local_val, axes=axes)
else: else:
result_val = np.fft.ifftn(local_val, axes=axes) if self.codomain.harmonic:
result_val = np.fft.fftn(local_val, axes=axes)
else:
result_val = np.fft.ifftn(local_val, axes=axes)
# Apply domain centering mask # Apply domain centering mask
if reduce(lambda x, y: x + y, self.domain.zerocenter): if reduce(lambda x, y: x + y, self.domain.zerocenter):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment