diff --git a/gridder_cxx.h b/gridder_cxx.h index b1bc3362c52f74b9aa427c9cc010b9b46adc1d6e..1c9646c56f2463a034c1497b876888ad3d19276d 100644 --- a/gridder_cxx.h +++ b/gridder_cxx.h @@ -149,17 +149,6 @@ template<typename T> inline T fmodulo (T v1, T v2) return (tmp==v2) ? T(0) : tmp; } -static size_t nthreads = 1; - -void set_nthreads(size_t nthreads_) - { - myassert(nthreads_>=1, "nthreads must be >= 1"); - nthreads = nthreads_; - } - -size_t get_nthreads() - { return nthreads; } - // // Utilities for Gauss-Legendre quadrature // @@ -167,7 +156,7 @@ size_t get_nthreads() static inline double one_minus_x2 (double x) { return (fabs(x)>0.1) ? (1.+x)*(1.-x) : 1.-x*x; } -void legendre_prep(int n, vector<double> &x, vector<double> &w) +void legendre_prep(int n, vector<double> &x, vector<double> &w, size_t nthreads) { constexpr double pi = 3.141592653589793238462643383279502884197; constexpr double eps = 3e-14; @@ -267,10 +256,10 @@ template<typename T> class EC_Kernel_with_correction: public EC_Kernel<T> public: using EC_Kernel<T>::operator(); - EC_Kernel_with_correction(size_t supp_) + EC_Kernel_with_correction(size_t supp_, size_t nthreads) : EC_Kernel<T>(supp_), p(int(1.5*supp_+2)) { - legendre_prep(2*p,x,wgt); + legendre_prep(2*p,x,wgt,nthreads); psi=x; for (auto &v:psi) v=operator()(v); @@ -287,9 +276,10 @@ template<typename T> class EC_Kernel_with_correction: public EC_Kernel<T> }; /* Compute correction factors for the ES gridding kernel This implementation follows eqs. (3.8) to (3.10) of Barnett et al. 2018 */ -vector<double> correction_factors (size_t n, size_t nval, size_t supp) +vector<double> correction_factors (size_t n, size_t nval, size_t supp, + size_t nthreads) { - EC_Kernel_with_correction<double> kernel(supp); + EC_Kernel_with_correction<double> kernel(supp, nthreads); vector<double> res(nval); double xn = 1./n; #pragma omp parallel for schedule(static) num_threads(nthreads) @@ -359,7 +349,7 @@ template<typename T> class Baselines } template<typename T2> void ms2vis(const mav<const T2,2> &ms, - const mav<const uint32_t,1> &idx, mav<T2,1> &vis) const + const mav<const uint32_t,1> &idx, mav<T2,1> &vis, size_t nthreads) const { myassert(ms.shape(0)==nrows, "shape mismatch"); myassert(ms.shape(1)==nchan, "shape mismatch"); @@ -374,7 +364,7 @@ template<typename T> class Baselines } template<typename T2> void vis2ms(const mav<const T2,1> &vis, - const mav<const uint32_t,1> &idx, mav<T2,2> &ms) const + const mav<const uint32_t,1> &idx, mav<T2,2> &ms, size_t nthreads) const { size_t nvis = vis.shape(0); myassert(idx.shape(0)==nvis, "shape mismatch"); @@ -397,6 +387,7 @@ template<typename T> class GridderConfig size_t supp, nsafe, nu, nv; T beta; vector<T> cfu, cfv; + size_t nthreads; complex<T> wscreen(double x, double y, double w, bool adjoint) const { @@ -412,13 +403,13 @@ template<typename T> class GridderConfig public: GridderConfig(size_t nxdirty, size_t nydirty, double epsilon, - double pixsize_x, double pixsize_y) + double pixsize_x, double pixsize_y, size_t nthreads_) : nx_dirty(nxdirty), ny_dirty(nydirty), eps(epsilon), psx(pixsize_x), psy(pixsize_y), supp(get_supp(epsilon)), nsafe((supp+1)/2), nu(max(2*nsafe,2*nx_dirty)), nv(max(2*nsafe,2*ny_dirty)), beta(2.3*supp), - cfu(nx_dirty), cfv(ny_dirty) + cfu(nx_dirty), cfv(ny_dirty), nthreads(nthreads_) { myassert((nx_dirty&1)==0, "nx_dirty must be even"); myassert((ny_dirty&1)==0, "ny_dirty must be even"); @@ -426,12 +417,12 @@ template<typename T> class GridderConfig myassert(pixsize_x>0, "pixsize_x must be positive"); myassert(pixsize_y>0, "pixsize_y must be positive"); - auto tmp = correction_factors(nu, nx_dirty/2+1, supp); + auto tmp = correction_factors(nu, nx_dirty/2+1, supp, nthreads); cfu[nx_dirty/2]=tmp[0]; cfu[0]=tmp[nx_dirty/2]; for (size_t i=1; i<nx_dirty/2; ++i) cfu[nx_dirty/2-i] = cfu[nx_dirty/2+i] = tmp[i]; - tmp = correction_factors(nv, ny_dirty/2+1, supp); + tmp = correction_factors(nv, ny_dirty/2+1, supp, nthreads); cfv[ny_dirty/2]=tmp[0]; cfv[0]=tmp[ny_dirty/2]; for (size_t i=1; i<ny_dirty/2; ++i) @@ -447,6 +438,7 @@ template<typename T> class GridderConfig size_t Supp() const { return supp; } size_t Nsafe() const { return nsafe; } T Beta() const { return beta; } + size_t Nthreads() const { return nthreads; } void apply_taper(const mav<const T,2> &img, mav<T,2> &img2, bool divide) const { @@ -733,6 +725,7 @@ template<typename T, typename Serv> void x2grid_c myassert(grid.contiguous(), "grid is not contiguous"); T beta = gconf.Beta(); size_t supp = gconf.Supp(); + size_t nthreads = gconf.Nthreads(); #pragma omp parallel num_threads(nthreads) { @@ -786,6 +779,7 @@ template<typename T, typename Serv> void grid2x_c myassert(grid.contiguous(), "grid is not contiguous"); T beta = gconf.Beta(); size_t supp = gconf.Supp(); + size_t nthreads = gconf.Nthreads(); // Loop over sampling points #pragma omp parallel num_threads(nthreads) @@ -840,6 +834,7 @@ template<typename T> void apply_holo size_t nu=gconf.Nu(), nv=gconf.Nv(); checkShape(grid.shape(), {nu, nv}); size_t nvis = idx.shape(0); + size_t nthreads = gconf.Nthreads(); bool have_wgt = wgt.size()!=0; if (have_wgt) checkShape(wgt.shape(), {nvis}); checkShape(ogrid.shape(), {nu, nv}); @@ -906,6 +901,7 @@ template<typename T> void get_correlations size_t supp = gconf.Supp(); myassert(size_t(abs(du))<supp, "|du| must be smaller than Supp"); myassert(size_t(abs(dv))<supp, "|dv| must be smaller than Supp"); + size_t nthreads = gconf.Nthreads(); T beta = gconf.Beta(); for (size_t i=0; i<nu; ++i) @@ -960,6 +956,7 @@ template<typename T, typename T2> void apply_wcorr(const GridderConfig<T> &gconf { auto nx_dirty=gconf.Nxdirty(); auto ny_dirty=gconf.Nydirty(); + size_t nthreads = gconf.Nthreads(); auto psx=gconf.Pixsize_x(); auto psy=gconf.Pixsize_y(); double x0 = -0.5*nx_dirty*psx, @@ -1003,6 +1000,7 @@ template<typename T, typename Serv> void wstack_common( auto psx=gconf.Pixsize_x(); auto psy=gconf.Pixsize_y(); size_t nvis = srv.Nvis(); + size_t nthreads = gconf.Nthreads(); // determine w values for every visibility, and min/max w; wmin=T(1e38); @@ -1058,7 +1056,8 @@ template<typename T, typename Serv> void x2dirty_wstack( auto nv=gconf.Nv(); size_t nvis = srv.Nvis(); auto w_supp = gconf.Supp(); - EC_Kernel_with_correction<T> kernel(w_supp); + size_t nthreads = gconf.Nthreads(); + EC_Kernel_with_correction<T> kernel(w_supp, nthreads); T wmin; vector<T> wval; double dw; @@ -1134,7 +1133,8 @@ template<typename T, typename Serv> void dirty2x_wstack( auto nv=gconf.Nv(); size_t nvis = srv.Nvis(); auto w_supp = gconf.Supp(); - EC_Kernel_with_correction<T> kernel(w_supp); + size_t nthreads = gconf.Nthreads(); + EC_Kernel_with_correction<T> kernel(w_supp, nthreads); T wmin; vector<T> wval; double dw; @@ -1204,7 +1204,7 @@ template<typename T> void dirty2vis_wstack(const Baselines<T> &baselines, template<typename T> size_t getIdxSize(const Baselines<T> &baselines, - const const_mav<bool,2> &flags, int chbegin, int chend, T wmin, T wmax) + const const_mav<bool,2> &flags, int chbegin, int chend, T wmin, T wmax, size_t nthreads) { size_t nrow=baselines.Nrows(), nchan=baselines.Nchannels(); if (chbegin<0) chbegin=0; @@ -1318,9 +1318,8 @@ template<typename T> void full_gridding(const const_mav<T,2> &uvw, const const_mav<T,2> &wgt, T pixsize_x, T pixsize_y, double epsilon, size_t nthreads, const mav<complex<T>,2> &dirty, size_t verbosity) { - set_nthreads(nthreads); Baselines<T> baselines(uvw, freq); - GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y); + GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y, nthreads); auto idx = getWgtIndices(baselines, gconf, wgt); auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()}); auto srv = MsServ<T,const_mav<complex<T>,2>,const_mav<T,2>>(baselines,idx2,ms,wgt); @@ -1332,9 +1331,8 @@ template<typename T> void full_degridding(const const_mav<T,2> &uvw, const const_mav<T,2> &wgt, T pixsize_x, T pixsize_y, double epsilon, size_t nthreads, const mav<complex<T>,2> &ms, size_t verbosity) { - set_nthreads(nthreads); Baselines<T> baselines(uvw, freq); - GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y); + GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y, nthreads); auto idx = getWgtIndices(baselines, gconf, wgt); auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()}); for (size_t i=0; i<ms.shape(0); ++i) @@ -1349,8 +1347,6 @@ template<typename T> void full_degridding(const const_mav<T,2> &uvw, // public names using detail::mav; using detail::const_mav; -using detail::set_nthreads; -using detail::get_nthreads; using detail::myassert; using detail::Baselines; using detail::GridderConfig; diff --git a/nifty_gridder.cc b/nifty_gridder.cc index 24bfc9439f36467eae177aa431a9c3f5cd20cbd3..238800a9225d9114ecf7db796c8c554eb754d374 100644 --- a/nifty_gridder.cc +++ b/nifty_gridder.cc @@ -33,29 +33,6 @@ namespace { auto None = py::none(); -// -// basic utilities -// - -// -// Start of real gridder functionality -// -constexpr auto set_nthreads_DS = R"""( -Specifies the number of threads to be used by the module - -Parameters -========== -nthreads: int - the number of threads to be used. Must be >=1. -)"""; -constexpr auto get_nthreads_DS = R"""( -Returns the number of threads used by the module - -Returns -======= -int : the number of threads used by the module -)"""; - template<typename T> using pyarr = py::array_t<T, 0>; @@ -210,7 +187,7 @@ template<typename T> class PyBaselines: public Baselines<T> } template<typename T2> pyarr<T2> ms2vis(const pyarr<T2> &ms_, - const pyarr<uint32_t> &idx_) const + const pyarr<uint32_t> &idx_, size_t nthreads) const { auto idx=make_const_mav<1>(idx_); size_t nvis = size_t(idx.shape(0)); @@ -219,7 +196,7 @@ template<typename T> class PyBaselines: public Baselines<T> auto vis = make_mav<1>(res); { py::gil_scoped_release release; - Baselines<T>::ms2vis(ms, idx, vis); + Baselines<T>::ms2vis(ms, idx, vis, nthreads); } return res; } @@ -242,7 +219,7 @@ template<typename T> class PyBaselines: public Baselines<T> the measurement set's visibility data (0 where not covered by idx) )"""; template<typename T2> pyarr<T2> vis2ms(const pyarr<T2> &vis_, - const pyarr<uint32_t> &idx_, py::object &ms_in) const + const pyarr<uint32_t> &idx_, py::object &ms_in, size_t nthreads) const { auto vis=make_const_mav<1>(vis_); auto idx=make_const_mav<1>(idx_); @@ -250,7 +227,7 @@ template<typename T> class PyBaselines: public Baselines<T> auto ms = make_mav<2>(res); { py::gil_scoped_release release; - Baselines<T>::vis2ms(vis, idx, ms); + Baselines<T>::vis2ms(vis, idx, ms, nthreads); } return res; } @@ -318,8 +295,8 @@ template<typename T> class PyGridderConfig: public GridderConfig<T> public: using GridderConfig<T>::GridderConfig; PyGridderConfig(size_t nxdirty, size_t nydirty, double epsilon, - double pixsize_x, double pixsize_y) - : GridderConfig<T>(nxdirty, nydirty, epsilon, pixsize_x, pixsize_y) {} + double pixsize_x, double pixsize_y, size_t nthreads) + : GridderConfig<T>(nxdirty, nydirty, epsilon, pixsize_x, pixsize_y, nthreads) {} using GridderConfig<T>::Nxdirty; using GridderConfig<T>::Nydirty; using GridderConfig<T>::Epsilon; @@ -616,13 +593,13 @@ np.array((nvis,), dtype=np.uint32) )"""; template<typename T> pyarr<uint32_t> PygetIndices(const PyBaselines<T> &baselines, const PyGridderConfig<T> &gconf, const pyarr<bool> &flags_, int chbegin, - int chend, T wmin, T wmax) + int chend, T wmin, T wmax, size_t nthreads) { size_t nidx; auto flags = make_const_mav<2>(flags_); { py::gil_scoped_release release; - nidx = getIdxSize(baselines, flags, chbegin, chend, wmin, wmax); + nidx = getIdxSize(baselines, flags, chbegin, chend, wmin, wmax, nthreads); } auto res = makeArray<uint32_t>({nidx}); auto res2 = make_mav<1>(res); @@ -709,13 +686,13 @@ PYBIND11_MODULE(nifty_gridder, m) .def ("Nrows",&PyBaselines<double>::Nrows) .def ("Nchannels",&PyBaselines<double>::Nchannels) .def ("ms2vis",&PyBaselines<double>::ms2vis<complex<double>>, - PyBaselines<double>::ms2vis_DS, "ms"_a, "idx"_a) + PyBaselines<double>::ms2vis_DS, "ms"_a, "idx"_a, "nthreads"_a=1) .def ("effectiveuvw",&PyBaselines<double>::effectiveuvw, "idx"_a) .def ("vis2ms",&PyBaselines<double>::vis2ms<complex<double>>, - PyBaselines<double>::vis2ms_DS, "vis"_a, "idx"_a, "ms_in"_a=None); + PyBaselines<double>::vis2ms_DS, "vis"_a, "idx"_a, "ms_in"_a=None, "nthreads"_a=1); py::class_<PyGridderConfig<double>> (m, "GridderConfig", GridderConfig_DS) - .def(py::init<size_t, size_t, double, double, double>(),"nxdirty"_a, - "nydirty"_a, "epsilon"_a, "pixsize_x"_a, "pixsize_y"_a) + .def(py::init<size_t, size_t, double, double, double, size_t>(),"nxdirty"_a, + "nydirty"_a, "epsilon"_a, "pixsize_x"_a, "pixsize_y"_a, "nthreads"_a=1) .def("Nxdirty", &PyGridderConfig<double>::Nxdirty) .def("Nydirty", &PyGridderConfig<double>::Nydirty) .def("Epsilon", &PyGridderConfig<double>::Epsilon) @@ -737,21 +714,21 @@ PYBIND11_MODULE(nifty_gridder, m) [](const PyGridderConfig<double> & gc) { // Encode object state in tuple return py::make_tuple(gc.Nxdirty(), gc.Nydirty(), gc.Epsilon(), - gc.Pixsize_x(), gc.Pixsize_y()); + gc.Pixsize_x(), gc.Pixsize_y(), gc.Nthreads()); }, // __setstate__ [](py::tuple t) { - myassert(t.size()==5,"Invalid state"); + myassert(t.size()==6,"Invalid state"); // Reconstruct from tuple return PyGridderConfig<double>(t[0].cast<size_t>(), t[1].cast<size_t>(), t[2].cast<double>(), t[3].cast<double>(), - t[4].cast<double>()); + t[4].cast<double>(), t[5].cast<size_t>()); })); m.def("getIndices", PygetIndices<double>, getIndices_DS, "baselines"_a, "gconf"_a, "flags"_a, "chbegin"_a=-1, "chend"_a=-1, - "wmin"_a=-1e30, "wmax"_a=1e30); + "wmin"_a=-1e30, "wmax"_a=1e30, "nthreads"_a=1); m.def("vis2grid_c",&Pyvis2grid_c<double>, vis2grid_c_DS, "baselines"_a, "gconf"_a, "idx"_a, "vis"_a, "grid_in"_a=None, "wgt"_a=None); m.def("ms2grid_c",&Pyms2grid_c<double>, ms2grid_c_DS, "baselines"_a, "gconf"_a, @@ -764,8 +741,6 @@ PYBIND11_MODULE(nifty_gridder, m) "grid"_a, "wgt"_a=None); m.def("get_correlations", &Pyget_correlations<double>, "baselines"_a, "gconf"_a, "idx"_a, "du"_a, "dv"_a, "wgt"_a=None); - m.def("set_nthreads",&set_nthreads, set_nthreads_DS, "nthreads"_a); - m.def("get_nthreads",&get_nthreads, get_nthreads_DS); m.def("vis2dirty_wstack",&Pyvis2dirty_wstack<double>, "baselines"_a, "gconf"_a, "idx"_a, "vis"_a); m.def("dirty2vis_wstack",&Pydirty2vis_wstack<double>, "baselines"_a, "gconf"_a,