Planned maintenance on Wednesday, 2021-01-20, 17:00-18:00. Expect some interruptions during that time

Commit 736e7b2d authored by Martin Reinecke's avatar Martin Reinecke

more OpenMP work

parent 2970b68c
...@@ -463,15 +463,18 @@ struct util // hack to avoid duplicate symbols ...@@ -463,15 +463,18 @@ struct util // hack to avoid duplicate symbols
} }
#ifdef POCKETFFT_OPENMP #ifdef POCKETFFT_OPENMP
static int nthreads() { return omp_get_num_threads(); } static size_t nthreads() { return omp_get_num_threads(); }
static int thread_num() { return omp_get_thread_num(); } static size_t thread_num() { return omp_get_thread_num(); }
static bool run_parallel (const shape_t &shape, size_t axis) static size_t thread_count (size_t nthreads, const shape_t &shape,
{ return prod(shape)/shape[axis] > 20; } // FIXME, needs improvement size_t axis)
{
if (nthreads==1) return 1;
if (prod(shape)/shape[axis] < 20) return 1;
return (nthreads==0) ? omp_get_max_threads() : nthreads;
}
#else #else
static int nthreads() { return 1; } static size_t nthreads() { return 1; }
static int thread_num() { return 0; } static size_t thread_num() { return 0; }
static bool run_parallel (const shape_t &, size_t)
{ return false; }
#endif #endif
}; };
...@@ -2107,14 +2110,15 @@ template<size_t N, typename Ti, typename To> class multi_iter ...@@ -2107,14 +2110,15 @@ template<size_t N, typename Ti, typename To> class multi_iter
} }
public: public:
multi_iter(const ndarr<Ti> &iarr_, ndarr<To> &oarr_, size_t idim_, multi_iter(const ndarr<Ti> &iarr_, ndarr<To> &oarr_, size_t idim_)
size_t nshares=1, size_t myshare=0)
: pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0), : pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0),
str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)), str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)),
idim(idim_), rem(iarr.size()/iarr.shape(idim)) idim(idim_), rem(iarr.size()/iarr.shape(idim))
{ {
auto nshares = util::nthreads();
if (nshares==1) return; if (nshares==1) return;
if (nshares==0) throw runtime_error("can't run with zero threads"); if (nshares==0) throw runtime_error("can't run with zero threads");
auto myshare = util::thread_num();
if (myshare>=nshares) throw runtime_error("impossible share requested"); if (myshare>=nshares) throw runtime_error("impossible share requested");
size_t nbase = rem/nshares; size_t nbase = rem/nshares;
size_t additional = rem%nshares; size_t additional = rem%nshares;
...@@ -2218,12 +2222,11 @@ template<typename T> NOINLINE void general_c( ...@@ -2218,12 +2222,11 @@ template<typename T> NOINLINE void general_c(
plan.reset(new pocketfft_c<T>(len)); plan.reset(new pocketfft_c<T>(len));
#ifdef POCKETFFT_OPENMP #ifdef POCKETFFT_OPENMP
#pragma omp parallel if(util::run_parallel(in.shape(), axes[iax])) num_threads(nthreads) #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
#endif #endif
{ {
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(cmplx<T>)); auto storage = alloc_tmp<T>(in.shape(), len, sizeof(cmplx<T>));
multi_iter<vlen, cmplx<T>, cmplx<T>> it(iax==0? in : out, out, axes[iax], multi_iter<vlen, cmplx<T>, cmplx<T>> it(iax==0? in : out, out, axes[iax]);
util::nthreads(), util::thread_num());
#if defined(HAVE_VECSUPPORT) #if defined(HAVE_VECSUPPORT)
if (vlen>1) if (vlen>1)
while (it.remaining()>=vlen) while (it.remaining()>=vlen)
...@@ -2282,12 +2285,11 @@ template<typename T> NOINLINE void general_hartley( ...@@ -2282,12 +2285,11 @@ template<typename T> NOINLINE void general_hartley(
plan.reset(new pocketfft_r<T>(len)); plan.reset(new pocketfft_r<T>(len));
#ifdef POCKETFFT_OPENMP #ifdef POCKETFFT_OPENMP
#pragma omp parallel if(util::run_parallel(in.shape(), axes[iax])) num_threads(nthreads) #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
#endif #endif
{ {
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T)); auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
multi_iter<vlen, T, T> it(iax==0 ? in : out, out, axes[iax], multi_iter<vlen, T, T> it(iax==0 ? in : out, out, axes[iax]);
util::nthreads(), util::thread_num());
#if defined(HAVE_VECSUPPORT) #if defined(HAVE_VECSUPPORT)
if (vlen>1) if (vlen>1)
while (it.remaining()>=vlen) while (it.remaining()>=vlen)
...@@ -2341,12 +2343,11 @@ template<typename T> NOINLINE void general_r2c( ...@@ -2341,12 +2343,11 @@ template<typename T> NOINLINE void general_r2c(
constexpr int vlen = VTYPE<T>::vlen; constexpr int vlen = VTYPE<T>::vlen;
size_t len=in.shape(axis); size_t len=in.shape(axis);
#ifdef POCKETFFT_OPENMP #ifdef POCKETFFT_OPENMP
#pragma omp parallel if(util::run_parallel(in.shape(), axis)) num_threads(nthreads) #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axis))
#endif #endif
{ {
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T)); auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
multi_iter<vlen, T, cmplx<T>> it(in, out, axis, multi_iter<vlen, T, cmplx<T>> it(in, out, axis);
util::nthreads(), util::thread_num());
#if defined(HAVE_VECSUPPORT) #if defined(HAVE_VECSUPPORT)
if (vlen>1) if (vlen>1)
while (it.remaining()>=vlen) while (it.remaining()>=vlen)
...@@ -2393,12 +2394,11 @@ template<typename T> NOINLINE void general_c2r( ...@@ -2393,12 +2394,11 @@ template<typename T> NOINLINE void general_c2r(
constexpr int vlen = VTYPE<T>::vlen; constexpr int vlen = VTYPE<T>::vlen;
size_t len=out.shape(axis); size_t len=out.shape(axis);
#ifdef POCKETFFT_OPENMP #ifdef POCKETFFT_OPENMP
#pragma omp parallel if(util::run_parallel(in.shape(), axis)) num_threads(nthreads) #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axis))
#endif #endif
{ {
auto storage = alloc_tmp<T>(out.shape(), len, sizeof(T)); auto storage = alloc_tmp<T>(out.shape(), len, sizeof(T));
multi_iter<vlen, cmplx<T>, T> it(in, out, axis, multi_iter<vlen, cmplx<T>, T> it(in, out, axis);
util::nthreads(), util::thread_num());
#if defined(HAVE_VECSUPPORT) #if defined(HAVE_VECSUPPORT)
if (vlen>1) if (vlen>1)
while (it.remaining()>=vlen) while (it.remaining()>=vlen)
...@@ -2450,12 +2450,11 @@ template<typename T> NOINLINE void general_r( ...@@ -2450,12 +2450,11 @@ template<typename T> NOINLINE void general_r(
size_t len=in.shape(axis); size_t len=in.shape(axis);
pocketfft_r<T> plan(len); pocketfft_r<T> plan(len);
#ifdef POCKETFFT_OPENMP #ifdef POCKETFFT_OPENMP
#pragma omp parallel if(util::run_parallel(in.shape(), axis)) num_threads(nthreads) #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axis))
#endif #endif
{ {
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T)); auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
multi_iter<vlen, T, T> it(in, out, axis, multi_iter<vlen, T, T> it(in, out, axis);
util::nthreads(), util::thread_num());
#if defined(HAVE_VECSUPPORT) #if defined(HAVE_VECSUPPORT)
if (vlen>1) if (vlen>1)
while (it.remaining()>=vlen) while (it.remaining()>=vlen)
......
...@@ -93,10 +93,12 @@ py::array xfftn(const py::array &a, py::object axes, double fct, bool inplace, ...@@ -93,10 +93,12 @@ py::array xfftn(const py::array &a, py::object axes, double fct, bool inplace,
inplace, fwd, nthreads)) inplace, fwd, nthreads))
} }
py::array fftn(const py::array &a, py::object axes, double fct, bool inplace, size_t nthreads) py::array fftn(const py::array &a, py::object axes, double fct, bool inplace,
size_t nthreads)
{ return xfftn(a, axes, fct, inplace, true, nthreads); } { return xfftn(a, axes, fct, inplace, true, nthreads); }
py::array ifftn(const py::array &a, py::object axes, double fct, bool inplace, size_t nthreads) py::array ifftn(const py::array &a, py::object axes, double fct, bool inplace,
size_t nthreads)
{ return xfftn(a, axes, fct, inplace, false, nthreads); } { return xfftn(a, axes, fct, inplace, false, nthreads); }
template<typename T> py::array rfftn_internal(const py::array &in, template<typename T> py::array rfftn_internal(const py::array &in,
...@@ -112,7 +114,8 @@ template<typename T> py::array rfftn_internal(const py::array &in, ...@@ -112,7 +114,8 @@ template<typename T> py::array rfftn_internal(const py::array &in,
return res; return res;
} }
py::array rfftn(const py::array &in, py::object axes_, double fct, size_t nthreads) py::array rfftn(const py::array &in, py::object axes_, double fct,
size_t nthreads)
{ {
DISPATCH(in, f64, f32, f128, rfftn_internal, (in, axes_, fct, nthreads)) DISPATCH(in, f64, f32, f128, rfftn_internal, (in, axes_, fct, nthreads))
} }
...@@ -128,15 +131,18 @@ template<typename T> py::array xrfft_scipy(const py::array &in, ...@@ -128,15 +131,18 @@ template<typename T> py::array xrfft_scipy(const py::array &in,
return res; return res;
} }
py::array rfft_scipy(const py::array &in, size_t axis, double fct, bool inplace, size_t nthreads) py::array rfft_scipy(const py::array &in, size_t axis, double fct, bool inplace,
size_t nthreads)
{ {
DISPATCH(in, f64, f32, f128, xrfft_scipy, (in, axis, fct, inplace, true, nthreads)) DISPATCH(in, f64, f32, f128, xrfft_scipy, (in, axis, fct, inplace, true,
nthreads))
} }
py::array irfft_scipy(const py::array &in, size_t axis, double fct, py::array irfft_scipy(const py::array &in, size_t axis, double fct,
bool inplace, size_t nthreads) bool inplace, size_t nthreads)
{ {
DISPATCH(in, f64, f32, f128, xrfft_scipy, (in, axis, fct, inplace, false, nthreads)) DISPATCH(in, f64, f32, f128, xrfft_scipy, (in, axis, fct, inplace, false,
nthreads))
} }
template<typename T> py::array irfftn_internal(const py::array &in, template<typename T> py::array irfftn_internal(const py::array &in,
py::object axes_, size_t lastsize, T fct, size_t nthreads) py::object axes_, size_t lastsize, T fct, size_t nthreads)
...@@ -158,7 +164,8 @@ template<typename T> py::array irfftn_internal(const py::array &in, ...@@ -158,7 +164,8 @@ template<typename T> py::array irfftn_internal(const py::array &in,
py::array irfftn(const py::array &in, py::object axes_, size_t lastsize, py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
double fct, size_t nthreads) double fct, size_t nthreads)
{ {
DISPATCH(in, c128, c64, c256, irfftn_internal, (in, axes_, lastsize, fct, nthreads)) DISPATCH(in, c128, c64, c256, irfftn_internal, (in, axes_, lastsize, fct,
nthreads))
} }
template<typename T> py::array hartley_internal(const py::array &in, template<typename T> py::array hartley_internal(const py::array &in,
...@@ -175,7 +182,8 @@ template<typename T> py::array hartley_internal(const py::array &in, ...@@ -175,7 +182,8 @@ template<typename T> py::array hartley_internal(const py::array &in,
py::array hartley(const py::array &in, py::object axes_, double fct, py::array hartley(const py::array &in, py::object axes_, double fct,
bool inplace, size_t nthreads) bool inplace, size_t nthreads)
{ {
DISPATCH(in, f64, f32, f128, hartley_internal, (in, axes_, fct, inplace, nthreads)) DISPATCH(in, f64, f32, f128, hartley_internal, (in, axes_, fct, inplace,
nthreads))
} }
template<typename T>py::array complex2hartley(const py::array &in, template<typename T>py::array complex2hartley(const py::array &in,
...@@ -229,9 +237,12 @@ py::array mycomplex2hartley(const py::array &in, ...@@ -229,9 +237,12 @@ py::array mycomplex2hartley(const py::array &in,
py::array hartley2(const py::array &in, py::object axes_, double fct, py::array hartley2(const py::array &in, py::object axes_, double fct,
bool inplace, size_t nthreads) bool inplace, size_t nthreads)
{ return mycomplex2hartley(in, rfftn(in, axes_, fct, nthreads), axes_, inplace); } {
return mycomplex2hartley(in, rfftn(in, axes_, fct, nthreads), axes_,
inplace);
}
const char *pypocketfft_DS = R"DELIM(Fast Fourier and Hartley transforms. const char *pypocketfft_DS = R"""(Fast Fourier and Hartley transforms.
This module supports This module supports
- single and double precision - single and double precision
...@@ -241,9 +252,9 @@ This module supports ...@@ -241,9 +252,9 @@ This module supports
For two- and higher-dimensional transforms the code will use SSE2 and AVX For two- and higher-dimensional transforms the code will use SSE2 and AVX
vector instructions for faster execution if these are supported by the CPU and vector instructions for faster execution if these are supported by the CPU and
were enabled during compilation. were enabled during compilation.
)DELIM"; )""";
const char *fftn_DS = R"DELIM( const char *fftn_DS = R"""(
Performs a forward complex FFT. Performs a forward complex FFT.
Parameters Parameters
...@@ -263,9 +274,9 @@ Returns ...@@ -263,9 +274,9 @@ Returns
------- -------
np.ndarray (same shape and data type as a) np.ndarray (same shape and data type as a)
The transformed data. The transformed data.
)DELIM"; )""";
const char *ifftn_DS = R"DELIM(Performs a backward complex FFT. const char *ifftn_DS = R"""(Performs a backward complex FFT.
Parameters Parameters
---------- ----------
...@@ -284,9 +295,9 @@ Returns ...@@ -284,9 +295,9 @@ Returns
------- -------
np.ndarray (same shape and data type as a) np.ndarray (same shape and data type as a)
The transformed data The transformed data
)DELIM"; )""";
const char *rfftn_DS = R"DELIM(Performs a forward real-valued FFT. const char *rfftn_DS = R"""(Performs a forward real-valued FFT.
Parameters Parameters
---------- ----------
...@@ -304,9 +315,9 @@ np.ndarray (np.complex64 or np.complex128) ...@@ -304,9 +315,9 @@ np.ndarray (np.complex64 or np.complex128)
The transformed data. The shape is identical to that of the input array, The transformed data. The shape is identical to that of the input array,
except for the axis that was transformed last. If the length of that axis except for the axis that was transformed last. If the length of that axis
was n on input, it is n//2+1 on output. was n on input, it is n//2+1 on output.
)DELIM"; )""";
const char *rfft_scipy_DS = R"DELIM(Performs a forward real-valued FFT. const char *rfft_scipy_DS = R"""(Performs a forward real-valued FFT.
Parameters Parameters
---------- ----------
...@@ -326,9 +337,9 @@ np.ndarray (np.float32 or np.float64) ...@@ -326,9 +337,9 @@ np.ndarray (np.float32 or np.float64)
The transformed data. The shape is identical to that of the input array. The transformed data. The shape is identical to that of the input array.
Along the transformed axis, values are arranged in Along the transformed axis, values are arranged in
FFTPACK half-complex order, i.e. `a[0].re, a[1].re, a[1].im, a[2].re ...`. FFTPACK half-complex order, i.e. `a[0].re, a[1].re, a[1].im, a[2].re ...`.
)DELIM"; )""";
const char *irfftn_DS = R"DELIM(Performs a backward real-valued FFT. const char *irfftn_DS = R"""(Performs a backward real-valued FFT.
Parameters Parameters
---------- ----------
...@@ -348,9 +359,9 @@ np.ndarray (np.float32 or np.float64) ...@@ -348,9 +359,9 @@ np.ndarray (np.float32 or np.float64)
The transformed data. The shape is identical to that of the input array, The transformed data. The shape is identical to that of the input array,
except for the axis that was transformed last, which has now `lastsize` except for the axis that was transformed last, which has now `lastsize`
entries. entries.
)DELIM"; )""";
const char *irfft_scipy_DS = R"DELIM(Performs a backward real-valued FFT. const char *irfft_scipy_DS = R"""(Performs a backward real-valued FFT.
Parameters Parameters
---------- ----------
...@@ -369,9 +380,9 @@ Returns ...@@ -369,9 +380,9 @@ Returns
------- -------
np.ndarray (np.float32 or np.float64) np.ndarray (np.float32 or np.float64)
The transformed data. The shape is identical to that of the input array. The transformed data. The shape is identical to that of the input array.
)DELIM"; )""";
const char *hartley_DS = R"DELIM(Performs a Hartley transform. const char *hartley_DS = R"""(Performs a Hartley transform.
For every requested axis, a 1D forward Fourier transform is carried out, For every requested axis, a 1D forward Fourier transform is carried out,
and the sum of real and imaginary parts of the result is stored in the output and the sum of real and imaginary parts of the result is stored in the output
array. array.
...@@ -393,7 +404,7 @@ Returns ...@@ -393,7 +404,7 @@ Returns
------- -------
np.ndarray (same shape and data type as a) np.ndarray (same shape and data type as a)
The transformed data The transformed data
)DELIM"; )""";
} // unnamed namespace } // unnamed namespace
...@@ -406,7 +417,8 @@ PYBIND11_MODULE(pypocketfft, m) ...@@ -406,7 +417,8 @@ PYBIND11_MODULE(pypocketfft, m)
"inplace"_a=false, "nthreads"_a=1); "inplace"_a=false, "nthreads"_a=1);
m.def("ifftn",&ifftn, ifftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1., m.def("ifftn",&ifftn, ifftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.,
"inplace"_a=false, "nthreads"_a=1); "inplace"_a=false, "nthreads"_a=1);
m.def("rfftn",&rfftn, rfftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1., "nthreads"_a=1); m.def("rfftn",&rfftn, rfftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.,
"nthreads"_a=1);
m.def("rfft_scipy",&rfft_scipy, rfft_scipy_DS, "a"_a, "axis"_a, "fct"_a=1., m.def("rfft_scipy",&rfft_scipy, rfft_scipy_DS, "a"_a, "axis"_a, "fct"_a=1.,
"inplace"_a=false, "nthreads"_a=1); "inplace"_a=false, "nthreads"_a=1);
m.def("irfftn",&irfftn, irfftn_DS, "a"_a, "axes"_a=py::none(), "lastsize"_a=0, m.def("irfftn",&irfftn, irfftn_DS, "a"_a, "axes"_a=py::none(), "lastsize"_a=0,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment