Commit ba67cb5d authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanup

parent de316727
......@@ -852,6 +852,9 @@ template<class T, class T2> class VisServ
checkShape(idx.shape(), {nvis});
if (have_wgt) checkShape(wgt.shape(), {nvis});
}
VisServ(const VisServ &orig, const const_mav<uint32_t,1> &/*newidx*/)
: baselines(orig.baselines), idx(orig.idx), vis(orig.vis), wgt(orig.wgt)
{ myfail("must not happen"); }
size_t Nvis() const { return nvis; }
const Baselines &getBaselines() const { return baselines; }
UVW getCoord(size_t i) const
......@@ -863,6 +866,7 @@ template<class T, class T2> class VisServ
{ vis(i) = have_wgt ? v*wgt(i) : v; }
void addVis (size_t i, const complex<T> &v) const
{ vis(i) += have_wgt ? v*wgt(i) : v; }
bool isMsServ() const { return false; }
};
template<class T, class T2> VisServ<T, T2> makeVisServ
(const Baselines &baselines,
......@@ -911,6 +915,7 @@ template<class T, class T2> class MsServ
auto rc = baselines.getRowChan(idx(i));
ms(rc.row, rc.chan) += have_wgt ? v*wgt(rc.row, rc.chan) : v;
}
bool isMsServ() const { return true; }
};
template<class T, class T2> MsServ<T, T2> makeMsServ
(const Baselines &baselines,
......@@ -1203,43 +1208,6 @@ template<typename T> void apply_wcorr(const GridderConfig &gconf,
}
}
template<typename T, typename Serv> void x2dirty(
const GridderConfig &gconf, const Serv &srv, const mav<T,2> &dirty,
size_t verbosity=0)
{
if (verbosity>0)
cout << "Gridding without w-stacking: " << srv.Nvis()
<< " visibilities" << endl;
if (verbosity>0) cout << "Using " << gconf.Nthreads() << " threads" << endl;
tmpStorage<complex<T>,2> grid_({gconf.Nu(), gconf.Nv()});
auto grid=grid_.getMav();
grid_.fill(0.);
x2grid_c(gconf, srv, grid);
tmpStorage<T,2> rgrid_(grid.shape());
auto rgrid=rgrid_.getMav();
complex2hartley(cmav(grid), rgrid, gconf.Nthreads());
gconf.grid2dirty(cmav(rgrid), dirty);
}
template<typename T, typename Serv> void dirty2x(
const GridderConfig &gconf, const const_mav<T,2> &dirty,
const Serv &srv, size_t verbosity=0)
{
if (verbosity>0)
cout << "Degridding without w-stacking: " << srv.Nvis()
<< " visibilities" << endl;
if (verbosity>0) cout << "Using " << gconf.Nthreads() << " threads" << endl;
tmpStorage<T,2> grid_({gconf.Nu(), gconf.Nv()});
auto grid=grid_.getMav();
gconf.dirty2grid(dirty, grid);
tmpStorage<complex<T>,2> grid2_(grid.shape());
auto grid2=grid2_.getMav();
hartley2complex(cmav(grid), grid2, gconf.Nthreads());
grid2x_c(gconf, cmav(grid2), srv);
}
template<typename Serv> void wstack_common(
const GridderConfig &gconf, const Serv &srv, double &wmin,
double &dw, size_t &nplanes, vector<size_t> &nvis_plane,
......@@ -1294,199 +1262,193 @@ template<typename Serv> void wstack_common(
}
}
template<typename T, typename Serv> void x2dirty_wstack(
template<typename T, typename Serv> void x2dirty(
const GridderConfig &gconf, const Serv &srv, const mav<T,2> &dirty,
size_t verbosity=0)
bool do_wstacking, size_t verbosity=0)
{
auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty();
size_t nvis = srv.Nvis();
auto supp = gconf.Supp();
size_t nthreads = gconf.Nthreads();
double wmin;
double dw;
size_t nplanes;
vector<size_t> nvis_plane;
vector<int> minplane;
if (verbosity>0) cout << "Gridding using improved w-stacking" << endl;
wstack_common(gconf, srv, wmin, dw, nplanes, nvis_plane, minplane, verbosity);
dirty.fill(0);
vector<complex<T>> vis_loc_;
vector<uint32_t> idx_loc_, mapper;
tmpStorage<complex<T>,2> grid_({gconf.Nu(),gconf.Nv()});
auto grid=grid_.getMav();
tmpStorage<complex<T>,2> tdirty_(dirty.shape());
auto tdirty=tdirty_.getMav();
for (size_t iw=0; iw<nplanes; ++iw)
if (do_wstacking)
{
if (nvis_plane[iw]==0) continue;
if (verbosity>1)
cout << "Working on plane " << iw << " containing " << nvis_plane[iw]
<< " visibilities" << endl;
double wcur = wmin+iw*dw;
size_t cnt=0;
vis_loc_.resize(nvis_plane[iw]);
auto vis_loc = mav<complex<T>,1>(vis_loc_.data(), {nvis_plane[iw]});
idx_loc_.resize(nvis_plane[iw]);
auto idx_loc = mav<uint32_t,1>(idx_loc_.data(), {nvis_plane[iw]});
mapper.resize(nvis_plane[iw]);
for (size_t ipart=0; ipart<nvis; ++ipart)
if ((int(iw)>=minplane[ipart]) && (iw<minplane[ipart]+supp))
mapper[cnt++] = ipart;
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nvis_plane[iw]; ++i)
auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty();
size_t nvis = srv.Nvis();
auto supp = gconf.Supp();
size_t nthreads = gconf.Nthreads();
double wmin;
double dw;
size_t nplanes;
vector<size_t> nvis_plane;
vector<int> minplane;
if (verbosity>0) cout << "Gridding using improved w-stacking" << endl;
wstack_common(gconf, srv, wmin, dw, nplanes, nvis_plane, minplane, verbosity);
dirty.fill(0);
vector<complex<T>> vis_loc_;
vector<uint32_t> idx_loc_, newidx;
tmpStorage<complex<T>,2> grid_({gconf.Nu(),gconf.Nv()});
auto grid=grid_.getMav();
tmpStorage<complex<T>,2> tdirty_(dirty.shape());
auto tdirty=tdirty_.getMav();
for (size_t iw=0; iw<nplanes; ++iw)
{
auto ipart = mapper[i];
vis_loc[i] = srv.getVis(ipart);
idx_loc[i] = srv.getIdx(ipart);
}
grid.fill(0);
vis2grid_c(srv.getBaselines(), gconf, cmav(idx_loc), cmav(vis_loc), grid,
nullmav<T,1>(), wcur, dw);
gconf.grid2dirty_c(cmav(grid), tdirty);
gconf.apply_wscreen(cmav(tdirty), tdirty, wcur, false);
if (nvis_plane[iw]==0) continue;
if (verbosity>1)
cout << "Working on plane " << iw << " containing " << nvis_plane[iw]
<< " visibilities" << endl;
double wcur = wmin+iw*dw;
size_t cnt=0;
newidx.resize(nvis_plane[iw]);
if (srv.isMsServ())
{
for (size_t ipart=0; ipart<nvis; ++ipart)
if ((int(iw)>=minplane[ipart]) && (iw<minplane[ipart]+supp))
newidx[cnt++] = srv.getIdx(ipart);
grid.fill(0);
Serv newsrv(srv, const_mav<uint32_t, 1>(newidx.data(), {newidx.size()}));
x2grid_c(gconf, newsrv, grid, wcur, dw);
}
else
{
for (size_t ipart=0; ipart<nvis; ++ipart)
if ((int(iw)>=minplane[ipart]) && (iw<minplane[ipart]+supp))
newidx[cnt++] = ipart;
vis_loc_.resize(nvis_plane[iw]);
auto vis_loc = mav<complex<T>,1>(vis_loc_.data(), {nvis_plane[iw]});
idx_loc_.resize(nvis_plane[iw]);
auto idx_loc = mav<uint32_t,1>(idx_loc_.data(), {nvis_plane[iw]});
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nx_dirty; ++i)
for (size_t j=0; j<ny_dirty; ++j)
dirty(i,j) += tdirty(i,j).real();
for (size_t i=0; i<nvis_plane[iw]; ++i)
{
auto ipart = newidx[i];
vis_loc[i] = srv.getVis(ipart);
idx_loc[i] = srv.getIdx(ipart);
}
grid.fill(0);
vis2grid_c(srv.getBaselines(), gconf, cmav(idx_loc), cmav(vis_loc), grid,
nullmav<T,1>(), wcur, dw);
}
gconf.grid2dirty_c(cmav(grid), tdirty);
gconf.apply_wscreen(cmav(tdirty), tdirty, wcur, false);
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nx_dirty; ++i)
for (size_t j=0; j<ny_dirty; ++j)
dirty(i,j) += tdirty(i,j).real();
}
// correct for w gridding
apply_wcorr(gconf, dirty, ES_Kernel_with_correction(supp, nthreads), dw);
}
// correct for w gridding
apply_wcorr(gconf, dirty, ES_Kernel_with_correction(supp, nthreads), dw);
}
template<typename T> void vis2dirty_wstack(const Baselines &baselines,
const GridderConfig &gconf, const const_mav<uint32_t,1> &idx,
const const_mav<complex<T>,1> &vis, mav<T,2> &dirty)
{
x2dirty_wstack(gconf, makeVisServ(baselines, idx, vis, nullmav<T,1>()), dirty);
}
template<typename T, typename Serv> void x2dirty_wstack2(
const GridderConfig &gconf, const Serv &srv, const mav<T,2> &dirty,
size_t verbosity=0)
{
auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty();
size_t nvis = srv.Nvis();
auto supp = gconf.Supp();
size_t nthreads = gconf.Nthreads();
double wmin;
double dw;
size_t nplanes;
vector<size_t> nvis_plane;
vector<int> minplane;
if (verbosity>0) cout << "Gridding using improved w-stacking" << endl;
wstack_common(gconf, srv, wmin, dw, nplanes, nvis_plane, minplane, verbosity);
dirty.fill(0);
vector<uint32_t> newidx;
tmpStorage<complex<T>,2> grid_({gconf.Nu(),gconf.Nv()});
auto grid=grid_.getMav();
tmpStorage<complex<T>,2> tdirty_(dirty.shape());
auto tdirty=tdirty_.getMav();
for (size_t iw=0; iw<nplanes; ++iw)
else
{
if (nvis_plane[iw]==0) continue;
if (verbosity>1)
cout << "Working on plane " << iw << " containing " << nvis_plane[iw]
if (verbosity>0)
cout << "Gridding without w-stacking: " << srv.Nvis()
<< " visibilities" << endl;
double wcur = wmin+iw*dw;
size_t cnt=0;
newidx.resize(nvis_plane[iw]);
for (size_t ipart=0; ipart<nvis; ++ipart)
if ((int(iw)>=minplane[ipart]) && (iw<minplane[ipart]+supp))
newidx[cnt++] = srv.getIdx(ipart);
grid.fill(0);
Serv newsrv(srv, const_mav<uint32_t, 1>(newidx.data(), {newidx.size()}));
x2grid_c(gconf, newsrv, grid, wcur, dw);
gconf.grid2dirty_c(cmav(grid), tdirty);
gconf.apply_wscreen(cmav(tdirty), tdirty, wcur, false);
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nx_dirty; ++i)
for (size_t j=0; j<ny_dirty; ++j)
dirty(i,j) += tdirty(i,j).real();
if (verbosity>0) cout << "Using " << gconf.Nthreads() << " threads" << endl;
tmpStorage<complex<T>,2> grid_({gconf.Nu(), gconf.Nv()});
auto grid=grid_.getMav();
grid_.fill(0.);
x2grid_c(gconf, srv, grid);
tmpStorage<T,2> rgrid_(grid.shape());
auto rgrid=rgrid_.getMav();
complex2hartley(cmav(grid), rgrid, gconf.Nthreads());
gconf.grid2dirty(cmav(rgrid), dirty);
}
// correct for w gridding
apply_wcorr(gconf, dirty, ES_Kernel_with_correction(supp, nthreads), dw);
}
template<typename T> void vis2dirty_wstack2(const Baselines &baselines,
template<typename T> void vis2dirty_wstack(const Baselines &baselines,
const GridderConfig &gconf, const const_mav<uint32_t,1> &idx,
const const_mav<complex<T>,1> &vis, mav<T,2> &dirty)
{
x2dirty_wstack2(gconf, makeVisServ(baselines, idx, vis, nullmav<T,1>()), dirty);
x2dirty(gconf, makeVisServ(baselines, idx, vis, nullmav<T,1>()), dirty, true);
}
template<typename T, typename Serv> void dirty2x_wstack(
template<typename T, typename Serv> void dirty2x(
const GridderConfig &gconf, const const_mav<T,2> &dirty,
const Serv &srv, size_t verbosity=0)
const Serv &srv, bool do_wstacking, size_t verbosity=0)
{
auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty();
size_t nvis = srv.Nvis();
auto supp = gconf.Supp();
size_t nthreads = gconf.Nthreads();
double wmin;
double dw;
size_t nplanes;
vector<size_t> nvis_plane;
vector<int> minplane;
if (verbosity>0) cout << "Degridding using improved w-stacking" << endl;
wstack_common(gconf, srv, wmin, dw, nplanes, nvis_plane, minplane, verbosity);
tmpStorage<T,2> tdirty_({nx_dirty,ny_dirty});
auto tdirty=tdirty_.getMav();
for (size_t i=0; i<nx_dirty; ++i)
for (size_t j=0; j<ny_dirty; ++j)
tdirty(i,j) = dirty(i,j);
// correct for w gridding
apply_wcorr(gconf, tdirty, ES_Kernel_with_correction(supp, nthreads), dw);
vector<complex<T>> vis_loc_;
vector<uint32_t> idx_loc_, mapper;
tmpStorage<complex<T>,2> grid_({gconf.Nu(),gconf.Nv()});
auto grid=grid_.getMav();
tmpStorage<complex<T>,2> tdirty2_({nx_dirty,ny_dirty});
auto tdirty2=tdirty2_.getMav();
for (size_t iw=0; iw<nplanes; ++iw)
if (do_wstacking)
{
if (nvis_plane[iw]==0) continue;
if (verbosity>1)
cout << "Working on plane " << iw << " containing " << nvis_plane[iw]
<< " visibilities" << endl;
double wcur = wmin+iw*dw;
size_t cnt=0;
vis_loc_.resize(nvis_plane[iw]);
auto vis_loc = mav<complex<T>,1>(vis_loc_.data(), {nvis_plane[iw]});
idx_loc_.resize(nvis_plane[iw]);
auto idx_loc = mav<uint32_t,1>(idx_loc_.data(), {nvis_plane[iw]});
mapper.resize(nvis_plane[iw]);
for (size_t ipart=0; ipart<nvis; ++ipart)
if ((int(iw)>=minplane[ipart]) && (iw<minplane[ipart]+supp))
mapper[cnt++] = ipart;
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nvis_plane[iw]; ++i)
idx_loc[i] = srv.getIdx(mapper[i]);
#pragma omp parallel for num_threads(nthreads)
auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty();
size_t nvis = srv.Nvis();
auto supp = gconf.Supp();
size_t nthreads = gconf.Nthreads();
double wmin;
double dw;
size_t nplanes;
vector<size_t> nvis_plane;
vector<int> minplane;
if (verbosity>0) cout << "Degridding using improved w-stacking" << endl;
wstack_common(gconf, srv, wmin, dw, nplanes, nvis_plane, minplane, verbosity);
tmpStorage<T,2> tdirty_({nx_dirty,ny_dirty});
auto tdirty=tdirty_.getMav();
for (size_t i=0; i<nx_dirty; ++i)
for (size_t j=0; j<ny_dirty; ++j)
tdirty2(i,j) = tdirty(i,j);
gconf.apply_wscreen(cmav(tdirty2), tdirty2, wcur, true);
gconf.dirty2grid_c(cmav(tdirty2), grid);
grid2vis_c(srv.getBaselines(), gconf, cmav(idx_loc), cmav(grid), vis_loc,
nullmav<T,1>(), wcur, dw);
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nvis_plane[iw]; ++i)
srv.addVis(mapper[i], vis_loc[i]);
tdirty(i,j) = dirty(i,j);
// correct for w gridding
apply_wcorr(gconf, tdirty, ES_Kernel_with_correction(supp, nthreads), dw);
vector<complex<T>> vis_loc_;
vector<uint32_t> idx_loc_, mapper;
tmpStorage<complex<T>,2> grid_({gconf.Nu(),gconf.Nv()});
auto grid=grid_.getMav();
tmpStorage<complex<T>,2> tdirty2_({nx_dirty,ny_dirty});
auto tdirty2=tdirty2_.getMav();
for (size_t iw=0; iw<nplanes; ++iw)
{
if (nvis_plane[iw]==0) continue;
if (verbosity>1)
cout << "Working on plane " << iw << " containing " << nvis_plane[iw]
<< " visibilities" << endl;
double wcur = wmin+iw*dw;
size_t cnt=0;
vis_loc_.resize(nvis_plane[iw]);
auto vis_loc = mav<complex<T>,1>(vis_loc_.data(), {nvis_plane[iw]});
idx_loc_.resize(nvis_plane[iw]);
auto idx_loc = mav<uint32_t,1>(idx_loc_.data(), {nvis_plane[iw]});
mapper.resize(nvis_plane[iw]);
for (size_t ipart=0; ipart<nvis; ++ipart)
if ((int(iw)>=minplane[ipart]) && (iw<minplane[ipart]+supp))
mapper[cnt++] = ipart;
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nvis_plane[iw]; ++i)
idx_loc[i] = srv.getIdx(mapper[i]);
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nx_dirty; ++i)
for (size_t j=0; j<ny_dirty; ++j)
tdirty2(i,j) = tdirty(i,j);
gconf.apply_wscreen(cmav(tdirty2), tdirty2, wcur, true);
gconf.dirty2grid_c(cmav(tdirty2), grid);
grid2vis_c(srv.getBaselines(), gconf, cmav(idx_loc), cmav(grid), vis_loc,
nullmav<T,1>(), wcur, dw);
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nvis_plane[iw]; ++i)
srv.addVis(mapper[i], vis_loc[i]);
}
}
else
{
if (verbosity>0)
cout << "Degridding without w-stacking: " << srv.Nvis()
<< " visibilities" << endl;
if (verbosity>0) cout << "Using " << gconf.Nthreads() << " threads" << endl;
tmpStorage<T,2> grid_({gconf.Nu(), gconf.Nv()});
auto grid=grid_.getMav();
gconf.dirty2grid(dirty, grid);
tmpStorage<complex<T>,2> grid2_(grid.shape());
auto grid2=grid2_.getMav();
hartley2complex(cmav(grid), grid2, gconf.Nthreads());
grid2x_c(gconf, cmav(grid2), srv);
}
}
template<typename T> void dirty2vis_wstack(const Baselines &baselines,
const GridderConfig &gconf, const const_mav<uint32_t,1> &idx,
const const_mav<T,2> &dirty, mav<complex<T>,1> &vis)
{
dirty2x_wstack(gconf, dirty, makeVisServ(baselines, idx, vis, nullmav<T,1>()));
dirty2x(gconf, dirty, makeVisServ(baselines, idx, vis, nullmav<T,1>()), true);
}
......@@ -1606,51 +1568,26 @@ template<typename T> vector<uint32_t> getWgtIndices(const Baselines &baselines,
template<typename T> void gridding(const const_mav<double,2> &uvw,
const const_mav<double,1> &freq, const const_mav<complex<T>,2> &ms,
const const_mav<T,2> &wgt, double pixsize_x, double pixsize_y, double epsilon,
size_t nthreads, const mav<T,2> &dirty, size_t verbosity)
bool do_wstacking, size_t nthreads, const mav<T,2> &dirty, size_t verbosity)
{
Baselines baselines(uvw, freq);
GridderConfig gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y, nthreads);
auto idx = getWgtIndices(baselines, gconf, wgt);
auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()});
x2dirty(gconf, makeMsServ(baselines,idx2,ms,wgt), dirty, verbosity);
x2dirty(gconf, makeMsServ(baselines,idx2,ms,wgt), dirty, do_wstacking, verbosity);
}
template<typename T> void degridding(const const_mav<double,2> &uvw,
const const_mav<double,1> &freq, const const_mav<T,2> &dirty,
const const_mav<T,2> &wgt, double pixsize_x, double pixsize_y, double epsilon,
size_t nthreads, const mav<complex<T>,2> &ms, size_t verbosity)
{
Baselines baselines(uvw, freq);
GridderConfig gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y, nthreads);
auto idx = getWgtIndices(baselines, gconf, wgt);
auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()});
ms.fill(0);
dirty2x(gconf, dirty, makeMsServ(baselines,idx2,ms,wgt), verbosity);
}
template<typename T> void full_gridding(const const_mav<double,2> &uvw,
const const_mav<double,1> &freq, const const_mav<complex<T>,2> &ms,
const const_mav<T,2> &wgt, double pixsize_x, double pixsize_y, double epsilon,
size_t nthreads, const mav<T,2> &dirty, size_t verbosity)
{
Baselines baselines(uvw, freq);
GridderConfig gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y, nthreads);
auto idx = getWgtIndices(baselines, gconf, wgt);
auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()});
x2dirty_wstack2(gconf, makeMsServ(baselines,idx2,ms,wgt), dirty, verbosity);
}
template<typename T> void full_degridding(const const_mav<double,2> &uvw,
const const_mav<double,1> &freq, const const_mav<T,2> &dirty,
const const_mav<T,2> &wgt, double pixsize_x, double pixsize_y, double epsilon,
size_t nthreads, const mav<complex<T>,2> &ms, size_t verbosity)
bool do_wstacking, size_t nthreads, const mav<complex<T>,2> &ms, size_t verbosity)
{
Baselines baselines(uvw, freq);
GridderConfig gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y, nthreads);
auto idx = getWgtIndices(baselines, gconf, wgt);
auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()});
ms.fill(0);
dirty2x_wstack(gconf, dirty, makeMsServ(baselines,idx2,ms,wgt), verbosity);
dirty2x(gconf, dirty, makeMsServ(baselines,idx2,ms,wgt), do_wstacking, verbosity);
}
} // namespace detail
......@@ -1666,8 +1603,8 @@ using detail::vis2grid_c;
using detail::grid2vis_c;
using detail::vis2dirty_wstack;
using detail::dirty2vis_wstack;
using detail::full_gridding;
using detail::full_degridding;
using detail::gridding;
using detail::degridding;
using detail::streamDump__;
using detail::CodeLocation;
} // namespace gridder
......
......@@ -36,10 +36,22 @@ auto None = py::none();
template<typename T>
using pyarr = py::array_t<T, 0>;
template<typename T> pyarr<T> getPyarr(const py::array &arr, const string &name)
{
auto t1=arr.dtype();
auto t2=pybind11::dtype::of<T>();
myassert(t1.is(t2),
"type mismatch for array '", name, "': expected '", t2.kind(),
t2.itemsize(), "', but got '", t1.kind(), t1.itemsize(), "'.");
return arr.cast<pyarr<T>>();
}
template<typename T> bool isPytype(const py::array &arr)
{ return arr.dtype().is(pybind11::dtype::of<T>()); }
template<typename T> pyarr<T> makeArray(const vector<size_t> &shape)
{ return pyarr<T>(shape); }
void checkArray(const py::array &arr, const char *aname,
void checkArray(const py::array &arr, const string &aname,
const vector<size_t> &shape)
{
if (size_t(arr.ndim())!=shape.size())
......@@ -58,7 +70,7 @@ void checkArray(const py::array &arr, const char *aname,
}
template<typename T> pyarr<T> provideArray(const py::object &in,
const vector<size_t> &shape)
const string &name, const vector<size_t> &shape)
{
if (in.is_none())
{
......@@ -69,23 +81,23 @@ template<typename T> pyarr<T> provideArray(const py::object &in,
tmp[i] = T(0);
return tmp_;
}
auto tmp_ = in.cast<pyarr<T>>();
checkArray(tmp_, "temporary", shape);
auto tmp_ = getPyarr<T>(in.cast<py::array>(), name);
checkArray(tmp_, name, shape);
return tmp_;
}
template<typename T> pyarr<T> providePotentialArray(const py::object &in,
const vector<size_t> &shape)
const string &name, const vector<size_t> &shape)
{
if (in.is_none())
return makeArray<T>(vector<size_t>(shape.size(), 0));
auto tmp_ = in.cast<pyarr<T>>();
checkArray(tmp_, "temporary", shape);
auto tmp_ = getPyarr<T>(in.cast<py::array>(), name);
checkArray(tmp_, name, shape);
return tmp_;
}
template<typename T> pyarr<T> provideCArray(py::object &in,
const vector<size_t> &shape)
const string &name, const vector<size_t> &shape)
{
if (in.is_none())
{
......@@ -96,8 +108,8 @@ template<typename T> pyarr<T> provideCArray(py::object &in,
tmp[i] = T(0);
return tmp_;
}
auto tmp_ = in.cast<pyarr<T>>();
checkArray(tmp_, "temporary", shape);
auto tmp_ = getPyarr<T>(in.cast<py::array>(), name);
checkArray(tmp_, name, shape);
return tmp_;
}
......@@ -212,7 +224,7 @@ class PyBaselines: public Baselines
{
auto vis=make_const_mav<1>(vis_);
auto idx=make_const_mav<1>(idx_);
auto res = provideArray<T>(ms_in, {nrows, nchan});
auto res = provideArray<T>(ms_in, "ms_in", {nrows, nchan});
auto ms = make_mav<2>(res);
{
py::gil_scoped_release release;
......@@ -411,9 +423,9 @@ template<typename T> pyarr<complex<T>> Pyvis2grid_c(
auto vis2 = make_const_mav<1>(vis_);
size_t nvis = vis2.shape(0);
auto idx2 = make_const_mav<1>(idx_);