Commit 0836cea1 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

nthreads is no longer global

parent 46487e0f
......@@ -149,17 +149,6 @@ template<typename T> inline T fmodulo (T v1, T v2)
return (tmp==v2) ? T(0) : tmp;
}
static size_t nthreads = 1;
void set_nthreads(size_t nthreads_)
{
myassert(nthreads_>=1, "nthreads must be >= 1");
nthreads = nthreads_;
}
size_t get_nthreads()
{ return nthreads; }
//
// Utilities for Gauss-Legendre quadrature
//
......@@ -167,7 +156,7 @@ size_t get_nthreads()
static inline double one_minus_x2 (double x)
{ return (fabs(x)>0.1) ? (1.+x)*(1.-x) : 1.-x*x; }
void legendre_prep(int n, vector<double> &x, vector<double> &w)
void legendre_prep(int n, vector<double> &x, vector<double> &w, size_t nthreads)
{
constexpr double pi = 3.141592653589793238462643383279502884197;
constexpr double eps = 3e-14;
......@@ -267,10 +256,10 @@ template<typename T> class EC_Kernel_with_correction: public EC_Kernel<T>
public:
using EC_Kernel<T>::operator();
EC_Kernel_with_correction(size_t supp_)
EC_Kernel_with_correction(size_t supp_, size_t nthreads)
: EC_Kernel<T>(supp_), p(int(1.5*supp_+2))
{
legendre_prep(2*p,x,wgt);
legendre_prep(2*p,x,wgt,nthreads);
psi=x;
for (auto &v:psi)
v=operator()(v);
......@@ -287,9 +276,10 @@ template<typename T> class EC_Kernel_with_correction: public EC_Kernel<T>
};
/* Compute correction factors for the ES gridding kernel
This implementation follows eqs. (3.8) to (3.10) of Barnett et al. 2018 */
vector<double> correction_factors (size_t n, size_t nval, size_t supp)
vector<double> correction_factors (size_t n, size_t nval, size_t supp,
size_t nthreads)
{
EC_Kernel_with_correction<double> kernel(supp);
EC_Kernel_with_correction<double> kernel(supp, nthreads);
vector<double> res(nval);
double xn = 1./n;
#pragma omp parallel for schedule(static) num_threads(nthreads)
......@@ -359,7 +349,7 @@ template<typename T> class Baselines
}
template<typename T2> void ms2vis(const mav<const T2,2> &ms,
const mav<const uint32_t,1> &idx, mav<T2,1> &vis) const
const mav<const uint32_t,1> &idx, mav<T2,1> &vis, size_t nthreads) const
{
myassert(ms.shape(0)==nrows, "shape mismatch");
myassert(ms.shape(1)==nchan, "shape mismatch");
......@@ -374,7 +364,7 @@ template<typename T> class Baselines
}
template<typename T2> void vis2ms(const mav<const T2,1> &vis,
const mav<const uint32_t,1> &idx, mav<T2,2> &ms) const
const mav<const uint32_t,1> &idx, mav<T2,2> &ms, size_t nthreads) const
{
size_t nvis = vis.shape(0);
myassert(idx.shape(0)==nvis, "shape mismatch");
......@@ -397,6 +387,7 @@ template<typename T> class GridderConfig
size_t supp, nsafe, nu, nv;
T beta;
vector<T> cfu, cfv;
size_t nthreads;
complex<T> wscreen(double x, double y, double w, bool adjoint) const
{
......@@ -412,13 +403,13 @@ template<typename T> class GridderConfig
public:
GridderConfig(size_t nxdirty, size_t nydirty, double epsilon,
double pixsize_x, double pixsize_y)
double pixsize_x, double pixsize_y, size_t nthreads_)
: nx_dirty(nxdirty), ny_dirty(nydirty), eps(epsilon),
psx(pixsize_x), psy(pixsize_y),
supp(get_supp(epsilon)), nsafe((supp+1)/2),
nu(max(2*nsafe,2*nx_dirty)), nv(max(2*nsafe,2*ny_dirty)),
beta(2.3*supp),
cfu(nx_dirty), cfv(ny_dirty)
cfu(nx_dirty), cfv(ny_dirty), nthreads(nthreads_)
{
myassert((nx_dirty&1)==0, "nx_dirty must be even");
myassert((ny_dirty&1)==0, "ny_dirty must be even");
......@@ -426,12 +417,12 @@ template<typename T> class GridderConfig
myassert(pixsize_x>0, "pixsize_x must be positive");
myassert(pixsize_y>0, "pixsize_y must be positive");
auto tmp = correction_factors(nu, nx_dirty/2+1, supp);
auto tmp = correction_factors(nu, nx_dirty/2+1, supp, nthreads);
cfu[nx_dirty/2]=tmp[0];
cfu[0]=tmp[nx_dirty/2];
for (size_t i=1; i<nx_dirty/2; ++i)
cfu[nx_dirty/2-i] = cfu[nx_dirty/2+i] = tmp[i];
tmp = correction_factors(nv, ny_dirty/2+1, supp);
tmp = correction_factors(nv, ny_dirty/2+1, supp, nthreads);
cfv[ny_dirty/2]=tmp[0];
cfv[0]=tmp[ny_dirty/2];
for (size_t i=1; i<ny_dirty/2; ++i)
......@@ -447,6 +438,7 @@ template<typename T> class GridderConfig
size_t Supp() const { return supp; }
size_t Nsafe() const { return nsafe; }
T Beta() const { return beta; }
size_t Nthreads() const { return nthreads; }
void apply_taper(const mav<const T,2> &img, mav<T,2> &img2, bool divide) const
{
......@@ -733,6 +725,7 @@ template<typename T, typename Serv> void x2grid_c
myassert(grid.contiguous(), "grid is not contiguous");
T beta = gconf.Beta();
size_t supp = gconf.Supp();
size_t nthreads = gconf.Nthreads();
#pragma omp parallel num_threads(nthreads)
{
......@@ -786,6 +779,7 @@ template<typename T, typename Serv> void grid2x_c
myassert(grid.contiguous(), "grid is not contiguous");
T beta = gconf.Beta();
size_t supp = gconf.Supp();
size_t nthreads = gconf.Nthreads();
// Loop over sampling points
#pragma omp parallel num_threads(nthreads)
......@@ -840,6 +834,7 @@ template<typename T> void apply_holo
size_t nu=gconf.Nu(), nv=gconf.Nv();
checkShape(grid.shape(), {nu, nv});
size_t nvis = idx.shape(0);
size_t nthreads = gconf.Nthreads();
bool have_wgt = wgt.size()!=0;
if (have_wgt) checkShape(wgt.shape(), {nvis});
checkShape(ogrid.shape(), {nu, nv});
......@@ -906,6 +901,7 @@ template<typename T> void get_correlations
size_t supp = gconf.Supp();
myassert(size_t(abs(du))<supp, "|du| must be smaller than Supp");
myassert(size_t(abs(dv))<supp, "|dv| must be smaller than Supp");
size_t nthreads = gconf.Nthreads();
T beta = gconf.Beta();
for (size_t i=0; i<nu; ++i)
......@@ -960,6 +956,7 @@ template<typename T, typename T2> void apply_wcorr(const GridderConfig<T> &gconf
{
auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty();
size_t nthreads = gconf.Nthreads();
auto psx=gconf.Pixsize_x();
auto psy=gconf.Pixsize_y();
double x0 = -0.5*nx_dirty*psx,
......@@ -1003,6 +1000,7 @@ template<typename T, typename Serv> void wstack_common(
auto psx=gconf.Pixsize_x();
auto psy=gconf.Pixsize_y();
size_t nvis = srv.Nvis();
size_t nthreads = gconf.Nthreads();
// determine w values for every visibility, and min/max w;
wmin=T(1e38);
......@@ -1058,7 +1056,8 @@ template<typename T, typename Serv> void x2dirty_wstack(
auto nv=gconf.Nv();
size_t nvis = srv.Nvis();
auto w_supp = gconf.Supp();
EC_Kernel_with_correction<T> kernel(w_supp);
size_t nthreads = gconf.Nthreads();
EC_Kernel_with_correction<T> kernel(w_supp, nthreads);
T wmin;
vector<T> wval;
double dw;
......@@ -1134,7 +1133,8 @@ template<typename T, typename Serv> void dirty2x_wstack(
auto nv=gconf.Nv();
size_t nvis = srv.Nvis();
auto w_supp = gconf.Supp();
EC_Kernel_with_correction<T> kernel(w_supp);
size_t nthreads = gconf.Nthreads();
EC_Kernel_with_correction<T> kernel(w_supp, nthreads);
T wmin;
vector<T> wval;
double dw;
......@@ -1204,7 +1204,7 @@ template<typename T> void dirty2vis_wstack(const Baselines<T> &baselines,
template<typename T> size_t getIdxSize(const Baselines<T> &baselines,
const const_mav<bool,2> &flags, int chbegin, int chend, T wmin, T wmax)
const const_mav<bool,2> &flags, int chbegin, int chend, T wmin, T wmax, size_t nthreads)
{
size_t nrow=baselines.Nrows(), nchan=baselines.Nchannels();
if (chbegin<0) chbegin=0;
......@@ -1318,9 +1318,8 @@ template<typename T> void full_gridding(const const_mav<T,2> &uvw,
const const_mav<T,2> &wgt, T pixsize_x, T pixsize_y, double epsilon,
size_t nthreads, const mav<complex<T>,2> &dirty, size_t verbosity)
{
set_nthreads(nthreads);
Baselines<T> baselines(uvw, freq);
GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y);
GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y, nthreads);
auto idx = getWgtIndices(baselines, gconf, wgt);
auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()});
auto srv = MsServ<T,const_mav<complex<T>,2>,const_mav<T,2>>(baselines,idx2,ms,wgt);
......@@ -1332,9 +1331,8 @@ template<typename T> void full_degridding(const const_mav<T,2> &uvw,
const const_mav<T,2> &wgt, T pixsize_x, T pixsize_y, double epsilon,
size_t nthreads, const mav<complex<T>,2> &ms, size_t verbosity)
{
set_nthreads(nthreads);
Baselines<T> baselines(uvw, freq);
GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y);
GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y, nthreads);
auto idx = getWgtIndices(baselines, gconf, wgt);
auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()});
for (size_t i=0; i<ms.shape(0); ++i)
......@@ -1349,8 +1347,6 @@ template<typename T> void full_degridding(const const_mav<T,2> &uvw,
// public names
using detail::mav;
using detail::const_mav;
using detail::set_nthreads;
using detail::get_nthreads;
using detail::myassert;
using detail::Baselines;
using detail::GridderConfig;
......
......@@ -33,29 +33,6 @@ namespace {
auto None = py::none();
//
// basic utilities
//
//
// Start of real gridder functionality
//
constexpr auto set_nthreads_DS = R"""(
Specifies the number of threads to be used by the module
Parameters
==========
nthreads: int
the number of threads to be used. Must be >=1.
)""";
constexpr auto get_nthreads_DS = R"""(
Returns the number of threads used by the module
Returns
=======
int : the number of threads used by the module
)""";
template<typename T>
using pyarr = py::array_t<T, 0>;
......@@ -210,7 +187,7 @@ template<typename T> class PyBaselines: public Baselines<T>
}
template<typename T2> pyarr<T2> ms2vis(const pyarr<T2> &ms_,
const pyarr<uint32_t> &idx_) const
const pyarr<uint32_t> &idx_, size_t nthreads) const
{
auto idx=make_const_mav<1>(idx_);
size_t nvis = size_t(idx.shape(0));
......@@ -219,7 +196,7 @@ template<typename T> class PyBaselines: public Baselines<T>
auto vis = make_mav<1>(res);
{
py::gil_scoped_release release;
Baselines<T>::ms2vis(ms, idx, vis);
Baselines<T>::ms2vis(ms, idx, vis, nthreads);
}
return res;
}
......@@ -242,7 +219,7 @@ template<typename T> class PyBaselines: public Baselines<T>
the measurement set's visibility data (0 where not covered by idx)
)""";
template<typename T2> pyarr<T2> vis2ms(const pyarr<T2> &vis_,
const pyarr<uint32_t> &idx_, py::object &ms_in) const
const pyarr<uint32_t> &idx_, py::object &ms_in, size_t nthreads) const
{
auto vis=make_const_mav<1>(vis_);
auto idx=make_const_mav<1>(idx_);
......@@ -250,7 +227,7 @@ template<typename T> class PyBaselines: public Baselines<T>
auto ms = make_mav<2>(res);
{
py::gil_scoped_release release;
Baselines<T>::vis2ms(vis, idx, ms);
Baselines<T>::vis2ms(vis, idx, ms, nthreads);
}
return res;
}
......@@ -318,8 +295,8 @@ template<typename T> class PyGridderConfig: public GridderConfig<T>
public:
using GridderConfig<T>::GridderConfig;
PyGridderConfig(size_t nxdirty, size_t nydirty, double epsilon,
double pixsize_x, double pixsize_y)
: GridderConfig<T>(nxdirty, nydirty, epsilon, pixsize_x, pixsize_y) {}
double pixsize_x, double pixsize_y, size_t nthreads)
: GridderConfig<T>(nxdirty, nydirty, epsilon, pixsize_x, pixsize_y, nthreads) {}
using GridderConfig<T>::Nxdirty;
using GridderConfig<T>::Nydirty;
using GridderConfig<T>::Epsilon;
......@@ -616,13 +593,13 @@ np.array((nvis,), dtype=np.uint32)
)""";
template<typename T> pyarr<uint32_t> PygetIndices(const PyBaselines<T> &baselines,
const PyGridderConfig<T> &gconf, const pyarr<bool> &flags_, int chbegin,
int chend, T wmin, T wmax)
int chend, T wmin, T wmax, size_t nthreads)
{
size_t nidx;
auto flags = make_const_mav<2>(flags_);
{
py::gil_scoped_release release;
nidx = getIdxSize(baselines, flags, chbegin, chend, wmin, wmax);
nidx = getIdxSize(baselines, flags, chbegin, chend, wmin, wmax, nthreads);
}
auto res = makeArray<uint32_t>({nidx});
auto res2 = make_mav<1>(res);
......@@ -709,13 +686,13 @@ PYBIND11_MODULE(nifty_gridder, m)
.def ("Nrows",&PyBaselines<double>::Nrows)
.def ("Nchannels",&PyBaselines<double>::Nchannels)
.def ("ms2vis",&PyBaselines<double>::ms2vis<complex<double>>,
PyBaselines<double>::ms2vis_DS, "ms"_a, "idx"_a)
PyBaselines<double>::ms2vis_DS, "ms"_a, "idx"_a, "nthreads"_a=1)
.def ("effectiveuvw",&PyBaselines<double>::effectiveuvw, "idx"_a)
.def ("vis2ms",&PyBaselines<double>::vis2ms<complex<double>>,
PyBaselines<double>::vis2ms_DS, "vis"_a, "idx"_a, "ms_in"_a=None);
PyBaselines<double>::vis2ms_DS, "vis"_a, "idx"_a, "ms_in"_a=None, "nthreads"_a=1);
py::class_<PyGridderConfig<double>> (m, "GridderConfig", GridderConfig_DS)
.def(py::init<size_t, size_t, double, double, double>(),"nxdirty"_a,
"nydirty"_a, "epsilon"_a, "pixsize_x"_a, "pixsize_y"_a)
.def(py::init<size_t, size_t, double, double, double, size_t>(),"nxdirty"_a,
"nydirty"_a, "epsilon"_a, "pixsize_x"_a, "pixsize_y"_a, "nthreads"_a=1)
.def("Nxdirty", &PyGridderConfig<double>::Nxdirty)
.def("Nydirty", &PyGridderConfig<double>::Nydirty)
.def("Epsilon", &PyGridderConfig<double>::Epsilon)
......@@ -737,21 +714,21 @@ PYBIND11_MODULE(nifty_gridder, m)
[](const PyGridderConfig<double> & gc) {
// Encode object state in tuple
return py::make_tuple(gc.Nxdirty(), gc.Nydirty(), gc.Epsilon(),
gc.Pixsize_x(), gc.Pixsize_y());
gc.Pixsize_x(), gc.Pixsize_y(), gc.Nthreads());
},
// __setstate__
[](py::tuple t) {
myassert(t.size()==5,"Invalid state");
myassert(t.size()==6,"Invalid state");
// Reconstruct from tuple
return PyGridderConfig<double>(t[0].cast<size_t>(), t[1].cast<size_t>(),
t[2].cast<double>(), t[3].cast<double>(),
t[4].cast<double>());
t[4].cast<double>(), t[5].cast<size_t>());
}));
m.def("getIndices", PygetIndices<double>, getIndices_DS, "baselines"_a,
"gconf"_a, "flags"_a, "chbegin"_a=-1, "chend"_a=-1,
"wmin"_a=-1e30, "wmax"_a=1e30);
"wmin"_a=-1e30, "wmax"_a=1e30, "nthreads"_a=1);
m.def("vis2grid_c",&Pyvis2grid_c<double>, vis2grid_c_DS, "baselines"_a,
"gconf"_a, "idx"_a, "vis"_a, "grid_in"_a=None, "wgt"_a=None);
m.def("ms2grid_c",&Pyms2grid_c<double>, ms2grid_c_DS, "baselines"_a, "gconf"_a,
......@@ -764,8 +741,6 @@ PYBIND11_MODULE(nifty_gridder, m)
"grid"_a, "wgt"_a=None);
m.def("get_correlations", &Pyget_correlations<double>, "baselines"_a, "gconf"_a,
"idx"_a, "du"_a, "dv"_a, "wgt"_a=None);
m.def("set_nthreads",&set_nthreads, set_nthreads_DS, "nthreads"_a);
m.def("get_nthreads",&get_nthreads, get_nthreads_DS);
m.def("vis2dirty_wstack",&Pyvis2dirty_wstack<double>, "baselines"_a, "gconf"_a,
"idx"_a, "vis"_a);
m.def("dirty2vis_wstack",&Pydirty2vis_wstack<double>, "baselines"_a, "gconf"_a,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment