diff --git a/nifty_gridder.cc b/nifty_gridder.cc index 24053aa2a08912611a8ba84dba9f912bbbe18a81..62c9a3f9e22321146d5fc33d57c1dcba5898943a 100644 --- a/nifty_gridder.cc +++ b/nifty_gridder.cc @@ -691,7 +691,8 @@ template<typename T> pyarr_c<T> vis2grid(const Baselines<T> &baselines, template<typename T> pyarr_c<complex<T>> ms2grid_c( const Baselines<T> &baselines, const GridderConfig<T> &gconf, - const pyarr_c<uint32_t> &idx_, const pyarr_c<complex<T>> &ms_) + const pyarr_c<uint32_t> &idx_, const pyarr_c<complex<T>> &ms_, + py::object user_grid) { auto nrows = baselines.Nrows(); auto nchan = baselines.Nchannels(); @@ -701,8 +702,11 @@ template<typename T> pyarr_c<complex<T>> ms2grid_c( auto ms = ms_.data(); auto idx = idx_.data(); + bool have_user_grid = !user_grid.is(py::none()); + size_t nu=gconf.Nu(), nv=gconf.Nv(); - auto res = makeArray<complex<T>>({nu, nv}); + auto res = have_user_grid ? user_grid.cast<pyarr_c<complex<T>>>() + : makeArray<complex<T>>({nu, nv}); auto grid = res.mutable_data(); { py::gil_scoped_release release; @@ -740,8 +744,8 @@ template<typename T> pyarr_c<complex<T>> ms2grid_c( template<typename T> pyarr_c<T> ms2grid(const Baselines<T> &baselines, const GridderConfig<T> &gconf, const pyarr_c<uint32_t> &idx_, - const pyarr_c<complex<T>> &ms_) - { return complex2hartley(ms2grid_c(baselines, gconf, idx_, ms_)); } + const pyarr_c<complex<T>> &ms_, py::object user_grid) + { return complex2hartley(ms2grid_c(baselines, gconf, idx_, ms_, user_grid)); } template<typename T> pyarr_c<complex<T>> grid2vis_c(const Baselines<T> &baselines, const GridderConfig<T> &gconf, const pyarr_c<uint32_t> &idx_, @@ -1126,7 +1130,8 @@ PYBIND11_MODULE(nifty_gridder, m) "wmin"_a=-1e30, "wmax"_a=1e30); m.def("vis2grid",&vis2grid<double>, vis2grid_DS, "baselines"_a, "gconf"_a, "idx"_a, "vis"_a, "user_grid"_a); - m.def("ms2grid",&ms2grid<double>, "baselines"_a, "gconf"_a, "idx"_a, "ms"_a); + m.def("ms2grid",&ms2grid<double>, "baselines"_a, "gconf"_a, "idx"_a, + "ms"_a, "user_grid"_a); m.def("grid2vis",&grid2vis<double>, grid2vis_DS, "baselines"_a, "gconf"_a, "idx"_a, "grid"_a); m.def("grid2ms",&grid2ms<double>, "baselines"_a, "gconf"_a, "idx"_a, @@ -1134,7 +1139,7 @@ PYBIND11_MODULE(nifty_gridder, m) m.def("vis2grid_c",&vis2grid_c<double>, "baselines"_a, "gconf"_a, "idx"_a, "vis"_a, "user_grid"_a); m.def("ms2grid_c",&ms2grid_c<double>, "baselines"_a, "gconf"_a, "idx"_a, - "ms"_a); + "ms"_a, "user_grid"_a); m.def("grid2vis_c",&grid2vis_c<double>, "baselines"_a, "gconf"_a, "idx"_a, "grid"_a); m.def("grid2ms_c",&grid2ms_c<double>, "baselines"_a, "gconf"_a, "idx"_a,