Commit 891b5454 authored by Peter Bell's avatar Peter Bell

Better parallelisability heuristic

parent 0b65b8fb
...@@ -566,11 +566,16 @@ struct util // hack to avoid duplicate symbols ...@@ -566,11 +566,16 @@ struct util // hack to avoid duplicate symbols
} }
static size_t thread_count (size_t nthreads, const shape_t &shape, static size_t thread_count (size_t nthreads, const shape_t &shape,
size_t axis) size_t axis, size_t vlen)
{ {
if (nthreads==1) return 1; if (nthreads==1) return 1;
if (prod(shape) < 20*shape[axis]) return 1; size_t size = prod(shape);
return (nthreads==0) ? thread::hardware_concurrency() : nthreads; size_t parallel = size / (shape[axis] * vlen);
if (shape[axis] < 1000)
parallel /= 4;
size_t max_threads = nthreads == 0 ?
thread::hardware_concurrency() : nthreads;
return max(size_t(1), min(parallel, max_threads));
} }
}; };
...@@ -3086,7 +3091,7 @@ POCKETFFT_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out, ...@@ -3086,7 +3091,7 @@ POCKETFFT_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
if ((!plan) || (len!=plan->length())) if ((!plan) || (len!=plan->length()))
plan = get_plan<Tplan>(len); plan = get_plan<Tplan>(len);
threading::thread_map(util::thread_count(nthreads, in.shape(), axes[iax]), threading::thread_map(util::thread_count(nthreads, in.shape(), axes[iax], vlen),
[&] { [&] {
auto storage = alloc_tmp<T0>(in.shape(), len, sizeof(T)); auto storage = alloc_tmp<T0>(in.shape(), len, sizeof(T));
const auto &tin(iax==0? in : out); const auto &tin(iax==0? in : out);
...@@ -3192,7 +3197,7 @@ template<typename T> POCKETFFT_NOINLINE void general_r2c( ...@@ -3192,7 +3197,7 @@ template<typename T> POCKETFFT_NOINLINE void general_r2c(
auto plan = get_plan<pocketfft_r<T>>(in.shape(axis)); auto plan = get_plan<pocketfft_r<T>>(in.shape(axis));
constexpr auto vlen = VLEN<T>::val; constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axis); size_t len=in.shape(axis);
threading::thread_map(util::thread_count(nthreads, in.shape(), axis), threading::thread_map(util::thread_count(nthreads, in.shape(), axis, vlen),
[&] { [&] {
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T)); auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
multi_iter<vlen> it(in, out, axis); multi_iter<vlen> it(in, out, axis);
...@@ -3246,7 +3251,7 @@ template<typename T> POCKETFFT_NOINLINE void general_c2r( ...@@ -3246,7 +3251,7 @@ template<typename T> POCKETFFT_NOINLINE void general_c2r(
auto plan = get_plan<pocketfft_r<T>>(out.shape(axis)); auto plan = get_plan<pocketfft_r<T>>(out.shape(axis));
constexpr auto vlen = VLEN<T>::val; constexpr auto vlen = VLEN<T>::val;
size_t len=out.shape(axis); size_t len=out.shape(axis);
threading::thread_map(util::thread_count(nthreads, in.shape(), axis), threading::thread_map(util::thread_count(nthreads, in.shape(), axis, vlen),
[&] { [&] {
auto storage = alloc_tmp<T>(out.shape(), len, sizeof(T)); auto storage = alloc_tmp<T>(out.shape(), len, sizeof(T));
multi_iter<vlen> it(in, out, axis); multi_iter<vlen> it(in, out, axis);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment