diff --git a/nifty_gridder.cc b/nifty_gridder.cc index 02abf6700d91a21fa792c8f46145a8e9e74016f0..dc2d4ed832c27349c587c124a60badcef8982608 100644 --- a/nifty_gridder.cc +++ b/nifty_gridder.cc @@ -852,6 +852,62 @@ template<typename T> pyarr_c<complex<T>> grid2ms(const Baselines<T> &baselines, const pyarr_c<T> &grid_, py::object &ms_in) { return grid2ms_c(baselines, gconf, idx_, hartley2complex(grid_), ms_in); } +template<typename T> pyarr_c<complex<T>> apply_holo( + const Baselines<T> &baselines, const GridderConfig<T> &gconf, + const pyarr_c<uint32_t> &idx_, const pyarr_c<complex<T>> &grid_) + { + size_t nu=gconf.Nu(), nv=gconf.Nv(); + checkArray(idx_, "idx", {0}); + auto grid = grid_.data(); + checkArray(grid_, "grid", {nu, nv}); + size_t nvis = size_t(idx_.shape(0)); + auto idx = idx_.data(); + + auto res = makeArray<complex<T>>({nu, nv}); + auto ogrid = res.mutable_data(); + { + py::gil_scoped_release release; + T beta = gconf.Beta(); + size_t w = gconf.W(); + + // Loop over sampling points +#pragma omp parallel +{ + Helper<T> hlp(gconf, grid, ogrid); + T emb = exp(-2*beta); + int jump = hlp.lineJump(); + const T * RESTRICT ku = hlp.kernel.data(); + const T * RESTRICT kv = hlp.kernel.data()+w; + +#pragma omp for schedule(guided,100) + for (size_t ipart=0; ipart<nvis; ++ipart) + { + UVW<T> coord = baselines.effectiveCoord(idx[ipart]); + hlp.prep(coord.u, coord.v); + complex<T> r = 0; + const auto * RESTRICT ptr = hlp.p0r; + for (size_t cu=0; cu<w; ++cu) + { + complex<T> tmp(0); + for (size_t cv=0; cv<w; ++cv) + tmp += ptr[cv] * kv[cv]; + r += tmp*ku[cu]; + ptr += jump; + } + r*=emb*emb; + auto * RESTRICT wptr = hlp.p0w; + for (size_t cu=0; cu<w; ++cu) + { + complex<T> tmp(r*ku[cu]); + for (size_t cv=0; cv<w; ++cv) + wptr[cv] += tmp*kv[cv]; + wptr+=jump; + } + } +} + } + return res; + } template<typename T> pyarr_c<uint32_t> getIndices(const Baselines<T> &baselines, const GridderConfig<T> &gconf, const pyarr_c<bool> &flags_, int chbegin, int chend, T wmin, T wmax) @@ -1138,4 +1194,6 @@ PYBIND11_MODULE(nifty_gridder, m) "grid"_a); m.def("grid2ms_c",&grid2ms_c<double>, "baselines"_a, "gconf"_a, "idx"_a, "grid"_a, "ms_in"_a=py::none()); + m.def("apply_holo",&apply_holo<double>, "baselines"_a, "gconf"_a, "idx"_a, + "grid"_a); }