Commit 0b65b8fb authored by Peter Bell's avatar Peter Bell
Browse files

Handle fork correctly in thread pool

parent 3aad03b5
...@@ -70,6 +70,10 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ...@@ -70,6 +70,10 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <atomic> #include <atomic>
#include <functional> #include <functional>
#ifdef POCKETFFT_PTHREADS
# include <pthread.h>
#endif
#if defined(__GNUC__) #if defined(__GNUC__)
#define POCKETFFT_NOINLINE __attribute__((noinline)) #define POCKETFFT_NOINLINE __attribute__((noinline))
...@@ -604,11 +608,11 @@ class latch ...@@ -604,11 +608,11 @@ class latch
template <typename T> class concurrent_queue template <typename T> class concurrent_queue
{ {
queue<T> q_; queue<T> q_;
mutex mut_; mutex mut_;
condition_variable item_added_; condition_variable item_added_;
bool shutdown_; bool shutdown_;
using lock_t = unique_lock<mutex>; using lock_t = unique_lock<mutex>;
public: public:
concurrent_queue(): shutdown_(false) {} concurrent_queue(): shutdown_(false) {}
...@@ -644,6 +648,8 @@ template <typename T> class concurrent_queue ...@@ -644,6 +648,8 @@ template <typename T> class concurrent_queue
} }
item_added_.notify_all(); item_added_.notify_all();
} }
void restart() { shutdown_ = false; }
}; };
class thread_pool class thread_pool
...@@ -658,41 +664,64 @@ class thread_pool ...@@ -658,41 +664,64 @@ class thread_pool
work(); work();
} }
public: void create_threads()
explicit thread_pool(size_t nthreads):
threads_(nthreads)
{ {
size_t nthreads = threads_.size();
for (size_t i=0; i<nthreads; ++i) for (size_t i=0; i<nthreads; ++i)
{ {
try { threads_[i] = thread([this]{ worker_main(); }); } try { threads_[i] = thread([this]{ worker_main(); }); }
catch (...) catch (...)
{ {
work_queue_.shutdown(); shutdown();
for (size_t j = 0; j < i; ++j)
threads_[j].join();
throw; throw;
} }
} }
} }
public:
explicit thread_pool(size_t nthreads):
threads_(nthreads)
{ create_threads(); }
thread_pool(): thread_pool(thread::hardware_concurrency()) {} thread_pool(): thread_pool(thread::hardware_concurrency()) {}
~thread_pool() ~thread_pool() { shutdown(); }
void submit(function<void()> work)
{
work_queue_.push(move(work));
}
void shutdown()
{ {
work_queue_.shutdown(); work_queue_.shutdown();
for (size_t i=0; i<threads_.size(); ++i) for (auto &thread : threads_)
threads_[i].join(); if (thread.joinable())
thread.join();
} }
void submit(function<void()> work) void restart()
{ {
work_queue_.push(move(work)); work_queue_.restart();
create_threads();
} }
}; };
thread_pool & get_pool() thread_pool & get_pool()
{ {
static thread_pool pool; static thread_pool pool;
#ifdef POCKETFFT_PTHREADS
static once_flag f;
call_once(f,
[]{
pthread_atfork(
+[]{ get_pool().shutdown(); }, // prepare
+[]{ get_pool().restart(); }, // parent
+[]{ get_pool().restart(); } // child
);
});
#endif
return pool; return pool;
} }
...@@ -700,13 +729,13 @@ thread_pool & get_pool() ...@@ -700,13 +729,13 @@ thread_pool & get_pool()
template <typename Func> template <typename Func>
void thread_map(size_t nthreads, Func f) void thread_map(size_t nthreads, Func f)
{ {
auto & pool = get_pool();
if (nthreads == 0) if (nthreads == 0)
nthreads = thread::hardware_concurrency(); nthreads = thread::hardware_concurrency();
if (nthreads == 1) if (nthreads == 1)
{ f(); return; } { f(); return; }
auto & pool = get_pool();
latch counter(nthreads); latch counter(nthreads);
for (size_t i=0; i<nthreads; ++i) for (size_t i=0; i<nthreads; ++i)
{ {
......
...@@ -15,6 +15,7 @@ include_dirs = ['./', _deferred_pybind11_include(True), ...@@ -15,6 +15,7 @@ include_dirs = ['./', _deferred_pybind11_include(True),
_deferred_pybind11_include()] _deferred_pybind11_include()]
extra_compile_args = ['--std=c++11', '-march=native', '-O3'] extra_compile_args = ['--std=c++11', '-march=native', '-O3']
python_module_link_args = [] python_module_link_args = []
define_macros = [('POCKETFFT_PTHREADS', None)]
if sys.platform == 'darwin': if sys.platform == 'darwin':
import distutils.sysconfig import distutils.sysconfig
...@@ -35,6 +36,7 @@ def get_extension_modules(): ...@@ -35,6 +36,7 @@ def get_extension_modules():
sources=['pypocketfft.cc'], sources=['pypocketfft.cc'],
depends=['pocketfft_hdronly.h'], depends=['pocketfft_hdronly.h'],
include_dirs=include_dirs, include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
extra_link_args=python_module_link_args)] extra_link_args=python_module_link_args)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment