diff --git a/gridder_cxx.h b/gridder_cxx.h
index b1bc3362c52f74b9aa427c9cc010b9b46adc1d6e..1c9646c56f2463a034c1497b876888ad3d19276d 100644
--- a/gridder_cxx.h
+++ b/gridder_cxx.h
@@ -149,17 +149,6 @@ template<typename T> inline T fmodulo (T v1, T v2)
   return (tmp==v2) ? T(0) : tmp;
   }
 
-static size_t nthreads = 1;
-
-void set_nthreads(size_t nthreads_)
-  {
-  myassert(nthreads_>=1, "nthreads must be >= 1");
-  nthreads = nthreads_;
-  }
-
-size_t get_nthreads()
-  { return nthreads; }
-
 //
 // Utilities for Gauss-Legendre quadrature
 //
@@ -167,7 +156,7 @@ size_t get_nthreads()
 static inline double one_minus_x2 (double x)
   { return (fabs(x)>0.1) ? (1.+x)*(1.-x) : 1.-x*x; }
 
-void legendre_prep(int n, vector<double> &x, vector<double> &w)
+void legendre_prep(int n, vector<double> &x, vector<double> &w, size_t nthreads)
   {
   constexpr double pi = 3.141592653589793238462643383279502884197;
   constexpr double eps = 3e-14;
@@ -267,10 +256,10 @@ template<typename T> class EC_Kernel_with_correction: public EC_Kernel<T>
 
   public:
     using EC_Kernel<T>::operator();
-    EC_Kernel_with_correction(size_t supp_)
+    EC_Kernel_with_correction(size_t supp_, size_t nthreads)
       : EC_Kernel<T>(supp_), p(int(1.5*supp_+2))
       {
-      legendre_prep(2*p,x,wgt);
+      legendre_prep(2*p,x,wgt,nthreads);
       psi=x;
       for (auto &v:psi)
         v=operator()(v);
@@ -287,9 +276,10 @@ template<typename T> class EC_Kernel_with_correction: public EC_Kernel<T>
   };
 /* Compute correction factors for the ES gridding kernel
    This implementation follows eqs. (3.8) to (3.10) of Barnett et al. 2018 */
-vector<double> correction_factors (size_t n, size_t nval, size_t supp)
+vector<double> correction_factors (size_t n, size_t nval, size_t supp,
+  size_t nthreads)
   {
-  EC_Kernel_with_correction<double> kernel(supp);
+  EC_Kernel_with_correction<double> kernel(supp, nthreads);
   vector<double> res(nval);
   double xn = 1./n;
 #pragma omp parallel for schedule(static) num_threads(nthreads)
@@ -359,7 +349,7 @@ template<typename T> class Baselines
       }
 
     template<typename T2> void ms2vis(const mav<const T2,2> &ms,
-      const mav<const uint32_t,1> &idx, mav<T2,1> &vis) const
+      const mav<const uint32_t,1> &idx, mav<T2,1> &vis, size_t nthreads) const
       {
       myassert(ms.shape(0)==nrows, "shape mismatch");
       myassert(ms.shape(1)==nchan, "shape mismatch");
@@ -374,7 +364,7 @@ template<typename T> class Baselines
       }
 
     template<typename T2> void vis2ms(const mav<const T2,1> &vis,
-      const mav<const uint32_t,1> &idx, mav<T2,2> &ms) const
+      const mav<const uint32_t,1> &idx, mav<T2,2> &ms, size_t nthreads) const
       {
       size_t nvis = vis.shape(0);
       myassert(idx.shape(0)==nvis, "shape mismatch");
@@ -397,6 +387,7 @@ template<typename T> class GridderConfig
     size_t supp, nsafe, nu, nv;
     T beta;
     vector<T> cfu, cfv;
+    size_t nthreads;
 
     complex<T> wscreen(double x, double y, double w, bool adjoint) const
       {
@@ -412,13 +403,13 @@ template<typename T> class GridderConfig
 
   public:
     GridderConfig(size_t nxdirty, size_t nydirty, double epsilon,
-      double pixsize_x, double pixsize_y)
+      double pixsize_x, double pixsize_y, size_t nthreads_)
       : nx_dirty(nxdirty), ny_dirty(nydirty), eps(epsilon),
         psx(pixsize_x), psy(pixsize_y),
         supp(get_supp(epsilon)), nsafe((supp+1)/2),
         nu(max(2*nsafe,2*nx_dirty)), nv(max(2*nsafe,2*ny_dirty)),
         beta(2.3*supp),
-        cfu(nx_dirty), cfv(ny_dirty)
+        cfu(nx_dirty), cfv(ny_dirty), nthreads(nthreads_)
       {
       myassert((nx_dirty&1)==0, "nx_dirty must be even");
       myassert((ny_dirty&1)==0, "ny_dirty must be even");
@@ -426,12 +417,12 @@ template<typename T> class GridderConfig
       myassert(pixsize_x>0, "pixsize_x must be positive");
       myassert(pixsize_y>0, "pixsize_y must be positive");
 
-      auto tmp = correction_factors(nu, nx_dirty/2+1, supp);
+      auto tmp = correction_factors(nu, nx_dirty/2+1, supp, nthreads);
       cfu[nx_dirty/2]=tmp[0];
       cfu[0]=tmp[nx_dirty/2];
       for (size_t i=1; i<nx_dirty/2; ++i)
         cfu[nx_dirty/2-i] = cfu[nx_dirty/2+i] = tmp[i];
-      tmp = correction_factors(nv, ny_dirty/2+1, supp);
+      tmp = correction_factors(nv, ny_dirty/2+1, supp, nthreads);
       cfv[ny_dirty/2]=tmp[0];
       cfv[0]=tmp[ny_dirty/2];
       for (size_t i=1; i<ny_dirty/2; ++i)
@@ -447,6 +438,7 @@ template<typename T> class GridderConfig
     size_t Supp() const { return supp; }
     size_t Nsafe() const { return nsafe; }
     T Beta() const { return beta; }
+    size_t Nthreads() const { return nthreads; }
 
     void apply_taper(const mav<const T,2> &img, mav<T,2> &img2, bool divide) const
       {
@@ -733,6 +725,7 @@ template<typename T, typename Serv> void x2grid_c
   myassert(grid.contiguous(), "grid is not contiguous");
   T beta = gconf.Beta();
   size_t supp = gconf.Supp();
+  size_t nthreads = gconf.Nthreads();
 
 #pragma omp parallel num_threads(nthreads)
 {
@@ -786,6 +779,7 @@ template<typename T, typename Serv> void grid2x_c
   myassert(grid.contiguous(), "grid is not contiguous");
   T beta = gconf.Beta();
   size_t supp = gconf.Supp();
+  size_t nthreads = gconf.Nthreads();
 
   // Loop over sampling points
 #pragma omp parallel num_threads(nthreads)
@@ -840,6 +834,7 @@ template<typename T> void apply_holo
   size_t nu=gconf.Nu(), nv=gconf.Nv();
   checkShape(grid.shape(), {nu, nv});
   size_t nvis = idx.shape(0);
+  size_t nthreads = gconf.Nthreads();
   bool have_wgt = wgt.size()!=0;
   if (have_wgt) checkShape(wgt.shape(), {nvis});
   checkShape(ogrid.shape(), {nu, nv});
@@ -906,6 +901,7 @@ template<typename T> void get_correlations
   size_t supp = gconf.Supp();
   myassert(size_t(abs(du))<supp, "|du| must be smaller than Supp");
   myassert(size_t(abs(dv))<supp, "|dv| must be smaller than Supp");
+  size_t nthreads = gconf.Nthreads();
 
   T beta = gconf.Beta();
   for (size_t i=0; i<nu; ++i)
@@ -960,6 +956,7 @@ template<typename T, typename T2> void apply_wcorr(const GridderConfig<T> &gconf
   {
   auto nx_dirty=gconf.Nxdirty();
   auto ny_dirty=gconf.Nydirty();
+  size_t nthreads = gconf.Nthreads();
   auto psx=gconf.Pixsize_x();
   auto psy=gconf.Pixsize_y();
   double x0 = -0.5*nx_dirty*psx,
@@ -1003,6 +1000,7 @@ template<typename T, typename Serv> void wstack_common(
   auto psx=gconf.Pixsize_x();
   auto psy=gconf.Pixsize_y();
   size_t nvis = srv.Nvis();
+  size_t nthreads = gconf.Nthreads();
 
   // determine w values for every visibility, and min/max w;
   wmin=T(1e38);
@@ -1058,7 +1056,8 @@ template<typename T, typename Serv> void x2dirty_wstack(
   auto nv=gconf.Nv();
   size_t nvis = srv.Nvis();
   auto w_supp = gconf.Supp();
-  EC_Kernel_with_correction<T> kernel(w_supp);
+  size_t nthreads = gconf.Nthreads();
+  EC_Kernel_with_correction<T> kernel(w_supp, nthreads);
   T wmin;
   vector<T> wval;
   double dw;
@@ -1134,7 +1133,8 @@ template<typename T, typename Serv> void dirty2x_wstack(
   auto nv=gconf.Nv();
   size_t nvis = srv.Nvis();
   auto w_supp = gconf.Supp();
-  EC_Kernel_with_correction<T> kernel(w_supp);
+  size_t nthreads = gconf.Nthreads();
+  EC_Kernel_with_correction<T> kernel(w_supp, nthreads);
   T wmin;
   vector<T> wval;
   double dw;
@@ -1204,7 +1204,7 @@ template<typename T> void dirty2vis_wstack(const Baselines<T> &baselines,
 
 
 template<typename T> size_t getIdxSize(const Baselines<T> &baselines,
-  const const_mav<bool,2> &flags, int chbegin, int chend, T wmin, T wmax)
+  const const_mav<bool,2> &flags, int chbegin, int chend, T wmin, T wmax, size_t nthreads)
   {
   size_t nrow=baselines.Nrows(), nchan=baselines.Nchannels();
   if (chbegin<0) chbegin=0;
@@ -1318,9 +1318,8 @@ template<typename T> void full_gridding(const const_mav<T,2> &uvw,
   const const_mav<T,2> &wgt, T pixsize_x, T pixsize_y, double epsilon,
   size_t nthreads, const mav<complex<T>,2> &dirty, size_t verbosity)
   {
-  set_nthreads(nthreads);
   Baselines<T> baselines(uvw, freq);
-  GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y);
+  GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y, nthreads);
   auto idx = getWgtIndices(baselines, gconf, wgt);
   auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()});
   auto srv = MsServ<T,const_mav<complex<T>,2>,const_mav<T,2>>(baselines,idx2,ms,wgt);
@@ -1332,9 +1331,8 @@ template<typename T> void full_degridding(const const_mav<T,2> &uvw,
   const const_mav<T,2> &wgt, T pixsize_x, T pixsize_y, double epsilon,
   size_t nthreads, const mav<complex<T>,2> &ms, size_t verbosity)
   {
-  set_nthreads(nthreads);
   Baselines<T> baselines(uvw, freq);
-  GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y);
+  GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y, nthreads);
   auto idx = getWgtIndices(baselines, gconf, wgt);
   auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()});
   for (size_t i=0; i<ms.shape(0); ++i)
@@ -1349,8 +1347,6 @@ template<typename T> void full_degridding(const const_mav<T,2> &uvw,
 // public names
 using detail::mav;
 using detail::const_mav;
-using detail::set_nthreads;
-using detail::get_nthreads;
 using detail::myassert;
 using detail::Baselines;
 using detail::GridderConfig;
diff --git a/nifty_gridder.cc b/nifty_gridder.cc
index 24bfc9439f36467eae177aa431a9c3f5cd20cbd3..238800a9225d9114ecf7db796c8c554eb754d374 100644
--- a/nifty_gridder.cc
+++ b/nifty_gridder.cc
@@ -33,29 +33,6 @@ namespace {
 
 auto None = py::none();
 
-//
-// basic utilities
-//
-
-//
-// Start of real gridder functionality
-//
-constexpr auto set_nthreads_DS = R"""(
-Specifies the number of threads to be used by the module
-
-Parameters
-==========
-nthreads: int
-    the number of threads to be used. Must be >=1.
-)""";
-constexpr auto get_nthreads_DS = R"""(
-Returns the number of threads used by the module
-
-Returns
-=======
-int : the number of threads used by the module
-)""";
-
 template<typename T>
   using pyarr = py::array_t<T, 0>;
 
@@ -210,7 +187,7 @@ template<typename T> class PyBaselines: public Baselines<T>
       }
 
     template<typename T2> pyarr<T2> ms2vis(const pyarr<T2> &ms_,
-      const pyarr<uint32_t> &idx_) const
+      const pyarr<uint32_t> &idx_, size_t nthreads) const
       {
       auto idx=make_const_mav<1>(idx_);
       size_t nvis = size_t(idx.shape(0));
@@ -219,7 +196,7 @@ template<typename T> class PyBaselines: public Baselines<T>
       auto vis = make_mav<1>(res);
       {
       py::gil_scoped_release release;
-      Baselines<T>::ms2vis(ms, idx, vis);
+      Baselines<T>::ms2vis(ms, idx, vis, nthreads);
       }
       return res;
       }
@@ -242,7 +219,7 @@ template<typename T> class PyBaselines: public Baselines<T>
         the measurement set's visibility data (0 where not covered by idx)
     )""";
     template<typename T2> pyarr<T2> vis2ms(const pyarr<T2> &vis_,
-      const pyarr<uint32_t> &idx_, py::object &ms_in) const
+      const pyarr<uint32_t> &idx_, py::object &ms_in, size_t nthreads) const
       {
       auto vis=make_const_mav<1>(vis_);
       auto idx=make_const_mav<1>(idx_);
@@ -250,7 +227,7 @@ template<typename T> class PyBaselines: public Baselines<T>
       auto ms = make_mav<2>(res);
       {
       py::gil_scoped_release release;
-      Baselines<T>::vis2ms(vis, idx, ms);
+      Baselines<T>::vis2ms(vis, idx, ms, nthreads);
       }
       return res;
       }
@@ -318,8 +295,8 @@ template<typename T> class PyGridderConfig: public GridderConfig<T>
   public:
     using GridderConfig<T>::GridderConfig;
     PyGridderConfig(size_t nxdirty, size_t nydirty, double epsilon,
-      double pixsize_x, double pixsize_y)
-      : GridderConfig<T>(nxdirty, nydirty, epsilon, pixsize_x, pixsize_y) {}
+      double pixsize_x, double pixsize_y, size_t nthreads)
+      : GridderConfig<T>(nxdirty, nydirty, epsilon, pixsize_x, pixsize_y, nthreads) {}
     using GridderConfig<T>::Nxdirty;
     using GridderConfig<T>::Nydirty;
     using GridderConfig<T>::Epsilon;
@@ -616,13 +593,13 @@ np.array((nvis,), dtype=np.uint32)
 )""";
 template<typename T> pyarr<uint32_t> PygetIndices(const PyBaselines<T> &baselines,
   const PyGridderConfig<T> &gconf, const pyarr<bool> &flags_, int chbegin,
-  int chend, T wmin, T wmax)
+  int chend, T wmin, T wmax, size_t nthreads)
   {
   size_t nidx;
   auto flags = make_const_mav<2>(flags_);
   {
   py::gil_scoped_release release;
-  nidx = getIdxSize(baselines, flags, chbegin, chend, wmin, wmax);
+  nidx = getIdxSize(baselines, flags, chbegin, chend, wmin, wmax, nthreads);
   }
   auto res = makeArray<uint32_t>({nidx});
   auto res2 = make_mav<1>(res);
@@ -709,13 +686,13 @@ PYBIND11_MODULE(nifty_gridder, m)
     .def ("Nrows",&PyBaselines<double>::Nrows)
     .def ("Nchannels",&PyBaselines<double>::Nchannels)
     .def ("ms2vis",&PyBaselines<double>::ms2vis<complex<double>>,
-      PyBaselines<double>::ms2vis_DS, "ms"_a, "idx"_a)
+      PyBaselines<double>::ms2vis_DS, "ms"_a, "idx"_a, "nthreads"_a=1)
     .def ("effectiveuvw",&PyBaselines<double>::effectiveuvw, "idx"_a)
     .def ("vis2ms",&PyBaselines<double>::vis2ms<complex<double>>,
-      PyBaselines<double>::vis2ms_DS, "vis"_a, "idx"_a, "ms_in"_a=None);
+      PyBaselines<double>::vis2ms_DS, "vis"_a, "idx"_a, "ms_in"_a=None, "nthreads"_a=1);
   py::class_<PyGridderConfig<double>> (m, "GridderConfig", GridderConfig_DS)
-    .def(py::init<size_t, size_t, double, double, double>(),"nxdirty"_a,
-      "nydirty"_a, "epsilon"_a, "pixsize_x"_a, "pixsize_y"_a)
+    .def(py::init<size_t, size_t, double, double, double, size_t>(),"nxdirty"_a,
+      "nydirty"_a, "epsilon"_a, "pixsize_x"_a, "pixsize_y"_a, "nthreads"_a=1)
     .def("Nxdirty", &PyGridderConfig<double>::Nxdirty)
     .def("Nydirty", &PyGridderConfig<double>::Nydirty)
     .def("Epsilon", &PyGridderConfig<double>::Epsilon)
@@ -737,21 +714,21 @@ PYBIND11_MODULE(nifty_gridder, m)
         [](const PyGridderConfig<double> & gc) {
           // Encode object state in tuple
           return py::make_tuple(gc.Nxdirty(), gc.Nydirty(), gc.Epsilon(),
-                                gc.Pixsize_x(), gc.Pixsize_y());
+                                gc.Pixsize_x(), gc.Pixsize_y(), gc.Nthreads());
         },
         // __setstate__
         [](py::tuple t) {
-          myassert(t.size()==5,"Invalid state");
+          myassert(t.size()==6,"Invalid state");
 
           // Reconstruct from tuple
           return PyGridderConfig<double>(t[0].cast<size_t>(), t[1].cast<size_t>(),
                                        t[2].cast<double>(), t[3].cast<double>(),
-                                       t[4].cast<double>());
+                                       t[4].cast<double>(), t[5].cast<size_t>());
 
         }));
   m.def("getIndices", PygetIndices<double>, getIndices_DS, "baselines"_a,
     "gconf"_a, "flags"_a, "chbegin"_a=-1, "chend"_a=-1,
-    "wmin"_a=-1e30, "wmax"_a=1e30);
+    "wmin"_a=-1e30, "wmax"_a=1e30, "nthreads"_a=1);
   m.def("vis2grid_c",&Pyvis2grid_c<double>, vis2grid_c_DS, "baselines"_a,
     "gconf"_a, "idx"_a, "vis"_a, "grid_in"_a=None, "wgt"_a=None);
   m.def("ms2grid_c",&Pyms2grid_c<double>, ms2grid_c_DS, "baselines"_a, "gconf"_a,
@@ -764,8 +741,6 @@ PYBIND11_MODULE(nifty_gridder, m)
     "grid"_a, "wgt"_a=None);
   m.def("get_correlations", &Pyget_correlations<double>, "baselines"_a, "gconf"_a,
     "idx"_a, "du"_a, "dv"_a, "wgt"_a=None);
-  m.def("set_nthreads",&set_nthreads, set_nthreads_DS, "nthreads"_a);
-  m.def("get_nthreads",&get_nthreads, get_nthreads_DS);
   m.def("vis2dirty_wstack",&Pyvis2dirty_wstack<double>, "baselines"_a, "gconf"_a,
     "idx"_a, "vis"_a);
   m.def("dirty2vis_wstack",&Pydirty2vis_wstack<double>, "baselines"_a, "gconf"_a,