From 1816b95ea0e7ead1f4b186f35dbd413fbb1114b3 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Wed, 14 Aug 2019 20:58:54 +0200
Subject: [PATCH] allow controlling the number of threads

---
 nifty_gridder.cc | 64 ++++++++++++++++++++++++++++++++++--------------
 1 file changed, 46 insertions(+), 18 deletions(-)

diff --git a/nifty_gridder.cc b/nifty_gridder.cc
index e499fbd..3e0dfa8 100644
--- a/nifty_gridder.cc
+++ b/nifty_gridder.cc
@@ -24,7 +24,6 @@
 #include <iostream>
 #include <algorithm>
 
-#define POCKETFFT_OPENMP
 #include "pocketfft_hdronly.h"
 
 #ifdef __GNUC__
@@ -63,6 +62,32 @@ template<typename T> inline T fmodulo (T v1, T v2)
   return (tmp==v2) ? T(0) : tmp;
   }
 
+static size_t nthreads = 1;
+
+constexpr auto set_nthreads_DS = R"""(
+Specifies the number of threads to be used by the module
+
+Parameters
+==========
+nthreads: int
+    the number of threads to be used. Must be >=1.
+)""";
+void set_nthreads(size_t nthreads_)
+  {
+  if (nthreads_<1) throw runtime_error("nthreads must be >= 1");
+  nthreads = nthreads_;
+  }
+
+constexpr auto get_nthreads_DS = R"""(
+Returns the number of threads used by the module
+
+Returns
+=======
+int : the number of threads used by the module
+)""";
+size_t get_nthreads()
+  { return nthreads;}
+
 //
 // Utilities for Gauss-Legendre quadrature
 //
@@ -81,7 +106,7 @@ void legendre_prep(int n, vector<double> &x, vector<double> &w)
   double t0 = 1 - (1-1./n) / (8.*n*n);
   double t1 = 1./(4.*n+2.);
 
-#pragma omp parallel
+#pragma omp parallel num_threads(nthreads)
 {
   int i;
 #pragma omp for schedule(dynamic,100)
@@ -213,7 +238,7 @@ template<typename T> pyarr_c<T> complex2hartley
   auto grid2 = res.mutable_data();
   {
   py::gil_scoped_release release;
-#pragma omp parallel for
+#pragma omp parallel for num_threads(nthreads)
   for (size_t u=0; u<nu; ++u)
     {
     size_t xu = (u==0) ? 0 : nu-u;
@@ -241,7 +266,7 @@ template<typename T> pyarr_c<complex<T>> hartley2complex
   auto grid2 = res.mutable_data();
   {
   py::gil_scoped_release release;
-#pragma omp parallel for
+#pragma omp parallel for num_threads(nthreads)
   for (size_t u=0; u<nu; ++u)
     {
     size_t xu = (u==0) ? 0 : nu-u;
@@ -268,8 +293,9 @@ template<typename T> void hartley2_2D(const pyarr_c<T> &in, pyarr_c<T> &out)
   auto ptmp = out.mutable_data();
   {
   py::gil_scoped_release release;
-  pocketfft::r2r_separable_hartley({nu, nv}, s_i, s_o, {0,1}, d_i, ptmp, T(1), 0);
-#pragma omp parallel for
+  pocketfft::r2r_separable_hartley({nu, nv}, s_i, s_o, {0,1}, d_i, ptmp, T(1),
+    nthreads);
+#pragma omp parallel for num_threads(nthreads)
   for(size_t i=1; i<(nu+1)/2; ++i)
     for(size_t j=1; j<(nv+1)/2; ++j)
        {
@@ -299,7 +325,7 @@ vector<double> correction_factors (size_t n, size_t nval, size_t w)
   for (auto &v:psi)
     v = exp(beta*(sqrt(1-v*v)-1.));
   vector<double> res(nval);
-#pragma omp parallel for schedule(static)
+#pragma omp parallel for schedule(static) num_threads(nthreads)
   for (size_t k=0; k<nval; ++k)
     {
     double tmp=0;
@@ -397,7 +423,7 @@ template<typename T> class Baselines
       auto vis = res.mutable_data();
       {
       py::gil_scoped_release release;
-#pragma omp parallel for
+#pragma omp parallel for num_threads(nthreads)
       for (size_t i=0; i<nvis; ++i)
         {
         auto t = idx(i);
@@ -439,7 +465,7 @@ template<typename T> class Baselines
       auto ms = res.template mutable_unchecked<2>();
       {
       py::gil_scoped_release release;
-#pragma omp parallel for
+#pragma omp parallel for num_threads(nthreads)
       for (size_t i=0; i<nvis; ++i)
         {
         auto t = idx(i);
@@ -578,7 +604,7 @@ template<typename T> class GridderConfig
       auto ptmp = tmp.mutable_data();
       pocketfft::c2c({nu,nv},{grid.strides(0),grid.strides(1)},
         {tmp.strides(0), tmp.strides(1)}, {0,1}, pocketfft::BACKWARD,
-        grid.data(), tmp.mutable_data(), T(1), 0);
+        grid.data(), tmp.mutable_data(), T(1), nthreads);
       auto res = makeArray<complex<T>>({nx_dirty, ny_dirty});
       auto pout = res.mutable_data();
       {
@@ -640,7 +666,7 @@ template<typename T> class GridderConfig
           ptmp[nv*i2+j2] = pdirty[ny_dirty*i + j]*cfu[i]*cfv[j];
           }
       pocketfft::c2c({nu,nv}, strides, strides, {0,1}, pocketfft::FORWARD,
-        ptmp, ptmp, T(1), 0);
+        ptmp, ptmp, T(1), nthreads);
       }
       return tmp;
       }
@@ -795,7 +821,7 @@ template<typename T> pyarr_c<complex<T>> vis2grid_c(
   T beta = gconf.Beta();
   size_t w = gconf.W();
 
-#pragma omp parallel
+#pragma omp parallel num_threads(nthreads)
 {
   Helper<T> hlp(gconf, nullptr, grid);
   T emb = exp(-2*beta);
@@ -894,7 +920,7 @@ template<typename T> pyarr_c<complex<T>> ms2grid_c(
   T beta = gconf.Beta();
   size_t w = gconf.W();
 
-#pragma omp parallel
+#pragma omp parallel num_threads(nthreads)
 {
   Helper<T> hlp(gconf, nullptr, grid);
   T emb = exp(-2*beta);
@@ -954,7 +980,7 @@ template<typename T> pyarr_c<complex<T>> ms2grid_c_wgt(
   T beta = gconf.Beta();
   size_t w = gconf.W();
 
-#pragma omp parallel
+#pragma omp parallel num_threads(nthreads)
 {
   Helper<T> hlp(gconf, nullptr, grid);
   T emb = exp(-2*beta);
@@ -1011,7 +1037,7 @@ template<typename T> pyarr_c<complex<T>> grid2vis_c(
   size_t w = gconf.W();
 
   // Loop over sampling points
-#pragma omp parallel
+#pragma omp parallel num_threads(nthreads)
 {
   Helper<T> hlp(gconf, grid, nullptr);
   T emb = exp(-2*beta);
@@ -1087,7 +1113,7 @@ template<typename T> pyarr_c<complex<T>> grid2ms_c(const Baselines<T> &baselines
   size_t w = gconf.W();
 
   // Loop over sampling points
-#pragma omp parallel
+#pragma omp parallel num_threads(nthreads)
 {
   Helper<T> hlp(gconf, grid, nullptr);
   T emb = exp(-2*beta);
@@ -1149,7 +1175,7 @@ template<typename T> pyarr_c<complex<T>> grid2ms_c_wgt(
   size_t w = gconf.W();
 
   // Loop over sampling points
-#pragma omp parallel
+#pragma omp parallel num_threads(nthreads)
 {
   Helper<T> hlp(gconf, grid, nullptr);
   T emb = exp(-2*beta);
@@ -1210,7 +1236,7 @@ template<typename T> pyarr_c<complex<T>> apply_holo(
   size_t w = gconf.W();
 
   // Loop over sampling points
-#pragma omp parallel
+#pragma omp parallel num_threads(nthreads)
 {
   Helper<T> hlp(gconf, grid, ogrid);
   T emb = exp(-2*beta);
@@ -1412,4 +1438,6 @@ PYBIND11_MODULE(nifty_gridder, m)
     "idx"_a, "grid"_a, "wgt"_a, "ms_in"_a=None);
   m.def("apply_holo",&apply_holo<double>, "baselines"_a, "gconf"_a, "idx"_a,
     "grid"_a);
+  m.def("set_nthreads",&set_nthreads, set_nthreads_DS, "nthreads"_a);
+  m.def("get_nthreads",&get_nthreads, get_nthreads_DS);
   }
-- 
GitLab