Commit 0b65b8fb authored by Peter Bell's avatar Peter Bell

Handle fork correctly in thread pool

parent 3aad03b5
......@@ -70,6 +70,10 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <atomic>
#include <functional>
#ifdef POCKETFFT_PTHREADS
# include <pthread.h>
#endif
#if defined(__GNUC__)
#define POCKETFFT_NOINLINE __attribute__((noinline))
......@@ -604,11 +608,11 @@ class latch
template <typename T> class concurrent_queue
{
queue<T> q_;
mutex mut_;
condition_variable item_added_;
bool shutdown_;
using lock_t = unique_lock<mutex>;
queue<T> q_;
mutex mut_;
condition_variable item_added_;
bool shutdown_;
using lock_t = unique_lock<mutex>;
public:
concurrent_queue(): shutdown_(false) {}
......@@ -644,6 +648,8 @@ template <typename T> class concurrent_queue
}
item_added_.notify_all();
}
void restart() { shutdown_ = false; }
};
class thread_pool
......@@ -658,41 +664,64 @@ class thread_pool
work();
}
public:
explicit thread_pool(size_t nthreads):
threads_(nthreads)
void create_threads()
{
size_t nthreads = threads_.size();
for (size_t i=0; i<nthreads; ++i)
{
try { threads_[i] = thread([this]{ worker_main(); }); }
catch (...)
{
work_queue_.shutdown();
for (size_t j = 0; j < i; ++j)
threads_[j].join();
shutdown();
throw;
}
}
}
public:
explicit thread_pool(size_t nthreads):
threads_(nthreads)
{ create_threads(); }
thread_pool(): thread_pool(thread::hardware_concurrency()) {}
~thread_pool()
~thread_pool() { shutdown(); }
void submit(function<void()> work)
{
work_queue_.push(move(work));
}
void shutdown()
{
work_queue_.shutdown();
for (size_t i=0; i<threads_.size(); ++i)
threads_[i].join();
for (auto &thread : threads_)
if (thread.joinable())
thread.join();
}
void submit(function<void()> work)
void restart()
{
work_queue_.push(move(work));
work_queue_.restart();
create_threads();
}
};
thread_pool & get_pool()
{
static thread_pool pool;
#ifdef POCKETFFT_PTHREADS
static once_flag f;
call_once(f,
[]{
pthread_atfork(
+[]{ get_pool().shutdown(); }, // prepare
+[]{ get_pool().restart(); }, // parent
+[]{ get_pool().restart(); } // child
);
});
#endif
return pool;
}
......@@ -700,13 +729,13 @@ thread_pool & get_pool()
template <typename Func>
void thread_map(size_t nthreads, Func f)
{
auto & pool = get_pool();
if (nthreads == 0)
nthreads = thread::hardware_concurrency();
if (nthreads == 1)
{ f(); return; }
auto & pool = get_pool();
latch counter(nthreads);
for (size_t i=0; i<nthreads; ++i)
{
......
......@@ -15,6 +15,7 @@ include_dirs = ['./', _deferred_pybind11_include(True),
_deferred_pybind11_include()]
extra_compile_args = ['--std=c++11', '-march=native', '-O3']
python_module_link_args = []
define_macros = [('POCKETFFT_PTHREADS', None)]
if sys.platform == 'darwin':
import distutils.sysconfig
......@@ -35,6 +36,7 @@ def get_extension_modules():
sources=['pypocketfft.cc'],
depends=['pocketfft_hdronly.h'],
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_link_args=python_module_link_args)]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment