From 1816b95ea0e7ead1f4b186f35dbd413fbb1114b3 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Wed, 14 Aug 2019 20:58:54 +0200 Subject: [PATCH] allow controlling the number of threads --- nifty_gridder.cc | 64 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/nifty_gridder.cc b/nifty_gridder.cc index e499fbd..3e0dfa8 100644 --- a/nifty_gridder.cc +++ b/nifty_gridder.cc @@ -24,7 +24,6 @@ #include <iostream> #include <algorithm> -#define POCKETFFT_OPENMP #include "pocketfft_hdronly.h" #ifdef __GNUC__ @@ -63,6 +62,32 @@ template<typename T> inline T fmodulo (T v1, T v2) return (tmp==v2) ? T(0) : tmp; } +static size_t nthreads = 1; + +constexpr auto set_nthreads_DS = R"""( +Specifies the number of threads to be used by the module + +Parameters +========== +nthreads: int + the number of threads to be used. Must be >=1. +)"""; +void set_nthreads(size_t nthreads_) + { + if (nthreads_<1) throw runtime_error("nthreads must be >= 1"); + nthreads = nthreads_; + } + +constexpr auto get_nthreads_DS = R"""( +Returns the number of threads used by the module + +Returns +======= +int : the number of threads used by the module +)"""; +size_t get_nthreads() + { return nthreads;} + // // Utilities for Gauss-Legendre quadrature // @@ -81,7 +106,7 @@ void legendre_prep(int n, vector<double> &x, vector<double> &w) double t0 = 1 - (1-1./n) / (8.*n*n); double t1 = 1./(4.*n+2.); -#pragma omp parallel +#pragma omp parallel num_threads(nthreads) { int i; #pragma omp for schedule(dynamic,100) @@ -213,7 +238,7 @@ template<typename T> pyarr_c<T> complex2hartley auto grid2 = res.mutable_data(); { py::gil_scoped_release release; -#pragma omp parallel for +#pragma omp parallel for num_threads(nthreads) for (size_t u=0; u<nu; ++u) { size_t xu = (u==0) ? 0 : nu-u; @@ -241,7 +266,7 @@ template<typename T> pyarr_c<complex<T>> hartley2complex auto grid2 = res.mutable_data(); { py::gil_scoped_release release; -#pragma omp parallel for +#pragma omp parallel for num_threads(nthreads) for (size_t u=0; u<nu; ++u) { size_t xu = (u==0) ? 0 : nu-u; @@ -268,8 +293,9 @@ template<typename T> void hartley2_2D(const pyarr_c<T> &in, pyarr_c<T> &out) auto ptmp = out.mutable_data(); { py::gil_scoped_release release; - pocketfft::r2r_separable_hartley({nu, nv}, s_i, s_o, {0,1}, d_i, ptmp, T(1), 0); -#pragma omp parallel for + pocketfft::r2r_separable_hartley({nu, nv}, s_i, s_o, {0,1}, d_i, ptmp, T(1), + nthreads); +#pragma omp parallel for num_threads(nthreads) for(size_t i=1; i<(nu+1)/2; ++i) for(size_t j=1; j<(nv+1)/2; ++j) { @@ -299,7 +325,7 @@ vector<double> correction_factors (size_t n, size_t nval, size_t w) for (auto &v:psi) v = exp(beta*(sqrt(1-v*v)-1.)); vector<double> res(nval); -#pragma omp parallel for schedule(static) +#pragma omp parallel for schedule(static) num_threads(nthreads) for (size_t k=0; k<nval; ++k) { double tmp=0; @@ -397,7 +423,7 @@ template<typename T> class Baselines auto vis = res.mutable_data(); { py::gil_scoped_release release; -#pragma omp parallel for +#pragma omp parallel for num_threads(nthreads) for (size_t i=0; i<nvis; ++i) { auto t = idx(i); @@ -439,7 +465,7 @@ template<typename T> class Baselines auto ms = res.template mutable_unchecked<2>(); { py::gil_scoped_release release; -#pragma omp parallel for +#pragma omp parallel for num_threads(nthreads) for (size_t i=0; i<nvis; ++i) { auto t = idx(i); @@ -578,7 +604,7 @@ template<typename T> class GridderConfig auto ptmp = tmp.mutable_data(); pocketfft::c2c({nu,nv},{grid.strides(0),grid.strides(1)}, {tmp.strides(0), tmp.strides(1)}, {0,1}, pocketfft::BACKWARD, - grid.data(), tmp.mutable_data(), T(1), 0); + grid.data(), tmp.mutable_data(), T(1), nthreads); auto res = makeArray<complex<T>>({nx_dirty, ny_dirty}); auto pout = res.mutable_data(); { @@ -640,7 +666,7 @@ template<typename T> class GridderConfig ptmp[nv*i2+j2] = pdirty[ny_dirty*i + j]*cfu[i]*cfv[j]; } pocketfft::c2c({nu,nv}, strides, strides, {0,1}, pocketfft::FORWARD, - ptmp, ptmp, T(1), 0); + ptmp, ptmp, T(1), nthreads); } return tmp; } @@ -795,7 +821,7 @@ template<typename T> pyarr_c<complex<T>> vis2grid_c( T beta = gconf.Beta(); size_t w = gconf.W(); -#pragma omp parallel +#pragma omp parallel num_threads(nthreads) { Helper<T> hlp(gconf, nullptr, grid); T emb = exp(-2*beta); @@ -894,7 +920,7 @@ template<typename T> pyarr_c<complex<T>> ms2grid_c( T beta = gconf.Beta(); size_t w = gconf.W(); -#pragma omp parallel +#pragma omp parallel num_threads(nthreads) { Helper<T> hlp(gconf, nullptr, grid); T emb = exp(-2*beta); @@ -954,7 +980,7 @@ template<typename T> pyarr_c<complex<T>> ms2grid_c_wgt( T beta = gconf.Beta(); size_t w = gconf.W(); -#pragma omp parallel +#pragma omp parallel num_threads(nthreads) { Helper<T> hlp(gconf, nullptr, grid); T emb = exp(-2*beta); @@ -1011,7 +1037,7 @@ template<typename T> pyarr_c<complex<T>> grid2vis_c( size_t w = gconf.W(); // Loop over sampling points -#pragma omp parallel +#pragma omp parallel num_threads(nthreads) { Helper<T> hlp(gconf, grid, nullptr); T emb = exp(-2*beta); @@ -1087,7 +1113,7 @@ template<typename T> pyarr_c<complex<T>> grid2ms_c(const Baselines<T> &baselines size_t w = gconf.W(); // Loop over sampling points -#pragma omp parallel +#pragma omp parallel num_threads(nthreads) { Helper<T> hlp(gconf, grid, nullptr); T emb = exp(-2*beta); @@ -1149,7 +1175,7 @@ template<typename T> pyarr_c<complex<T>> grid2ms_c_wgt( size_t w = gconf.W(); // Loop over sampling points -#pragma omp parallel +#pragma omp parallel num_threads(nthreads) { Helper<T> hlp(gconf, grid, nullptr); T emb = exp(-2*beta); @@ -1210,7 +1236,7 @@ template<typename T> pyarr_c<complex<T>> apply_holo( size_t w = gconf.W(); // Loop over sampling points -#pragma omp parallel +#pragma omp parallel num_threads(nthreads) { Helper<T> hlp(gconf, grid, ogrid); T emb = exp(-2*beta); @@ -1412,4 +1438,6 @@ PYBIND11_MODULE(nifty_gridder, m) "idx"_a, "grid"_a, "wgt"_a, "ms_in"_a=None); m.def("apply_holo",&apply_holo<double>, "baselines"_a, "gconf"_a, "idx"_a, "grid"_a); + m.def("set_nthreads",&set_nthreads, set_nthreads_DS, "nthreads"_a); + m.def("get_nthreads",&get_nthreads, get_nthreads_DS); } -- GitLab