Commit 039a629a authored by Martin Reinecke's avatar Martin Reinecke

streamline the FFT interface

parent 5ed8f324
Pipeline #29434 passed with stages
in 11 minutes and 52 seconds
......@@ -61,9 +61,7 @@ class FFTOperator(LinearOperator):
adom.check_codomain(target)
target.check_codomain(adom)
import pyfftw
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(1000.)
utilities.fft_prep()
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -74,7 +72,6 @@ class FFTOperator(LinearOperator):
return self._apply_cartesian(x, mode)
def _apply_cartesian(self, x, mode):
from pyfftw.interfaces.numpy_fft import fftn
axes = x.domain.axes[self._space]
tdom = self._target if x.domain == self._domain else self._domain
oldax = dobj.distaxis(x.val)
......@@ -110,7 +107,7 @@ class FFTOperator(LinearOperator):
tmp = dobj.from_local_data(shp2d, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
ldat2 = dobj.local_data(tmp)
ldat2 = fftn(ldat2, axes=(1,))
ldat2 = utilities.my_fftn(ldat2, axes=(1,))
ldat2 = ldat2.real+ldat2.imag
tmp = dobj.from_local_data(tmp.shape, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
......
......@@ -23,7 +23,8 @@ import abc
from future.utils import with_metaclass
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMetaBase", "hartley", "my_fftn_r2c"]
"memo", "NiftyMetaBase", "fft_prep", "hartley", "my_fftn_r2c",
"my_fftn"]
def get_slice_list(shape, axes):
......@@ -167,6 +168,17 @@ def nthreads():
return nthreads._val
nthreads._val = None
# Optional extra arguments for the FFT calls
# _fft_extra_args = {}
# if exact reproducibility is needed, use this:
_fft_extra_args = dict(planner_effort='FFTW_ESTIMATE')
def fft_prep():
import pyfftw
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(1000.)
def hartley(a, axes=None):
# Check if the axes provided are valid given the shape
......@@ -177,7 +189,7 @@ def hartley(a, axes=None):
raise TypeError("Hartley transform requires real-valued arrays.")
from pyfftw.interfaces.numpy_fft import rfftn
tmp = rfftn(a, axes=axes, threads=nthreads())
tmp = rfftn(a, axes=axes, threads=nthreads(), **_fft_extra_args)
def _fill_array(tmp, res, axes):
if axes is None:
......@@ -219,7 +231,7 @@ def my_fftn_r2c(a, axes=None):
raise TypeError("Transform requires real-valued input arrays.")
from pyfftw.interfaces.numpy_fft import rfftn
tmp = rfftn(a, axes=axes, threads=nthreads())
tmp = rfftn(a, axes=axes, threads=nthreads(), **_fft_extra_args)
def _fill_complex_array(tmp, res, axes):
if axes is None:
......@@ -250,3 +262,8 @@ def my_fftn_r2c(a, axes=None):
return res
return _fill_complex_array(tmp, np.empty_like(a, dtype=tmp.dtype), axes)
def my_fftn(a, axes=None):
from pyfftw.interfaces.numpy_fft import fftn
return fftn(a, axes=axes, **_fft_extra_args)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment