Commit 855637b6 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

move more code to gridder_cxx.h

parent 00fb5b31
...@@ -34,6 +34,8 @@ ...@@ -34,6 +34,8 @@
namespace gridder { namespace gridder {
namespace detail {
using namespace std; using namespace std;
template<typename T, size_t ndim> class mav template<typename T, size_t ndim> class mav
...@@ -667,6 +669,7 @@ template<class T, class T2, class T3> class VisServ ...@@ -667,6 +669,7 @@ template<class T, class T2, class T3> class VisServ
if (have_wgt) checkShape(wgt.shape(), {nvis}); if (have_wgt) checkShape(wgt.shape(), {nvis});
} }
size_t Nvis() const { return nvis; } size_t Nvis() const { return nvis; }
const Baselines<T> &getBaselines() const { return baselines; }
UVW<T> getCoord(size_t i) const UVW<T> getCoord(size_t i) const
{ return baselines.effectiveCoord(idx[i]); } { return baselines.effectiveCoord(idx[i]); }
complex<T> getVis(size_t i) const complex<T> getVis(size_t i) const
...@@ -700,6 +703,7 @@ template<class T, class T2, class T3> class MsServ ...@@ -700,6 +703,7 @@ template<class T, class T2, class T3> class MsServ
if (have_wgt) checkShape(wgt.shape(), {nrows, nchan}); if (have_wgt) checkShape(wgt.shape(), {nrows, nchan});
} }
size_t Nvis() const { return nvis; } size_t Nvis() const { return nvis; }
const Baselines<T> &getBaselines() const { return baselines; }
UVW<T> getCoord(size_t i) const UVW<T> getCoord(size_t i) const
{ return baselines.effectiveCoord(idx(i)); } { return baselines.effectiveCoord(idx(i)); }
complex<T> getVis(size_t i) const complex<T> getVis(size_t i) const
...@@ -1045,7 +1049,7 @@ template<typename T, typename Serv> void wstack_common( ...@@ -1045,7 +1049,7 @@ template<typename T, typename Serv> void wstack_common(
} }
} }
template<typename T, typename Serv> void x2dirty_wstack(const Baselines<T> &baselines, template<typename T, typename Serv> void x2dirty_wstack(
const GridderConfig<T> &gconf, const Serv &srv, const mav<complex<T>,2> &dirty, size_t verbosity=0) const GridderConfig<T> &gconf, const Serv &srv, const mav<complex<T>,2> &dirty, size_t verbosity=0)
{ {
auto nx_dirty=gconf.Nxdirty(); auto nx_dirty=gconf.Nxdirty();
...@@ -1101,7 +1105,7 @@ template<typename T, typename Serv> void x2dirty_wstack(const Baselines<T> &base ...@@ -1101,7 +1105,7 @@ template<typename T, typename Serv> void x2dirty_wstack(const Baselines<T> &base
idx_loc[i] = srv.getIdx(ipart); idx_loc[i] = srv.getIdx(ipart);
} }
grid_.fill(0.); grid_.fill(0.);
vis2grid_c(baselines, gconf, const_mav<uint32_t,1>(idx_loc), const_mav<complex<T>,1>(vis_loc), grid, const_mav<T,1>(nullptr,{0})); vis2grid_c(srv.getBaselines(), gconf, const_mav<uint32_t,1>(idx_loc), const_mav<complex<T>,1>(vis_loc), grid, const_mav<T,1>(nullptr,{0}));
gconf.grid2dirty_c(grid, tdirty); gconf.grid2dirty_c(grid, tdirty);
gconf.apply_wscreen(tdirty, tdirty, wcur, false); gconf.apply_wscreen(tdirty, tdirty, wcur, false);
#pragma omp parallel for num_threads(nthreads) #pragma omp parallel for num_threads(nthreads)
...@@ -1117,10 +1121,10 @@ template<typename T> void vis2dirty_wstack(const Baselines<T> &baselines, ...@@ -1117,10 +1121,10 @@ template<typename T> void vis2dirty_wstack(const Baselines<T> &baselines,
const GridderConfig<T> &gconf, const const_mav<uint32_t,1> &idx, const GridderConfig<T> &gconf, const const_mav<uint32_t,1> &idx,
const const_mav<complex<T>,1> &vis, mav<complex<T>,2> &dirty) const const_mav<complex<T>,1> &vis, mav<complex<T>,2> &dirty)
{ {
x2dirty_wstack(baselines, gconf, VisServ<T,const_mav<complex<T>,1>,const_mav<T,1>>(baselines, idx, vis, const_mav<T,1>(nullptr,{0})), dirty); x2dirty_wstack(gconf, VisServ<T,const_mav<complex<T>,1>,const_mav<T,1>>(baselines, idx, vis, const_mav<T,1>(nullptr,{0})), dirty);
} }
template<typename T, typename Serv> void dirty2x_wstack(const Baselines<T> &baselines, template<typename T, typename Serv> void dirty2x_wstack(
const GridderConfig<T> &gconf, const const_mav<complex<T>,2> &dirty, const GridderConfig<T> &gconf, const const_mav<complex<T>,2> &dirty,
const Serv &srv, size_t verbosity=0) const Serv &srv, size_t verbosity=0)
{ {
...@@ -1179,7 +1183,7 @@ template<typename T, typename Serv> void dirty2x_wstack(const Baselines<T> &base ...@@ -1179,7 +1183,7 @@ template<typename T, typename Serv> void dirty2x_wstack(const Baselines<T> &base
gconf.apply_wscreen(tdirty, tdirty2, wcur, true); gconf.apply_wscreen(tdirty, tdirty2, wcur, true);
gconf.dirty2grid_c(tdirty2, grid); gconf.dirty2grid_c(tdirty2, grid);
grid2vis_c(baselines, gconf, const_mav<uint32_t,1>(idx_loc), const_mav<complex<T>,2>(grid), vis_loc, const_mav<T,1>(nullptr,{0})); grid2vis_c(srv.getBaselines(), gconf, const_mav<uint32_t,1>(idx_loc), const_mav<complex<T>,2>(grid), vis_loc, const_mav<T,1>(nullptr,{0}));
T tmpval=T(2)/(w_supp*dw); T tmpval=T(2)/(w_supp*dw);
#pragma omp parallel for num_threads(nthreads) #pragma omp parallel for num_threads(nthreads)
...@@ -1195,7 +1199,79 @@ template<typename T> void dirty2vis_wstack(const Baselines<T> &baselines, ...@@ -1195,7 +1199,79 @@ template<typename T> void dirty2vis_wstack(const Baselines<T> &baselines,
const GridderConfig<T> &gconf, const const_mav<uint32_t,1> &idx, const GridderConfig<T> &gconf, const const_mav<uint32_t,1> &idx,
const const_mav<complex<T>,2> &dirty, mav<complex<T>,1> &vis) const const_mav<complex<T>,2> &dirty, mav<complex<T>,1> &vis)
{ {
dirty2x_wstack(baselines, gconf, dirty, VisServ<T,mav<complex<T>,1>,const_mav<T,1>>(baselines, idx, vis, const_mav<T,1>(nullptr,{0}))); dirty2x_wstack(gconf, dirty, VisServ<T,mav<complex<T>,1>,const_mav<T,1>>(baselines, idx, vis, const_mav<T,1>(nullptr,{0})));
}
template<typename T> size_t getIdxSize(const Baselines<T> &baselines,
const const_mav<bool,2> &flags, int chbegin, int chend, T wmin, T wmax)
{
size_t nrow=baselines.Nrows(), nchan=baselines.Nchannels();
if (chbegin<0) chbegin=0;
if (chend<0) chend=nchan;
myassert(chend>chbegin, "empty channel range selected");
myassert(chend<=int(nchan), "chend too large");
myassert(wmax>wmin, "empty w range selected");
checkShape(flags.shape(), {nrow, nchan});
size_t res=0;
#pragma omp parallel for num_threads(nthreads) reduction(+:res)
for (size_t irow=0; irow<nrow; ++irow)
for (int ichan=chbegin; ichan<chend; ++ichan)
if (!flags(irow,ichan))
{
auto uvw = baselines.effectiveCoord(RowChan{irow,size_t(ichan)});
if ((uvw.w>=wmin) && (uvw.w<wmax))
++res;
}
return res;
}
template<typename T> void fillIdx(const Baselines<T> &baselines,
const GridderConfig<T> &gconf, const const_mav<bool,2> &flags, int chbegin,
int chend, T wmin, T wmax, mav<uint32_t,1> &res)
{
size_t nrow=baselines.Nrows(),
nchan=baselines.Nchannels(),
nsafe=gconf.Nsafe();
if (chbegin<0) chbegin=0;
if (chend<0) chend=nchan;
myassert(chend>chbegin, "empty channel range selected");
myassert(chend<=int(nchan), "chend too large");
myassert(wmax>wmin, "empty w range selected");
checkShape(flags.shape(), {nrow, nchan});
constexpr int side=1<<logsquare;
size_t nbu = (gconf.Nu()+1+side-1) >> logsquare,
nbv = (gconf.Nv()+1+side-1) >> logsquare;
vector<uint32_t> acc(nbu*nbv+1, 0);
vector<uint32_t> tmp(nrow*(chend-chbegin));
for (size_t irow=0, idx=0; irow<nrow; ++irow)
for (int ichan=chbegin; ichan<chend; ++ichan)
if (!flags(irow, ichan))
{
auto uvw = baselines.effectiveCoord(RowChan{irow,size_t(ichan)});
if ((uvw.w>=wmin) && (uvw.w<wmax))
{
T u, v;
int iu0, iv0;
gconf.getpix(uvw.u, uvw.v, u, v, iu0, iv0);
iu0 = (iu0+nsafe)>>logsquare;
iv0 = (iv0+nsafe)>>logsquare;
++acc[nbv*iu0 + iv0 + 1];
tmp[idx++] = nbv*iu0 + iv0;
}
}
for (size_t i=1; i<acc.size(); ++i)
acc[i] += acc[i-1];
myassert(res.shape(0)==acc.back(), "array size mismatch");
for (size_t irow=0, idx=0; irow<nrow; ++irow)
for (int ichan=chbegin; ichan<chend; ++ichan)
if (!flags(irow, ichan))
{
auto uvw = baselines.effectiveCoord(RowChan{irow,size_t(ichan)});
if ((uvw.w>=wmin) && (uvw.w<wmax))
res[acc[tmp[idx++]]++] = irow*nchan+ichan;
}
} }
template<typename T> vector<uint32_t> getWgtIndices(const Baselines<T> &baselines, template<typename T> vector<uint32_t> getWgtIndices(const Baselines<T> &baselines,
...@@ -1248,7 +1324,7 @@ template<typename T> void full_gridding(const const_mav<T,2> &uvw, ...@@ -1248,7 +1324,7 @@ template<typename T> void full_gridding(const const_mav<T,2> &uvw,
auto idx = getWgtIndices(baselines, gconf, wgt); auto idx = getWgtIndices(baselines, gconf, wgt);
auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()}); auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()});
auto srv = MsServ<T,const_mav<complex<T>,2>,const_mav<T,2>>(baselines,idx2,ms,wgt); auto srv = MsServ<T,const_mav<complex<T>,2>,const_mav<T,2>>(baselines,idx2,ms,wgt);
x2dirty_wstack(baselines, gconf, srv, dirty, verbosity); x2dirty_wstack(gconf, srv, dirty, verbosity);
} }
template<typename T> void full_degridding(const const_mav<T,2> &uvw, template<typename T> void full_degridding(const const_mav<T,2> &uvw,
...@@ -1265,9 +1341,27 @@ template<typename T> void full_degridding(const const_mav<T,2> &uvw, ...@@ -1265,9 +1341,27 @@ template<typename T> void full_degridding(const const_mav<T,2> &uvw,
for (size_t j=0; j<ms.shape(1); ++j) for (size_t j=0; j<ms.shape(1); ++j)
ms(i,j) = 0; ms(i,j) = 0;
auto srv = MsServ<T,mav<complex<T>,2>,const_mav<T,2>>(baselines,idx2,ms,wgt); auto srv = MsServ<T,mav<complex<T>,2>,const_mav<T,2>>(baselines,idx2,ms,wgt);
dirty2x_wstack(baselines, gconf, dirty, srv, verbosity); dirty2x_wstack(gconf, dirty, srv, verbosity);
} }
} // namespace detail
// public names
using detail::mav;
using detail::const_mav;
using detail::set_nthreads;
using detail::get_nthreads;
using detail::myassert;
using detail::Baselines;
using detail::GridderConfig;
using detail::ms2grid_c;
using detail::grid2ms_c;
using detail::vis2grid_c;
using detail::grid2vis_c;
using detail::vis2dirty_wstack;
using detail::dirty2vis_wstack;
using detail::full_gridding;
using detail::full_degridding;
} // namespace gridder } // namespace gridder
#endif #endif
...@@ -618,55 +618,17 @@ template<typename T> pyarr<uint32_t> PygetIndices(const PyBaselines<T> &baseline ...@@ -618,55 +618,17 @@ template<typename T> pyarr<uint32_t> PygetIndices(const PyBaselines<T> &baseline
const PyGridderConfig<T> &gconf, const pyarr<bool> &flags_, int chbegin, const PyGridderConfig<T> &gconf, const pyarr<bool> &flags_, int chbegin,
int chend, T wmin, T wmax) int chend, T wmin, T wmax)
{ {
size_t nrow=baselines.Nrows(), size_t nidx;
nchan=baselines.Nchannels(), auto flags = make_const_mav<2>(flags_);
nsafe=gconf.Nsafe();
if (chbegin<0) chbegin=0;
if (chend<0) chend=nchan;
myassert(chend>chbegin, "empty channel range selected");
myassert(chend<=int(nchan), "chend too large");
myassert(wmax>wmin, "empty w range selected");
checkArray(flags_, "flags", {nrow, nchan});
auto flags = flags_.data();
constexpr int side=1<<logsquare;
size_t nbu = (gconf.Nu()+1+side-1) >> logsquare,
nbv = (gconf.Nv()+1+side-1) >> logsquare;
vector<uint32_t> acc(nbu*nbv+1, 0);
vector<uint32_t> tmp(nrow*(chend-chbegin));
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
for (size_t irow=0, idx=0; irow<nrow; ++irow) nidx = getIdxSize(baselines, flags, chbegin, chend, wmin, wmax);
for (int ichan=chbegin; ichan<chend; ++ichan)
if (!flags[irow*nchan+ichan])
{
auto uvw = baselines.effectiveCoord(RowChan{irow,size_t(ichan)});
if ((uvw.w>=wmin) && (uvw.w<wmax))
{
T u, v;
int iu0, iv0;
gconf.getpix(uvw.u, uvw.v, u, v, iu0, iv0);
iu0 = (iu0+nsafe)>>logsquare;
iv0 = (iv0+nsafe)>>logsquare;
++acc[nbv*iu0 + iv0 + 1];
tmp[idx++] = nbv*iu0 + iv0;
}
}
for (size_t i=1; i<acc.size(); ++i)
acc[i] += acc[i-1];
} }
auto res = makeArray<uint32_t>({acc.back()}); auto res = makeArray<uint32_t>({nidx});
auto iout = res.mutable_data(); auto res2 = make_mav<1>(res);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
for (size_t irow=0, idx=0; irow<nrow; ++irow) fillIdx(baselines, gconf, flags, chbegin, chend, wmin, wmax, res2);
for (int ichan=chbegin; ichan<chend; ++ichan)
if (!flags[irow*nchan+ichan])
{
auto uvw = baselines.effectiveCoord(RowChan{irow,size_t(ichan)});
if ((uvw.w>=wmin) && (uvw.w<wmax))
iout[acc[tmp[idx++]]++] = irow*nchan+ichan;
}
} }
return res; return res;
} }
...@@ -809,10 +771,10 @@ PYBIND11_MODULE(nifty_gridder, m) ...@@ -809,10 +771,10 @@ PYBIND11_MODULE(nifty_gridder, m)
m.def("dirty2vis_wstack",&Pydirty2vis_wstack<double>, "baselines"_a, "gconf"_a, m.def("dirty2vis_wstack",&Pydirty2vis_wstack<double>, "baselines"_a, "gconf"_a,
"idx"_a, "dirty"_a); "idx"_a, "dirty"_a);
m.def("full_gridding",&Pyfull_gridding<double>,"uvw"_a,"freq"_a,"vis"_a, m.def("full_gridding",&Pyfull_gridding<double>,"uvw"_a,"freq"_a,"ms"_a,
"wgt"_a,"npix_x"_a,"npix_y"_a,"pixsize_x"_a,"pixsize_y"_a,"epsilon"_a, "wgt"_a,"npix_x"_a,"npix_y"_a,"pixsize_x"_a,"pixsize_y"_a,"epsilon"_a,
"nthreads"_a=1, "verbosity"_a=0); "nthreads"_a=1, "verbosity"_a=0);
m.def("full_gridding_f",&Pyfull_gridding<float>,"uvw"_a,"freq"_a,"vis"_a, m.def("full_gridding_f",&Pyfull_gridding<float>,"uvw"_a,"freq"_a,"ms"_a,
"wgt"_a,"npix_x"_a,"npix_y"_a,"pixsize_x"_a,"pixsize_y"_a,"epsilon"_a, "wgt"_a,"npix_x"_a,"npix_y"_a,"pixsize_x"_a,"pixsize_y"_a,"epsilon"_a,
"nthreads"_a=1, "verbosity"_a=0); "nthreads"_a=1, "verbosity"_a=0);
m.def("full_degridding",&Pyfull_degridding<double>,"uvw"_a,"freq"_a,"dirty"_a, m.def("full_degridding",&Pyfull_degridding<double>,"uvw"_a,"freq"_a,"dirty"_a,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment