diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index 2340ac29e2043848d186bada50475916d0e1fd39..dec087d102bbe320dcaedd115490bbfcf2f94809 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -70,6 +70,10 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include <atomic> #include <functional> +#ifdef POCKETFFT_PTHREADS +# include <pthread.h> +#endif + #if defined(__GNUC__) #define POCKETFFT_NOINLINE __attribute__((noinline)) @@ -604,11 +608,11 @@ class latch template <typename T> class concurrent_queue { - queue<T> q_; - mutex mut_; - condition_variable item_added_; - bool shutdown_; - using lock_t = unique_lock<mutex>; + queue<T> q_; + mutex mut_; + condition_variable item_added_; + bool shutdown_; + using lock_t = unique_lock<mutex>; public: concurrent_queue(): shutdown_(false) {} @@ -644,6 +648,8 @@ template <typename T> class concurrent_queue } item_added_.notify_all(); } + + void restart() { shutdown_ = false; } }; class thread_pool @@ -658,41 +664,64 @@ class thread_pool work(); } - public: - explicit thread_pool(size_t nthreads): - threads_(nthreads) + void create_threads() { + size_t nthreads = threads_.size(); for (size_t i=0; i<nthreads; ++i) { try { threads_[i] = thread([this]{ worker_main(); }); } catch (...) { - work_queue_.shutdown(); - for (size_t j = 0; j < i; ++j) - threads_[j].join(); + shutdown(); throw; } } } + public: + explicit thread_pool(size_t nthreads): + threads_(nthreads) + { create_threads(); } + thread_pool(): thread_pool(thread::hardware_concurrency()) {} - ~thread_pool() + ~thread_pool() { shutdown(); } + + void submit(function<void()> work) + { + work_queue_.push(move(work)); + } + + void shutdown() { work_queue_.shutdown(); - for (size_t i=0; i<threads_.size(); ++i) - threads_[i].join(); + for (auto &thread : threads_) + if (thread.joinable()) + thread.join(); } - void submit(function<void()> work) + void restart() { - work_queue_.push(move(work)); + work_queue_.restart(); + create_threads(); } }; thread_pool & get_pool() { static thread_pool pool; +#ifdef POCKETFFT_PTHREADS + static once_flag f; + call_once(f, + []{ + pthread_atfork( + +[]{ get_pool().shutdown(); }, // prepare + +[]{ get_pool().restart(); }, // parent + +[]{ get_pool().restart(); } // child + ); + }); +#endif + return pool; } @@ -700,13 +729,13 @@ thread_pool & get_pool() template <typename Func> void thread_map(size_t nthreads, Func f) { - auto & pool = get_pool(); if (nthreads == 0) nthreads = thread::hardware_concurrency(); if (nthreads == 1) { f(); return; } + auto & pool = get_pool(); latch counter(nthreads); for (size_t i=0; i<nthreads; ++i) { diff --git a/setup.py b/setup.py index 85867cde80c8355bfaca4deb2ca9ebf1c3f64654..e85518bf04e674c9767dd8bf0eca6a458e2f0cac 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ include_dirs = ['./', _deferred_pybind11_include(True), _deferred_pybind11_include()] extra_compile_args = ['--std=c++11', '-march=native', '-O3'] python_module_link_args = [] +define_macros = [('POCKETFFT_PTHREADS', None)] if sys.platform == 'darwin': import distutils.sysconfig @@ -35,6 +36,7 @@ def get_extension_modules(): sources=['pypocketfft.cc'], depends=['pocketfft_hdronly.h'], include_dirs=include_dirs, + define_macros=define_macros, extra_compile_args=extra_compile_args, extra_link_args=python_module_link_args)]