There is a maintenance of MPCDF Gitlab on Thursday, April 22st 2020, 9:00 am CEST - Expect some service interruptions during this time

Commit 3aad03b5 authored by Peter Bell's avatar Peter Bell

Basic thread pool implementation

Still not fork safe
parent a9a6e6ab
......@@ -62,9 +62,13 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <array>
#include <mutex>
#endif
#ifdef POCKETFFT_OPENMP
#include <omp.h>
#endif
#include <mutex>
#include <condition_variable>
#include <thread>
#include <queue>
#include <atomic>
#include <functional>
#if defined(__GNUC__)
......@@ -557,22 +561,167 @@ struct util // hack to avoid duplicate symbols
if (axis>=shape.size()) throw invalid_argument("bad axis number");
}
#ifdef POCKETFFT_OPENMP
static size_t nthreads() { return size_t(omp_get_num_threads()); }
static size_t thread_num() { return size_t(omp_get_thread_num()); }
static size_t thread_count (size_t nthreads, const shape_t &shape,
size_t axis)
{
if (nthreads==1) return 1;
if (prod(shape) < 20*shape[axis]) return 1;
return (nthreads==0) ? size_t(omp_get_max_threads()) : nthreads;
return (nthreads==0) ? thread::hardware_concurrency() : nthreads;
}
};
namespace threading {
thread_local size_t thread_id = 0;
thread_local size_t num_threads = 1;
class latch
{
atomic<size_t> num_left_;
mutex mut_;
condition_variable completed_;
using lock_t = unique_lock<mutex>;
public:
latch(size_t n): num_left_(n) {}
void count_down()
{
{
lock_t lock(mut_);
if (--num_left_)
return;
}
completed_.notify_all();
}
void wait()
{
lock_t lock(mut_);
completed_.wait(lock, [this]{ return is_ready(); });
}
bool is_ready() { return num_left_ == 0; }
};
template <typename T> class concurrent_queue
{
queue<T> q_;
mutex mut_;
condition_variable item_added_;
bool shutdown_;
using lock_t = unique_lock<mutex>;
public:
concurrent_queue(): shutdown_(false) {}
void push(T val)
{
{
lock_t lock(mut_);
if (shutdown_)
throw runtime_error("Item added to queue after shutdown");
q_.push(move(val));
}
item_added_.notify_one();
}
bool pop(T & val)
{
lock_t lock(mut_);
item_added_.wait(lock, [this] { return (!q_.empty() || shutdown_); });
if (q_.empty())
return false; // We are shutting down
val = std::move(q_.front());
q_.pop();
return true;
}
void shutdown()
{
{
lock_t lock(mut_);
shutdown_ = true;
}
item_added_.notify_all();
}
#else
static constexpr size_t nthreads() { return 1; }
static constexpr size_t thread_num() { return 0; }
#endif
};
class thread_pool
{
concurrent_queue<function<void()>> work_queue_;
vector<thread> threads_;
void worker_main()
{
function<void()> work;
while (work_queue_.pop(work))
work();
}
public:
explicit thread_pool(size_t nthreads):
threads_(nthreads)
{
for (size_t i=0; i<nthreads; ++i)
{
try { threads_[i] = thread([this]{ worker_main(); }); }
catch (...)
{
work_queue_.shutdown();
for (size_t j = 0; j < i; ++j)
threads_[j].join();
throw;
}
}
}
thread_pool(): thread_pool(thread::hardware_concurrency()) {}
~thread_pool()
{
work_queue_.shutdown();
for (size_t i=0; i<threads_.size(); ++i)
threads_[i].join();
}
void submit(function<void()> work)
{
work_queue_.push(move(work));
}
};
thread_pool & get_pool()
{
static thread_pool pool;
return pool;
}
/** Map a function f over nthreads */
template <typename Func>
void thread_map(size_t nthreads, Func f)
{
auto & pool = get_pool();
if (nthreads == 0)
nthreads = thread::hardware_concurrency();
if (nthreads == 1)
{ f(); return; }
latch counter(nthreads);
for (size_t i=0; i<nthreads; ++i)
{
pool.submit(
[&f, &counter, i, nthreads] {
thread_id = i;
num_threads = nthreads;
f();
counter.count_down();
});
}
counter.wait();
}
}
//
// complex FFTPACK transforms
//
......@@ -2661,10 +2810,10 @@ template<size_t N> class multi_iter
str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)),
idim(idim_), rem(iarr.size()/iarr.shape(idim))
{
auto nshares = util::nthreads();
auto nshares = threading::num_threads;
if (nshares==1) return;
if (nshares==0) throw runtime_error("can't run with zero threads");
auto myshare = util::thread_num();
auto myshare = threading::thread_id;
if (myshare>=nshares) throw runtime_error("impossible share requested");
size_t nbase = rem/nshares;
size_t additional = rem%nshares;
......@@ -2838,12 +2987,6 @@ template<typename T> arr<char> alloc_tmp(const shape_t &shape,
return arr<char>(tmpsize*elemsize);
}
#ifdef POCKETFFT_OPENMP
#define POCKETFFT_NTHREADS nthreads
#else
#define POCKETFFT_NTHREADS
#endif
template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
const cndarr<cmplx<T>> &src, cmplx<vtype_t<T>> *POCKETFFT_RESTRICT dst)
{
......@@ -2902,7 +3045,7 @@ template <typename T> using add_vec_t = typename add_vec<T>::type;
template<typename Tplan, typename T, typename T0, typename Exec>
POCKETFFT_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
const shape_t &axes, T0 fct, size_t POCKETFFT_NTHREADS, const Exec & exec,
const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec,
const bool allow_inplace=true)
{
shared_ptr<Tplan> plan;
......@@ -2914,10 +3057,8 @@ POCKETFFT_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
if ((!plan) || (len!=plan->length()))
plan = get_plan<Tplan>(len);
#ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
#endif
{
threading::thread_map(util::thread_count(nthreads, in.shape(), axes[iax]),
[&] {
auto storage = alloc_tmp<T0>(in.shape(), len, sizeof(T));
const auto &tin(iax==0? in : out);
multi_iter<vlen> it(tin, out, axes[iax]);
......@@ -2937,7 +3078,7 @@ POCKETFFT_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
&out[it.oofs(0)] : reinterpret_cast<T *>(storage.data());
exec(it, tin, out, buf, *plan, fct);
}
} // end of parallel region
}); // end of parallel region
fct = T0(1); // factor has been applied, use 1 for remaining axes
}
}
......@@ -3017,15 +3158,13 @@ struct ExecDcst
template<typename T> POCKETFFT_NOINLINE void general_r2c(
const cndarr<T> &in, ndarr<cmplx<T>> &out, size_t axis, bool forward, T fct,
size_t POCKETFFT_NTHREADS)
size_t nthreads)
{
auto plan = get_plan<pocketfft_r<T>>(in.shape(axis));
constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axis);
#ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axis))
#endif
{
threading::thread_map(util::thread_count(nthreads, in.shape(), axis),
[&] {
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
multi_iter<vlen> it(in, out, axis);
#ifndef POCKETFFT_NO_VECTORS
......@@ -3069,19 +3208,17 @@ template<typename T> POCKETFFT_NOINLINE void general_r2c(
if (i<len)
out[it.oofs(ii)].Set(tdata[i]);
}
} // end of parallel region
}); // end of parallel region
}
template<typename T> POCKETFFT_NOINLINE void general_c2r(
const cndarr<cmplx<T>> &in, ndarr<T> &out, size_t axis, bool forward, T fct,
size_t POCKETFFT_NTHREADS)
size_t nthreads)
{
auto plan = get_plan<pocketfft_r<T>>(out.shape(axis));
constexpr auto vlen = VLEN<T>::val;
size_t len=out.shape(axis);
#ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axis))
#endif
{
threading::thread_map(util::thread_count(nthreads, in.shape(), axis),
[&] {
auto storage = alloc_tmp<T>(out.shape(), len, sizeof(T));
multi_iter<vlen> it(in, out, axis);
#ifndef POCKETFFT_NO_VECTORS
......@@ -3129,7 +3266,7 @@ template<typename T> POCKETFFT_NOINLINE void general_c2r(
plan->exec(tdata, fct, false);
copy_output(it, tdata, out);
}
} // end of parallel region
}); // end of parallel region
}
struct ExecR2R
......@@ -3152,8 +3289,6 @@ struct ExecR2R
}
};
#undef POCKETFFT_NTHREADS
template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,
const stride_t &stride_out, const shape_t &axes, bool forward,
const complex<T> *data_in, complex<T> *data_out, T fct,
......
......@@ -23,8 +23,8 @@ if sys.platform == 'darwin':
vars['LDSHARED'] = vars['LDSHARED'].replace('-bundle', '')
python_module_link_args += ['-bundle']
else:
extra_compile_args += ['-DPOCKETFFT_OPENMP', '-fopenmp', '-Wfatal-errors', '-Wfloat-conversion', '-Wsign-conversion', '-Wconversion' ,'-W', '-Wall', '-Wstrict-aliasing=2', '-Wwrite-strings', '-Wredundant-decls', '-Woverloaded-virtual', '-Wcast-qual', '-Wcast-align', '-Wpointer-arith']
python_module_link_args += ['-march=native', '-Wl,-rpath,$ORIGIN', '-fopenmp']
extra_compile_args += ['-Wfatal-errors', '-Wfloat-conversion', '-Wsign-conversion', '-Wconversion' ,'-W', '-Wall', '-Wstrict-aliasing=2', '-Wwrite-strings', '-Wredundant-decls', '-Woverloaded-virtual', '-Wcast-qual', '-Wcast-align', '-Wpointer-arith']
python_module_link_args += ['-march=native', '-Wl,-rpath,$ORIGIN']
# if you don't want debugging info, add "-s" to python_module_link_args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment