Commit 1816b95e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

allow controlling the number of threads

parent e2093551
......@@ -24,7 +24,6 @@
#include <iostream>
#include <algorithm>
#define POCKETFFT_OPENMP
#include "pocketfft_hdronly.h"
#ifdef __GNUC__
......@@ -63,6 +62,32 @@ template<typename T> inline T fmodulo (T v1, T v2)
return (tmp==v2) ? T(0) : tmp;
}
static size_t nthreads = 1;
constexpr auto set_nthreads_DS = R"""(
Specifies the number of threads to be used by the module
Parameters
==========
nthreads: int
the number of threads to be used. Must be >=1.
)""";
void set_nthreads(size_t nthreads_)
{
if (nthreads_<1) throw runtime_error("nthreads must be >= 1");
nthreads = nthreads_;
}
constexpr auto get_nthreads_DS = R"""(
Returns the number of threads used by the module
Returns
=======
int : the number of threads used by the module
)""";
size_t get_nthreads()
{ return nthreads;}
//
// Utilities for Gauss-Legendre quadrature
//
......@@ -81,7 +106,7 @@ void legendre_prep(int n, vector<double> &x, vector<double> &w)
double t0 = 1 - (1-1./n) / (8.*n*n);
double t1 = 1./(4.*n+2.);
#pragma omp parallel
#pragma omp parallel num_threads(nthreads)
{
int i;
#pragma omp for schedule(dynamic,100)
......@@ -213,7 +238,7 @@ template<typename T> pyarr_c<T> complex2hartley
auto grid2 = res.mutable_data();
{
py::gil_scoped_release release;
#pragma omp parallel for
#pragma omp parallel for num_threads(nthreads)
for (size_t u=0; u<nu; ++u)
{
size_t xu = (u==0) ? 0 : nu-u;
......@@ -241,7 +266,7 @@ template<typename T> pyarr_c<complex<T>> hartley2complex
auto grid2 = res.mutable_data();
{
py::gil_scoped_release release;
#pragma omp parallel for
#pragma omp parallel for num_threads(nthreads)
for (size_t u=0; u<nu; ++u)
{
size_t xu = (u==0) ? 0 : nu-u;
......@@ -268,8 +293,9 @@ template<typename T> void hartley2_2D(const pyarr_c<T> &in, pyarr_c<T> &out)
auto ptmp = out.mutable_data();
{
py::gil_scoped_release release;
pocketfft::r2r_separable_hartley({nu, nv}, s_i, s_o, {0,1}, d_i, ptmp, T(1), 0);
#pragma omp parallel for
pocketfft::r2r_separable_hartley({nu, nv}, s_i, s_o, {0,1}, d_i, ptmp, T(1),
nthreads);
#pragma omp parallel for num_threads(nthreads)
for(size_t i=1; i<(nu+1)/2; ++i)
for(size_t j=1; j<(nv+1)/2; ++j)
{
......@@ -299,7 +325,7 @@ vector<double> correction_factors (size_t n, size_t nval, size_t w)
for (auto &v:psi)
v = exp(beta*(sqrt(1-v*v)-1.));
vector<double> res(nval);
#pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static) num_threads(nthreads)
for (size_t k=0; k<nval; ++k)
{
double tmp=0;
......@@ -397,7 +423,7 @@ template<typename T> class Baselines
auto vis = res.mutable_data();
{
py::gil_scoped_release release;
#pragma omp parallel for
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nvis; ++i)
{
auto t = idx(i);
......@@ -439,7 +465,7 @@ template<typename T> class Baselines
auto ms = res.template mutable_unchecked<2>();
{
py::gil_scoped_release release;
#pragma omp parallel for
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nvis; ++i)
{
auto t = idx(i);
......@@ -578,7 +604,7 @@ template<typename T> class GridderConfig
auto ptmp = tmp.mutable_data();
pocketfft::c2c({nu,nv},{grid.strides(0),grid.strides(1)},
{tmp.strides(0), tmp.strides(1)}, {0,1}, pocketfft::BACKWARD,
grid.data(), tmp.mutable_data(), T(1), 0);
grid.data(), tmp.mutable_data(), T(1), nthreads);
auto res = makeArray<complex<T>>({nx_dirty, ny_dirty});
auto pout = res.mutable_data();
{
......@@ -640,7 +666,7 @@ template<typename T> class GridderConfig
ptmp[nv*i2+j2] = pdirty[ny_dirty*i + j]*cfu[i]*cfv[j];
}
pocketfft::c2c({nu,nv}, strides, strides, {0,1}, pocketfft::FORWARD,
ptmp, ptmp, T(1), 0);
ptmp, ptmp, T(1), nthreads);
}
return tmp;
}
......@@ -795,7 +821,7 @@ template<typename T> pyarr_c<complex<T>> vis2grid_c(
T beta = gconf.Beta();
size_t w = gconf.W();
#pragma omp parallel
#pragma omp parallel num_threads(nthreads)
{
Helper<T> hlp(gconf, nullptr, grid);
T emb = exp(-2*beta);
......@@ -894,7 +920,7 @@ template<typename T> pyarr_c<complex<T>> ms2grid_c(
T beta = gconf.Beta();
size_t w = gconf.W();
#pragma omp parallel
#pragma omp parallel num_threads(nthreads)
{
Helper<T> hlp(gconf, nullptr, grid);
T emb = exp(-2*beta);
......@@ -954,7 +980,7 @@ template<typename T> pyarr_c<complex<T>> ms2grid_c_wgt(
T beta = gconf.Beta();
size_t w = gconf.W();
#pragma omp parallel
#pragma omp parallel num_threads(nthreads)
{
Helper<T> hlp(gconf, nullptr, grid);
T emb = exp(-2*beta);
......@@ -1011,7 +1037,7 @@ template<typename T> pyarr_c<complex<T>> grid2vis_c(
size_t w = gconf.W();
// Loop over sampling points
#pragma omp parallel
#pragma omp parallel num_threads(nthreads)
{
Helper<T> hlp(gconf, grid, nullptr);
T emb = exp(-2*beta);
......@@ -1087,7 +1113,7 @@ template<typename T> pyarr_c<complex<T>> grid2ms_c(const Baselines<T> &baselines
size_t w = gconf.W();
// Loop over sampling points
#pragma omp parallel
#pragma omp parallel num_threads(nthreads)
{
Helper<T> hlp(gconf, grid, nullptr);
T emb = exp(-2*beta);
......@@ -1149,7 +1175,7 @@ template<typename T> pyarr_c<complex<T>> grid2ms_c_wgt(
size_t w = gconf.W();
// Loop over sampling points
#pragma omp parallel
#pragma omp parallel num_threads(nthreads)
{
Helper<T> hlp(gconf, grid, nullptr);
T emb = exp(-2*beta);
......@@ -1210,7 +1236,7 @@ template<typename T> pyarr_c<complex<T>> apply_holo(
size_t w = gconf.W();
// Loop over sampling points
#pragma omp parallel
#pragma omp parallel num_threads(nthreads)
{
Helper<T> hlp(gconf, grid, ogrid);
T emb = exp(-2*beta);
......@@ -1412,4 +1438,6 @@ PYBIND11_MODULE(nifty_gridder, m)
"idx"_a, "grid"_a, "wgt"_a, "ms_in"_a=None);
m.def("apply_holo",&apply_holo<double>, "baselines"_a, "gconf"_a, "idx"_a,
"grid"_a);
m.def("set_nthreads",&set_nthreads, set_nthreads_DS, "nthreads"_a);
m.def("get_nthreads",&get_nthreads, get_nthreads_DS);
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment