Commit c398a36a authored by Martin Reinecke's avatar Martin Reinecke
Browse files

add multithreading for FFTs

parent adc1eca5
Pipeline #27834 failed with stages
in 4 minutes and 34 seconds
......@@ -63,6 +63,7 @@ class FFTOperator(LinearOperator):
import pyfftw
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(1000.)
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -100,7 +101,8 @@ class FFTOperator(LinearOperator):
rem_axes = tuple(i for i in axes if i != oldax)
tmp = x.val
ldat = dobj.local_data(tmp)
ldat = utilities.my_fftn_r2c(ldat, axes=rem_axes)
ldat = utilities.my_fftn_r2c(ldat, axes=rem_axes,
threads=utilities.nthreads())
if oldax != 0:
raise ValueError("bad distribution")
ldat2 = ldat.reshape((ldat.shape[0],
......
......@@ -160,6 +160,15 @@ def NiftyMetaBase():
return with_metaclass(NiftyMeta, type('NewBase', (object,), {}))
def nthreads():
if nthreads._val is None:
import os
nthreads._val = int(os.getenv("OMP_NUM_THREADS", "1"))
print("value=",nthreads._val)
return nthreads._val
nthreads._val=None
def hartley(a, axes=None):
# Check if the axes provided are valid given the shape
if axes is not None and \
......@@ -169,7 +178,7 @@ def hartley(a, axes=None):
raise TypeError("Hartley transform requires real-valued arrays.")
from pyfftw.interfaces.numpy_fft import rfftn
tmp = rfftn(a, axes=axes)
tmp = rfftn(a, axes=axes, threads=nthreads())
def _fill_array(tmp, res, axes):
if axes is None:
......@@ -211,7 +220,7 @@ def my_fftn_r2c(a, axes=None):
raise TypeError("Transform requires real-valued input arrays.")
from pyfftw.interfaces.numpy_fft import rfftn
tmp = rfftn(a, axes=axes)
tmp = rfftn(a, axes=axes, threads=nthreads())
def _fill_complex_array(tmp, res, axes):
if axes is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment