Commit 318cec49 authored by Martin Reinecke's avatar Martin Reinecke

slight restructuring

parent c6b28459
......@@ -907,6 +907,8 @@ template<class T, class T2> class VisServ
bool have_wgt;
public:
using Tsub = SubServ<T, VisServ>;
VisServ(const Baselines &baselines_,
const const_mav<idx_t,1> &idx_, T2 vis_, const const_mav<T,1> &wgt_)
: baselines(baselines_), idx(idx_), vis(vis_), wgt(wgt_),
......@@ -915,8 +917,8 @@ template<class T, class T2> class VisServ
checkShape(idx.shape(), {nvis});
if (have_wgt) checkShape(wgt.shape(), {nvis});
}
SubServ<T, VisServ> getSubserv(const const_mav<idx_t,1> &subidx) const
{ return SubServ<T, VisServ>(*this, subidx); }
Tsub getSubserv(const const_mav<idx_t,1> &subidx) const
{ return Tsub(*this, subidx); }
size_t Nvis() const { return nvis; }
const Baselines &getBaselines() const { return baselines; }
UVW getCoord(size_t i) const
......@@ -945,6 +947,8 @@ template<class T, class T2> class MsServ
bool have_wgt;
public:
using Tsub = SubServ<T, MsServ>;
MsServ(const Baselines &baselines_,
const const_mav<idx_t,1> &idx_, T2 ms_, const const_mav<T,2> &wgt_)
: baselines(baselines_), idx(idx_), ms(ms_), wgt(wgt_),
......@@ -953,8 +957,8 @@ template<class T, class T2> class MsServ
checkShape(ms.shape(), {baselines.Nrows(), baselines.Nchannels()});
if (have_wgt) checkShape(wgt.shape(), ms.shape());
}
SubServ<T, MsServ> getSubserv(const const_mav<idx_t,1> &subidx) const
{ return SubServ<T, MsServ>(*this, subidx); }
Tsub getSubserv(const const_mav<idx_t,1> &subidx) const
{ return Tsub(*this, subidx); }
size_t Nvis() const { return nvis; }
const Baselines &getBaselines() const { return baselines; }
UVW getCoord(size_t i) const
......@@ -1272,119 +1276,147 @@ template<typename T> void apply_wcorr(const GridderConfig &gconf,
}
}
template<typename Serv> void wminmax(const GridderConfig &gconf,
const Serv &srv, double &wmin, double &wmax)
template<typename Serv> class WgridHelper
{
size_t nvis = srv.Nvis();
size_t nthreads = gconf.Nthreads();
private:
const Serv &srv;
double wmin, dw;
size_t nplanes, supp;
vector<vector<idx_t>> minplane;
size_t verbosity;
int curplane;
vector<idx_t> subidx;
static void wminmax(const GridderConfig &gconf,
const Serv &srv, double &wmin, double &wmax)
{
size_t nvis = srv.Nvis();
size_t nthreads = gconf.Nthreads();
wmin= 1e38;
wmax=-1e38;
// FIXME maybe this can be done more intelligently
wmin= 1e38;
wmax=-1e38;
// FIXME maybe this can be done more intelligently
#pragma omp parallel for num_threads(nthreads) reduction(min:wmin) reduction(max:wmax)
for (size_t ipart=0; ipart<nvis; ++ipart)
{
auto wval = abs(srv.getCoord(ipart).w);
wmin = min(wmin,wval);
wmax = max(wmax,wval);
}
}
for (size_t ipart=0; ipart<nvis; ++ipart)
{
auto wval = abs(srv.getCoord(ipart).w);
wmin = min(wmin,wval);
wmax = max(wmax,wval);
}
}
template<typename T> void update_idx(vector<T> &v, const vector<T> &add,
const vector<T> &del)
{
myassert(v.size()>=del.size(), "must not happen");
vector<T> res;
res.reserve((v.size()+add.size())-del.size());
auto iin=v.begin(), ein=v.end();
auto iadd=add.begin(), eadd=add.end();
auto irem=del.begin(), erem=del.end();
while(iin!=ein)
{
if ((irem!=erem) && (*iin==*irem))
{ ++irem; ++iin; } // skip removed entry
else if ((iadd!=eadd) && (*iadd<*iin))
res.push_back(*(iadd++)); // add new entry
else
res.push_back(*(iin++));
}
myassert(irem==erem, "must not happen");
while(iadd!=eadd)
res.push_back(*(iadd++));
myassert(res.size()==(v.size()+add.size())-del.size(), "must not happen");
v.swap(res);
}
template<typename T> static void update_idx(vector<T> &v, const vector<T> &add,
const vector<T> &del)
{
myassert(v.size()>=del.size(), "must not happen");
vector<T> res;
res.reserve((v.size()+add.size())-del.size());
auto iin=v.begin(), ein=v.end();
auto iadd=add.begin(), eadd=add.end();
auto irem=del.begin(), erem=del.end();
while(iin!=ein)
{
if ((irem!=erem) && (*iin==*irem))
{ ++irem; ++iin; } // skip removed entry
else if ((iadd!=eadd) && (*iadd<*iin))
res.push_back(*(iadd++)); // add new entry
else
res.push_back(*(iin++));
}
myassert(irem==erem, "must not happen");
while(iadd!=eadd)
res.push_back(*(iadd++));
myassert(res.size()==(v.size()+add.size())-del.size(), "must not happen");
v.swap(res);
}
template<typename Serv> void wstack_common(
const GridderConfig &gconf, const Serv &srv, double &wmin, double &dw,
size_t &nplanes, vector<vector<idx_t>> &minplane, size_t verbosity)
{
size_t nvis = srv.Nvis();
size_t nthreads = gconf.Nthreads();
double wmax;
public:
WgridHelper(const GridderConfig &gconf, const Serv &srv_, size_t verbosity_)
: srv(srv_), verbosity(verbosity_), curplane(-1)
{
size_t nvis = srv.Nvis();
size_t nthreads = gconf.Nthreads();
double wmax;
wminmax(gconf, srv, wmin, wmax);
wminmax(gconf, srv, wmin, wmax);
#ifdef _OPENMP
if (verbosity>0) cout << "Using " << nthreads << " threads" << endl;
if (verbosity>0) cout << "Using " << nthreads << " threads" << endl;
#else
if (verbosity>0) cout << "Using single-threaded mode" << endl;
if (verbosity>0) cout << "Using single-threaded mode" << endl;
#endif
if (verbosity>0) cout << "W range: " << wmin << " to " << wmax << endl;
double x0 = -0.5*gconf.Nxdirty()*gconf.Pixsize_x(),
y0 = -0.5*gconf.Nydirty()*gconf.Pixsize_y();
double nmin = sqrt(max(1.-x0*x0-y0*y0,0.))-1.;
dw = 0.25/abs(nmin);
nplanes = size_t((wmax-wmin)/dw+2);
dw = (1.+1e-13)*(wmax-wmin)/(nplanes-1);
auto supp = gconf.Supp();
wmin -= (0.5*supp-1)*dw;
wmax += (0.5*supp-1)*dw;
nplanes += supp-2;
if (verbosity>0) cout << "Kernel support: " << supp << endl;
if (verbosity>0) cout << "nplanes: " << nplanes << endl;
minplane.resize(nplanes);
vector<size_t> tcnt(my_max_threads()*nplanes,0);
if (verbosity>0) cout << "W range: " << wmin << " to " << wmax << endl;
double x0 = -0.5*gconf.Nxdirty()*gconf.Pixsize_x(),
y0 = -0.5*gconf.Nydirty()*gconf.Pixsize_y();
double nmin = sqrt(max(1.-x0*x0-y0*y0,0.))-1.;
dw = 0.25/abs(nmin);
nplanes = size_t((wmax-wmin)/dw+2);
dw = (1.+1e-13)*(wmax-wmin)/(nplanes-1);
supp = gconf.Supp();
wmin -= (0.5*supp-1)*dw;
wmax += (0.5*supp-1)*dw;
nplanes += supp-2;
if (verbosity>0) cout << "Kernel support: " << supp << endl;
if (verbosity>0) cout << "nplanes: " << nplanes << endl;
minplane.resize(nplanes);
vector<size_t> tcnt(my_max_threads()*nplanes,0);
#pragma omp parallel num_threads(nthreads)
{
vector<size_t> mytcnt(nplanes,0);
vector<size_t> nvp(nplanes,0);
auto mythread = my_thread_num();
vector<size_t> mytcnt(nplanes,0);
auto mythread = my_thread_num();
#pragma omp for schedule(static)
for (size_t ipart=0; ipart<nvis; ++ipart)
{
int plane0 = max(0,int(1+(abs(srv.getCoord(ipart).w)-(0.5*supp*dw)-wmin)/dw));
++mytcnt[plane0];
for (int i=plane0; i<min<int>(nplanes,plane0+supp); ++i)
++nvp[i];
}
for (size_t ipart=0; ipart<nvis; ++ipart)
{
int plane0 = max(0,int(1+(abs(srv.getCoord(ipart).w)-(0.5*supp*dw)-wmin)/dw));
++mytcnt[plane0];
}
#pragma omp critical (wstack_common)
for (size_t i=0; i<nplanes; ++i)
tcnt[mythread*nplanes+i] = mytcnt[i];
for (size_t i=0; i<nplanes; ++i)
tcnt[mythread*nplanes+i] = mytcnt[i];
#pragma omp barrier
#pragma omp single
for (size_t j=0; j<nplanes; ++j)
{
size_t l=0;
for (int i=0; i<my_num_threads(); ++i)
l+=tcnt[i*nplanes+j];
minplane[j].resize(l);
}
for (size_t j=0; j<nplanes; ++j)
{
size_t l=0;
for (int i=0; i<my_num_threads(); ++i)
l+=tcnt[i*nplanes+j];
minplane[j].resize(l);
}
#pragma omp barrier
vector<size_t> myofs(nplanes, 0);
for (size_t j=0; j<nplanes; ++j)
for (int i=0; i<mythread; ++i)
myofs[j]+=tcnt[i*nplanes+j];
vector<size_t> myofs(nplanes, 0);
for (size_t j=0; j<nplanes; ++j)
for (int i=0; i<mythread; ++i)
myofs[j]+=tcnt[i*nplanes+j];
#pragma omp for schedule(static)
for (size_t ipart=0; ipart<nvis; ++ipart)
{
int plane0 = max(0,int(1+(abs(srv.getCoord(ipart).w)-(0.5*supp*dw)-wmin)/dw));
minplane[plane0][myofs[plane0]++]=idx_t(ipart);
}
for (size_t ipart=0; ipart<nvis; ++ipart)
{
int plane0 = max(0,int(1+(abs(srv.getCoord(ipart).w)-(0.5*supp*dw)-wmin)/dw));
minplane[plane0][myofs[plane0]++]=idx_t(ipart);
}
}
}
}
typename Serv::Tsub getSubserv() const
{
auto subidx2 = const_mav<idx_t, 1>(subidx.data(), {subidx.size()});
return srv.getSubserv(subidx2);
}
double W() const { return wmin+curplane*dw; }
size_t Nvis() const { return subidx.size(); }
double DW() const { return dw; }
bool advance()
{
if (++curplane>=int(nplanes)) return false;
update_idx(subidx, minplane[curplane], curplane>=int(supp) ? minplane[curplane-supp] : vector<idx_t>());
if (verbosity>1)
cout << "Working on plane " << curplane << " containing " << subidx.size()
<< " visibilities" << endl;
return true;
}
};
template<typename T, typename Serv> void x2dirty(
const GridderConfig &gconf, const Serv &srv, const mav<T,2> &dirty,
......@@ -1392,44 +1424,29 @@ template<typename T, typename Serv> void x2dirty(
{
if (do_wstacking)
{
auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty();
auto supp = gconf.Supp();
size_t nthreads = gconf.Nthreads();
double wmin;
double dw;
size_t nplanes;
vector<vector<idx_t>> minplane;
if (verbosity>0) cout << "Gridding using improved w-stacking" << endl;
wstack_common(gconf, srv, wmin, dw, nplanes, minplane, verbosity);
WgridHelper<Serv> hlp(gconf, srv, verbosity);
double dw = hlp.DW();
dirty.fill(0);
vector<idx_t> subidx;
tmpStorage<complex<T>,2> grid_({gconf.Nu(),gconf.Nv()});
auto grid=grid_.getMav();
tmpStorage<complex<T>,2> tdirty_(dirty.shape());
auto tdirty=tdirty_.getMav();
for (size_t iw=0; iw<nplanes; ++iw)
while(hlp.advance()) // iterate over w planes
{
update_idx(subidx, minplane[iw], iw>=supp ? minplane[iw-supp] : vector<idx_t>());
if (subidx.size()==0) continue;
if (verbosity>1)
cout << "Working on plane " << iw << " containing " << subidx.size()
<< " visibilities" << endl;
double wcur = wmin+iw*dw;
auto subidx2 = const_mav<idx_t, 1>(subidx.data(), {subidx.size()});
auto subsrv = srv.getSubserv(subidx2);
if (hlp.Nvis()==0) continue;
grid.fill(0);
x2grid_c(gconf, subsrv, grid, wcur, dw);
x2grid_c(gconf, hlp.getSubserv(), grid, hlp.W(), dw);
gconf.grid2dirty_c(cmav(grid), tdirty);
gconf.apply_wscreen(cmav(tdirty), tdirty, wcur, true);
gconf.apply_wscreen(cmav(tdirty), tdirty, hlp.W(), true);
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nx_dirty; ++i)
for (size_t j=0; j<ny_dirty; ++j)
for (size_t i=0; i<gconf.Nxdirty(); ++i)
for (size_t j=0; j<gconf.Nydirty(); ++j)
dirty(i,j) += tdirty(i,j).real();
}
// correct for w gridding
apply_wcorr(gconf, dirty, ES_Kernel(supp, nthreads), dw);
apply_wcorr(gconf, dirty, ES_Kernel(gconf.Supp(), nthreads), dw);
}
else
{
......@@ -1463,49 +1480,34 @@ template<typename T, typename Serv> void dirty2x(
{
if (do_wstacking)
{
auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty();
auto supp = gconf.Supp();
size_t nx_dirty=gconf.Nxdirty(), ny_dirty=gconf.Nydirty();
size_t nthreads = gconf.Nthreads();
double wmin;
double dw;
size_t nplanes;
vector<vector<idx_t>> minplane;
if (verbosity>0) cout << "Degridding using improved w-stacking" << endl;
wstack_common(gconf, srv, wmin, dw, nplanes, minplane, verbosity);
WgridHelper<Serv> hlp(gconf, srv, verbosity);
double dw = hlp.DW();
tmpStorage<T,2> tdirty_({nx_dirty,ny_dirty});
auto tdirty=tdirty_.getMav();
for (size_t i=0; i<nx_dirty; ++i)
for (size_t j=0; j<ny_dirty; ++j)
tdirty(i,j) = dirty(i,j);
// correct for w gridding
apply_wcorr(gconf, tdirty, ES_Kernel(supp, nthreads), dw);
vector<idx_t> subidx;
apply_wcorr(gconf, tdirty, ES_Kernel(gconf.Supp(), nthreads), dw);
tmpStorage<complex<T>,2> grid_({gconf.Nu(),gconf.Nv()});
auto grid=grid_.getMav();
tmpStorage<complex<T>,2> tdirty2_({nx_dirty,ny_dirty});
tmpStorage<complex<T>,2> tdirty2_({nx_dirty, ny_dirty});
auto tdirty2=tdirty2_.getMav();
for (size_t iw=0; iw<nplanes; ++iw)
while(hlp.advance()) // iterate over w planes
{
update_idx(subidx, minplane[iw], iw>=supp ? minplane[iw-supp] : vector<idx_t>());
if (subidx.size()==0) continue;
if (verbosity>1)
cout << "Working on plane " << iw << " containing " << subidx.size()
<< " visibilities" << endl;
double wcur = wmin+iw*dw;
if (hlp.Nvis()==0) continue;
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nx_dirty; ++i)
for (size_t j=0; j<ny_dirty; ++j)
tdirty2(i,j) = tdirty(i,j);
gconf.apply_wscreen(cmav(tdirty2), tdirty2, wcur, false);
gconf.apply_wscreen(cmav(tdirty2), tdirty2, hlp.W(), false);
gconf.dirty2grid_c(cmav(tdirty2), grid);
auto subidx2 = const_mav<idx_t, 1>(subidx.data(), {subidx.size()});
auto subsrv = srv.getSubserv(subidx2);
grid2x_c(gconf, cmav(grid), subsrv, wcur, dw);
grid2x_c(gconf, cmav(grid), hlp.getSubserv(), hlp.W(), dw);
}
}
else
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment