Commit 0836cea1 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

nthreads is no longer global

parent 46487e0f
...@@ -149,17 +149,6 @@ template<typename T> inline T fmodulo (T v1, T v2) ...@@ -149,17 +149,6 @@ template<typename T> inline T fmodulo (T v1, T v2)
return (tmp==v2) ? T(0) : tmp; return (tmp==v2) ? T(0) : tmp;
} }
static size_t nthreads = 1;
void set_nthreads(size_t nthreads_)
{
myassert(nthreads_>=1, "nthreads must be >= 1");
nthreads = nthreads_;
}
size_t get_nthreads()
{ return nthreads; }
// //
// Utilities for Gauss-Legendre quadrature // Utilities for Gauss-Legendre quadrature
// //
...@@ -167,7 +156,7 @@ size_t get_nthreads() ...@@ -167,7 +156,7 @@ size_t get_nthreads()
static inline double one_minus_x2 (double x) static inline double one_minus_x2 (double x)
{ return (fabs(x)>0.1) ? (1.+x)*(1.-x) : 1.-x*x; } { return (fabs(x)>0.1) ? (1.+x)*(1.-x) : 1.-x*x; }
void legendre_prep(int n, vector<double> &x, vector<double> &w) void legendre_prep(int n, vector<double> &x, vector<double> &w, size_t nthreads)
{ {
constexpr double pi = 3.141592653589793238462643383279502884197; constexpr double pi = 3.141592653589793238462643383279502884197;
constexpr double eps = 3e-14; constexpr double eps = 3e-14;
...@@ -267,10 +256,10 @@ template<typename T> class EC_Kernel_with_correction: public EC_Kernel<T> ...@@ -267,10 +256,10 @@ template<typename T> class EC_Kernel_with_correction: public EC_Kernel<T>
public: public:
using EC_Kernel<T>::operator(); using EC_Kernel<T>::operator();
EC_Kernel_with_correction(size_t supp_) EC_Kernel_with_correction(size_t supp_, size_t nthreads)
: EC_Kernel<T>(supp_), p(int(1.5*supp_+2)) : EC_Kernel<T>(supp_), p(int(1.5*supp_+2))
{ {
legendre_prep(2*p,x,wgt); legendre_prep(2*p,x,wgt,nthreads);
psi=x; psi=x;
for (auto &v:psi) for (auto &v:psi)
v=operator()(v); v=operator()(v);
...@@ -287,9 +276,10 @@ template<typename T> class EC_Kernel_with_correction: public EC_Kernel<T> ...@@ -287,9 +276,10 @@ template<typename T> class EC_Kernel_with_correction: public EC_Kernel<T>
}; };
/* Compute correction factors for the ES gridding kernel /* Compute correction factors for the ES gridding kernel
This implementation follows eqs. (3.8) to (3.10) of Barnett et al. 2018 */ This implementation follows eqs. (3.8) to (3.10) of Barnett et al. 2018 */
vector<double> correction_factors (size_t n, size_t nval, size_t supp) vector<double> correction_factors (size_t n, size_t nval, size_t supp,
size_t nthreads)
{ {
EC_Kernel_with_correction<double> kernel(supp); EC_Kernel_with_correction<double> kernel(supp, nthreads);
vector<double> res(nval); vector<double> res(nval);
double xn = 1./n; double xn = 1./n;
#pragma omp parallel for schedule(static) num_threads(nthreads) #pragma omp parallel for schedule(static) num_threads(nthreads)
...@@ -359,7 +349,7 @@ template<typename T> class Baselines ...@@ -359,7 +349,7 @@ template<typename T> class Baselines
} }
template<typename T2> void ms2vis(const mav<const T2,2> &ms, template<typename T2> void ms2vis(const mav<const T2,2> &ms,
const mav<const uint32_t,1> &idx, mav<T2,1> &vis) const const mav<const uint32_t,1> &idx, mav<T2,1> &vis, size_t nthreads) const
{ {
myassert(ms.shape(0)==nrows, "shape mismatch"); myassert(ms.shape(0)==nrows, "shape mismatch");
myassert(ms.shape(1)==nchan, "shape mismatch"); myassert(ms.shape(1)==nchan, "shape mismatch");
...@@ -374,7 +364,7 @@ template<typename T> class Baselines ...@@ -374,7 +364,7 @@ template<typename T> class Baselines
} }
template<typename T2> void vis2ms(const mav<const T2,1> &vis, template<typename T2> void vis2ms(const mav<const T2,1> &vis,
const mav<const uint32_t,1> &idx, mav<T2,2> &ms) const const mav<const uint32_t,1> &idx, mav<T2,2> &ms, size_t nthreads) const
{ {
size_t nvis = vis.shape(0); size_t nvis = vis.shape(0);
myassert(idx.shape(0)==nvis, "shape mismatch"); myassert(idx.shape(0)==nvis, "shape mismatch");
...@@ -397,6 +387,7 @@ template<typename T> class GridderConfig ...@@ -397,6 +387,7 @@ template<typename T> class GridderConfig
size_t supp, nsafe, nu, nv; size_t supp, nsafe, nu, nv;
T beta; T beta;
vector<T> cfu, cfv; vector<T> cfu, cfv;
size_t nthreads;
complex<T> wscreen(double x, double y, double w, bool adjoint) const complex<T> wscreen(double x, double y, double w, bool adjoint) const
{ {
...@@ -412,13 +403,13 @@ template<typename T> class GridderConfig ...@@ -412,13 +403,13 @@ template<typename T> class GridderConfig
public: public:
GridderConfig(size_t nxdirty, size_t nydirty, double epsilon, GridderConfig(size_t nxdirty, size_t nydirty, double epsilon,
double pixsize_x, double pixsize_y) double pixsize_x, double pixsize_y, size_t nthreads_)
: nx_dirty(nxdirty), ny_dirty(nydirty), eps(epsilon), : nx_dirty(nxdirty), ny_dirty(nydirty), eps(epsilon),
psx(pixsize_x), psy(pixsize_y), psx(pixsize_x), psy(pixsize_y),
supp(get_supp(epsilon)), nsafe((supp+1)/2), supp(get_supp(epsilon)), nsafe((supp+1)/2),
nu(max(2*nsafe,2*nx_dirty)), nv(max(2*nsafe,2*ny_dirty)), nu(max(2*nsafe,2*nx_dirty)), nv(max(2*nsafe,2*ny_dirty)),
beta(2.3*supp), beta(2.3*supp),
cfu(nx_dirty), cfv(ny_dirty) cfu(nx_dirty), cfv(ny_dirty), nthreads(nthreads_)
{ {
myassert((nx_dirty&1)==0, "nx_dirty must be even"); myassert((nx_dirty&1)==0, "nx_dirty must be even");
myassert((ny_dirty&1)==0, "ny_dirty must be even"); myassert((ny_dirty&1)==0, "ny_dirty must be even");
...@@ -426,12 +417,12 @@ template<typename T> class GridderConfig ...@@ -426,12 +417,12 @@ template<typename T> class GridderConfig
myassert(pixsize_x>0, "pixsize_x must be positive"); myassert(pixsize_x>0, "pixsize_x must be positive");
myassert(pixsize_y>0, "pixsize_y must be positive"); myassert(pixsize_y>0, "pixsize_y must be positive");
auto tmp = correction_factors(nu, nx_dirty/2+1, supp); auto tmp = correction_factors(nu, nx_dirty/2+1, supp, nthreads);
cfu[nx_dirty/2]=tmp[0]; cfu[nx_dirty/2]=tmp[0];
cfu[0]=tmp[nx_dirty/2]; cfu[0]=tmp[nx_dirty/2];
for (size_t i=1; i<nx_dirty/2; ++i) for (size_t i=1; i<nx_dirty/2; ++i)
cfu[nx_dirty/2-i] = cfu[nx_dirty/2+i] = tmp[i]; cfu[nx_dirty/2-i] = cfu[nx_dirty/2+i] = tmp[i];
tmp = correction_factors(nv, ny_dirty/2+1, supp); tmp = correction_factors(nv, ny_dirty/2+1, supp, nthreads);
cfv[ny_dirty/2]=tmp[0]; cfv[ny_dirty/2]=tmp[0];
cfv[0]=tmp[ny_dirty/2]; cfv[0]=tmp[ny_dirty/2];
for (size_t i=1; i<ny_dirty/2; ++i) for (size_t i=1; i<ny_dirty/2; ++i)
...@@ -447,6 +438,7 @@ template<typename T> class GridderConfig ...@@ -447,6 +438,7 @@ template<typename T> class GridderConfig
size_t Supp() const { return supp; } size_t Supp() const { return supp; }
size_t Nsafe() const { return nsafe; } size_t Nsafe() const { return nsafe; }
T Beta() const { return beta; } T Beta() const { return beta; }
size_t Nthreads() const { return nthreads; }
void apply_taper(const mav<const T,2> &img, mav<T,2> &img2, bool divide) const void apply_taper(const mav<const T,2> &img, mav<T,2> &img2, bool divide) const
{ {
...@@ -733,6 +725,7 @@ template<typename T, typename Serv> void x2grid_c ...@@ -733,6 +725,7 @@ template<typename T, typename Serv> void x2grid_c
myassert(grid.contiguous(), "grid is not contiguous"); myassert(grid.contiguous(), "grid is not contiguous");
T beta = gconf.Beta(); T beta = gconf.Beta();
size_t supp = gconf.Supp(); size_t supp = gconf.Supp();
size_t nthreads = gconf.Nthreads();
#pragma omp parallel num_threads(nthreads) #pragma omp parallel num_threads(nthreads)
{ {
...@@ -786,6 +779,7 @@ template<typename T, typename Serv> void grid2x_c ...@@ -786,6 +779,7 @@ template<typename T, typename Serv> void grid2x_c
myassert(grid.contiguous(), "grid is not contiguous"); myassert(grid.contiguous(), "grid is not contiguous");
T beta = gconf.Beta(); T beta = gconf.Beta();
size_t supp = gconf.Supp(); size_t supp = gconf.Supp();
size_t nthreads = gconf.Nthreads();
// Loop over sampling points // Loop over sampling points
#pragma omp parallel num_threads(nthreads) #pragma omp parallel num_threads(nthreads)
...@@ -840,6 +834,7 @@ template<typename T> void apply_holo ...@@ -840,6 +834,7 @@ template<typename T> void apply_holo
size_t nu=gconf.Nu(), nv=gconf.Nv(); size_t nu=gconf.Nu(), nv=gconf.Nv();
checkShape(grid.shape(), {nu, nv}); checkShape(grid.shape(), {nu, nv});
size_t nvis = idx.shape(0); size_t nvis = idx.shape(0);
size_t nthreads = gconf.Nthreads();
bool have_wgt = wgt.size()!=0; bool have_wgt = wgt.size()!=0;
if (have_wgt) checkShape(wgt.shape(), {nvis}); if (have_wgt) checkShape(wgt.shape(), {nvis});
checkShape(ogrid.shape(), {nu, nv}); checkShape(ogrid.shape(), {nu, nv});
...@@ -906,6 +901,7 @@ template<typename T> void get_correlations ...@@ -906,6 +901,7 @@ template<typename T> void get_correlations
size_t supp = gconf.Supp(); size_t supp = gconf.Supp();
myassert(size_t(abs(du))<supp, "|du| must be smaller than Supp"); myassert(size_t(abs(du))<supp, "|du| must be smaller than Supp");
myassert(size_t(abs(dv))<supp, "|dv| must be smaller than Supp"); myassert(size_t(abs(dv))<supp, "|dv| must be smaller than Supp");
size_t nthreads = gconf.Nthreads();
T beta = gconf.Beta(); T beta = gconf.Beta();
for (size_t i=0; i<nu; ++i) for (size_t i=0; i<nu; ++i)
...@@ -960,6 +956,7 @@ template<typename T, typename T2> void apply_wcorr(const GridderConfig<T> &gconf ...@@ -960,6 +956,7 @@ template<typename T, typename T2> void apply_wcorr(const GridderConfig<T> &gconf
{ {
auto nx_dirty=gconf.Nxdirty(); auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty(); auto ny_dirty=gconf.Nydirty();
size_t nthreads = gconf.Nthreads();
auto psx=gconf.Pixsize_x(); auto psx=gconf.Pixsize_x();
auto psy=gconf.Pixsize_y(); auto psy=gconf.Pixsize_y();
double x0 = -0.5*nx_dirty*psx, double x0 = -0.5*nx_dirty*psx,
...@@ -1003,6 +1000,7 @@ template<typename T, typename Serv> void wstack_common( ...@@ -1003,6 +1000,7 @@ template<typename T, typename Serv> void wstack_common(
auto psx=gconf.Pixsize_x(); auto psx=gconf.Pixsize_x();
auto psy=gconf.Pixsize_y(); auto psy=gconf.Pixsize_y();
size_t nvis = srv.Nvis(); size_t nvis = srv.Nvis();
size_t nthreads = gconf.Nthreads();
// determine w values for every visibility, and min/max w; // determine w values for every visibility, and min/max w;
wmin=T(1e38); wmin=T(1e38);
...@@ -1058,7 +1056,8 @@ template<typename T, typename Serv> void x2dirty_wstack( ...@@ -1058,7 +1056,8 @@ template<typename T, typename Serv> void x2dirty_wstack(
auto nv=gconf.Nv(); auto nv=gconf.Nv();
size_t nvis = srv.Nvis(); size_t nvis = srv.Nvis();
auto w_supp = gconf.Supp(); auto w_supp = gconf.Supp();
EC_Kernel_with_correction<T> kernel(w_supp); size_t nthreads = gconf.Nthreads();
EC_Kernel_with_correction<T> kernel(w_supp, nthreads);
T wmin; T wmin;
vector<T> wval; vector<T> wval;
double dw; double dw;
...@@ -1134,7 +1133,8 @@ template<typename T, typename Serv> void dirty2x_wstack( ...@@ -1134,7 +1133,8 @@ template<typename T, typename Serv> void dirty2x_wstack(
auto nv=gconf.Nv(); auto nv=gconf.Nv();
size_t nvis = srv.Nvis(); size_t nvis = srv.Nvis();
auto w_supp = gconf.Supp(); auto w_supp = gconf.Supp();
EC_Kernel_with_correction<T> kernel(w_supp); size_t nthreads = gconf.Nthreads();
EC_Kernel_with_correction<T> kernel(w_supp, nthreads);
T wmin; T wmin;
vector<T> wval; vector<T> wval;
double dw; double dw;
...@@ -1204,7 +1204,7 @@ template<typename T> void dirty2vis_wstack(const Baselines<T> &baselines, ...@@ -1204,7 +1204,7 @@ template<typename T> void dirty2vis_wstack(const Baselines<T> &baselines,
template<typename T> size_t getIdxSize(const Baselines<T> &baselines, template<typename T> size_t getIdxSize(const Baselines<T> &baselines,
const const_mav<bool,2> &flags, int chbegin, int chend, T wmin, T wmax) const const_mav<bool,2> &flags, int chbegin, int chend, T wmin, T wmax, size_t nthreads)
{ {
size_t nrow=baselines.Nrows(), nchan=baselines.Nchannels(); size_t nrow=baselines.Nrows(), nchan=baselines.Nchannels();
if (chbegin<0) chbegin=0; if (chbegin<0) chbegin=0;
...@@ -1318,9 +1318,8 @@ template<typename T> void full_gridding(const const_mav<T,2> &uvw, ...@@ -1318,9 +1318,8 @@ template<typename T> void full_gridding(const const_mav<T,2> &uvw,
const const_mav<T,2> &wgt, T pixsize_x, T pixsize_y, double epsilon, const const_mav<T,2> &wgt, T pixsize_x, T pixsize_y, double epsilon,
size_t nthreads, const mav<complex<T>,2> &dirty, size_t verbosity) size_t nthreads, const mav<complex<T>,2> &dirty, size_t verbosity)
{ {
set_nthreads(nthreads);
Baselines<T> baselines(uvw, freq); Baselines<T> baselines(uvw, freq);
GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y); GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y, nthreads);
auto idx = getWgtIndices(baselines, gconf, wgt); auto idx = getWgtIndices(baselines, gconf, wgt);
auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()}); auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()});
auto srv = MsServ<T,const_mav<complex<T>,2>,const_mav<T,2>>(baselines,idx2,ms,wgt); auto srv = MsServ<T,const_mav<complex<T>,2>,const_mav<T,2>>(baselines,idx2,ms,wgt);
...@@ -1332,9 +1331,8 @@ template<typename T> void full_degridding(const const_mav<T,2> &uvw, ...@@ -1332,9 +1331,8 @@ template<typename T> void full_degridding(const const_mav<T,2> &uvw,
const const_mav<T,2> &wgt, T pixsize_x, T pixsize_y, double epsilon, const const_mav<T,2> &wgt, T pixsize_x, T pixsize_y, double epsilon,
size_t nthreads, const mav<complex<T>,2> &ms, size_t verbosity) size_t nthreads, const mav<complex<T>,2> &ms, size_t verbosity)
{ {
set_nthreads(nthreads);
Baselines<T> baselines(uvw, freq); Baselines<T> baselines(uvw, freq);
GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y); GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y, nthreads);
auto idx = getWgtIndices(baselines, gconf, wgt); auto idx = getWgtIndices(baselines, gconf, wgt);
auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()}); auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()});
for (size_t i=0; i<ms.shape(0); ++i) for (size_t i=0; i<ms.shape(0); ++i)
...@@ -1349,8 +1347,6 @@ template<typename T> void full_degridding(const const_mav<T,2> &uvw, ...@@ -1349,8 +1347,6 @@ template<typename T> void full_degridding(const const_mav<T,2> &uvw,
// public names // public names
using detail::mav; using detail::mav;
using detail::const_mav; using detail::const_mav;
using detail::set_nthreads;
using detail::get_nthreads;
using detail::myassert; using detail::myassert;
using detail::Baselines; using detail::Baselines;
using detail::GridderConfig; using detail::GridderConfig;
......
...@@ -33,29 +33,6 @@ namespace { ...@@ -33,29 +33,6 @@ namespace {
auto None = py::none(); auto None = py::none();
//
// basic utilities
//
//
// Start of real gridder functionality
//
constexpr auto set_nthreads_DS = R"""(
Specifies the number of threads to be used by the module
Parameters
==========
nthreads: int
the number of threads to be used. Must be >=1.
)""";
constexpr auto get_nthreads_DS = R"""(
Returns the number of threads used by the module
Returns
=======
int : the number of threads used by the module
)""";
template<typename T> template<typename T>
using pyarr = py::array_t<T, 0>; using pyarr = py::array_t<T, 0>;
...@@ -210,7 +187,7 @@ template<typename T> class PyBaselines: public Baselines<T> ...@@ -210,7 +187,7 @@ template<typename T> class PyBaselines: public Baselines<T>
} }
template<typename T2> pyarr<T2> ms2vis(const pyarr<T2> &ms_, template<typename T2> pyarr<T2> ms2vis(const pyarr<T2> &ms_,
const pyarr<uint32_t> &idx_) const const pyarr<uint32_t> &idx_, size_t nthreads) const
{ {
auto idx=make_const_mav<1>(idx_); auto idx=make_const_mav<1>(idx_);
size_t nvis = size_t(idx.shape(0)); size_t nvis = size_t(idx.shape(0));
...@@ -219,7 +196,7 @@ template<typename T> class PyBaselines: public Baselines<T> ...@@ -219,7 +196,7 @@ template<typename T> class PyBaselines: public Baselines<T>
auto vis = make_mav<1>(res); auto vis = make_mav<1>(res);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
Baselines<T>::ms2vis(ms, idx, vis); Baselines<T>::ms2vis(ms, idx, vis, nthreads);
} }
return res; return res;
} }
...@@ -242,7 +219,7 @@ template<typename T> class PyBaselines: public Baselines<T> ...@@ -242,7 +219,7 @@ template<typename T> class PyBaselines: public Baselines<T>
the measurement set's visibility data (0 where not covered by idx) the measurement set's visibility data (0 where not covered by idx)
)"""; )""";
template<typename T2> pyarr<T2> vis2ms(const pyarr<T2> &vis_, template<typename T2> pyarr<T2> vis2ms(const pyarr<T2> &vis_,
const pyarr<uint32_t> &idx_, py::object &ms_in) const const pyarr<uint32_t> &idx_, py::object &ms_in, size_t nthreads) const
{ {
auto vis=make_const_mav<1>(vis_); auto vis=make_const_mav<1>(vis_);
auto idx=make_const_mav<1>(idx_); auto idx=make_const_mav<1>(idx_);
...@@ -250,7 +227,7 @@ template<typename T> class PyBaselines: public Baselines<T> ...@@ -250,7 +227,7 @@ template<typename T> class PyBaselines: public Baselines<T>
auto ms = make_mav<2>(res); auto ms = make_mav<2>(res);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
Baselines<T>::vis2ms(vis, idx, ms); Baselines<T>::vis2ms(vis, idx, ms, nthreads);
} }
return res; return res;
} }
...@@ -318,8 +295,8 @@ template<typename T> class PyGridderConfig: public GridderConfig<T> ...@@ -318,8 +295,8 @@ template<typename T> class PyGridderConfig: public GridderConfig<T>
public: public:
using GridderConfig<T>::GridderConfig; using GridderConfig<T>::GridderConfig;
PyGridderConfig(size_t nxdirty, size_t nydirty, double epsilon, PyGridderConfig(size_t nxdirty, size_t nydirty, double epsilon,
double pixsize_x, double pixsize_y) double pixsize_x, double pixsize_y, size_t nthreads)
: GridderConfig<T>(nxdirty, nydirty, epsilon, pixsize_x, pixsize_y) {} : GridderConfig<T>(nxdirty, nydirty, epsilon, pixsize_x, pixsize_y, nthreads) {}
using GridderConfig<T>::Nxdirty; using GridderConfig<T>::Nxdirty;
using GridderConfig<T>::Nydirty; using GridderConfig<T>::Nydirty;
using GridderConfig<T>::Epsilon; using GridderConfig<T>::Epsilon;
...@@ -616,13 +593,13 @@ np.array((nvis,), dtype=np.uint32) ...@@ -616,13 +593,13 @@ np.array((nvis,), dtype=np.uint32)
)"""; )""";
template<typename T> pyarr<uint32_t> PygetIndices(const PyBaselines<T> &baselines, template<typename T> pyarr<uint32_t> PygetIndices(const PyBaselines<T> &baselines,
const PyGridderConfig<T> &gconf, const pyarr<bool> &flags_, int chbegin, const PyGridderConfig<T> &gconf, const pyarr<bool> &flags_, int chbegin,
int chend, T wmin, T wmax) int chend, T wmin, T wmax, size_t nthreads)
{ {
size_t nidx; size_t nidx;
auto flags = make_const_mav<2>(flags_); auto flags = make_const_mav<2>(flags_);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
nidx = getIdxSize(baselines, flags, chbegin, chend, wmin, wmax); nidx = getIdxSize(baselines, flags, chbegin, chend, wmin, wmax, nthreads);
} }
auto res = makeArray<uint32_t>({nidx}); auto res = makeArray<uint32_t>({nidx});
auto res2 = make_mav<1>(res); auto res2 = make_mav<1>(res);
...@@ -709,13 +686,13 @@ PYBIND11_MODULE(nifty_gridder, m) ...@@ -709,13 +686,13 @@ PYBIND11_MODULE(nifty_gridder, m)
.def ("Nrows",&PyBaselines<double>::Nrows) .def ("Nrows",&PyBaselines<double>::Nrows)
.def ("Nchannels",&PyBaselines<double>::Nchannels) .def ("Nchannels",&PyBaselines<double>::Nchannels)
.def ("ms2vis",&PyBaselines<double>::ms2vis<complex<double>>, .def ("ms2vis",&PyBaselines<double>::ms2vis<complex<double>>,
PyBaselines<double>::ms2vis_DS, "ms"_a, "idx"_a) PyBaselines<double>::ms2vis_DS, "ms"_a, "idx"_a, "nthreads"_a=1)
.def ("effectiveuvw",&PyBaselines<double>::effectiveuvw, "idx"_a) .def ("effectiveuvw",&PyBaselines<double>::effectiveuvw, "idx"_a)
.def ("vis2ms",&PyBaselines<double>::vis2ms<complex<double>>, .def ("vis2ms",&PyBaselines<double>::vis2ms<complex<double>>,
PyBaselines<double>::vis2ms_DS, "vis"_a, "idx"_a, "ms_in"_a=None); PyBaselines<double>::vis2ms_DS, "vis"_a, "idx"_a, "ms_in"_a=None, "nthreads"_a=1);
py::class_<PyGridderConfig<double>> (m, "GridderConfig", GridderConfig_DS) py::class_<PyGridderConfig<double>> (m, "GridderConfig", GridderConfig_DS)
.def(py::init<size_t, size_t, double, double, double>(),"nxdirty"_a, .def(py::init<size_t, size_t, double, double, double, size_t>(),"nxdirty"_a,
"nydirty"_a, "epsilon"_a, "pixsize_x"_a, "pixsize_y"_a) "nydirty"_a, "epsilon"_a, "pixsize_x"_a, "pixsize_y"_a, "nthreads"_a=1)
.def("Nxdirty", &PyGridderConfig<double>::Nxdirty) .def("Nxdirty", &PyGridderConfig<double>::Nxdirty)
.def("Nydirty", &PyGridderConfig<double>::Nydirty) .def("Nydirty", &PyGridderConfig<double>::Nydirty)
.def("Epsilon", &PyGridderConfig<double>::Epsilon) .def("Epsilon", &PyGridderConfig<double>::Epsilon)
...@@ -737,21 +714,21 @@ PYBIND11_MODULE(nifty_gridder, m) ...@@ -737,21 +714,21 @@ PYBIND11_MODULE(nifty_gridder, m)
[](const PyGridderConfig<double> & gc) { [](const PyGridderConfig<double> & gc) {
// Encode object state in tuple // Encode object state in tuple
return py::make_tuple(gc.Nxdirty(), gc.Nydirty(), gc.Epsilon(), return py::make_tuple(gc.Nxdirty(), gc.Nydirty(), gc.Epsilon(),
gc.Pixsize_x(), gc.Pixsize_y()); gc.Pixsize_x(), gc.Pixsize_y(), gc.Nthreads());
}, },
// __setstate__ // __setstate__
[](py::tuple t) { [](py::tuple t) {
myassert(t.size()==5,"Invalid state"); myassert(t.size()==6,"Invalid state");
// Reconstruct from tuple // Reconstruct from tuple
return PyGridderConfig<double>(t[0].cast<size_t>(), t[1].cast<size_t>(), return PyGridderConfig<double>(t[0].cast<size_t>(), t[1].cast<size_t>(),
t[2].cast<double>(), t[3].cast<double>(), t[2].cast<double>(), t[3].cast<double>(),
t[4].cast<double>()); t[4].cast<double>(), t[5].cast<size_t>());
})); }));
m.def("getIndices", PygetIndices<double>, getIndices_DS, "baselines"_a, m.def("getIndices", PygetIndices<double>, getIndices_DS, "baselines"_a,
"gconf"_a, "flags"_a, "chbegin"_a=-1, "chend"_a=-1, "gconf"_a, "flags"_a, "chbegin"_a=-1, "chend"_a=-1,
"wmin"_a=-1e30, "wmax"_a=1e30); "wmin"_a=-1e30, "wmax"_a=1e30, "nthreads"_a=1);
m.def("vis2grid_c",&Pyvis2grid_c<double>, vis2grid_c_DS, "baselines"_a, m.def("vis2grid_c",&Pyvis2grid_c<double>, vis2grid_c_DS, "baselines"_a,
"gconf"_a, "idx"_a, "vis"_a, "grid_in"_a=None, "wgt"_a=None); "gconf"_a, "idx"_a, "vis"_a, "grid_in"_a=None, "wgt"_a=None);
m.def("ms2grid_c",&Pyms2grid_c<double>, ms2grid_c_DS, "baselines"_a, "gconf"_a, m.def("ms2grid_c",&Pyms2grid_c<double>, ms2grid_c_DS, "baselines"_a, "gconf"_a,
...@@ -764,8 +741,6 @@ PYBIND11_MODULE(nifty_gridder, m) ...@@ -764,8 +741,6 @@ PYBIND11_MODULE(nifty_gridder, m)
"grid"_a, "wgt"_a=None); "grid"_a, "wgt"_a=None);
m.def("get_correlations", &Pyget_correlations<double>, "baselines"_a, "gconf"_a, m.def("get_correlations", &Pyget_correlations<double>, "baselines"_a, "gconf"_a,
"idx"_a, "du"_a, "dv"_a, "wgt"_a=None); "idx"_a, "du"_a, "dv"_a, "wgt"_a=None);
m.def("set_nthreads",&set_nthreads, set_nthreads_DS, "nthreads"_a);
m.def("get_nthreads",&get_nthreads, get_nthreads_DS);
m.def("vis2dirty_wstack",&Pyvis2dirty_wstack<double>, "baselines"_a, "gconf"_a, m.def("vis2dirty_wstack",&Pyvis2dirty_wstack<double>, "baselines"_a, "gconf"_a,
"idx"_a, "vis"_a); "idx"_a, "vis"_a);
m.def("dirty2vis_wstack",&Pydirty2vis_wstack<double>, "baselines"_a, "gconf"_a, m.def("dirty2vis_wstack",&Pydirty2vis_wstack<double>, "baselines"_a, "gconf"_a,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment