Commit 8f4dbdea authored by Martin Reinecke's avatar Martin Reinecke
Browse files

memory reductions; faster irow, ichan

parent 098e6387
...@@ -334,6 +334,7 @@ template<typename T> class Baselines ...@@ -334,6 +334,7 @@ template<typename T> class Baselines
vector<UVW<T>> coord; vector<UVW<T>> coord;
vector<T> f_over_c; vector<T> f_over_c;
size_t nrows, nchan; size_t nrows, nchan;
size_t shift, mask;
public: public:
Baselines(const const_mav<T,2> &coord_, const const_mav<T,1> &freq) Baselines(const const_mav<T,2> &coord_, const const_mav<T,1> &freq)
...@@ -342,6 +343,9 @@ template<typename T> class Baselines ...@@ -342,6 +343,9 @@ template<typename T> class Baselines
myassert(coord_.shape(1)==3, "dimension mismatch"); myassert(coord_.shape(1)==3, "dimension mismatch");
nrows = coord_.shape(0); nrows = coord_.shape(0);
nchan = freq.shape(0); nchan = freq.shape(0);
shift=0;
while((size_t(1)<<shift)<nchan) ++shift;
mask=(size_t(1)<<shift)-1;
myassert(nrows*nchan<(size_t(1)<<32), "too many entries in MS"); myassert(nrows*nchan<(size_t(1)<<32), "too many entries in MS");
f_over_c.resize(nchan); f_over_c.resize(nchan);
for (size_t i=0; i<nchan; ++i) for (size_t i=0; i<nchan; ++i)
...@@ -355,11 +359,7 @@ template<typename T> class Baselines ...@@ -355,11 +359,7 @@ template<typename T> class Baselines
} }
RowChan getRowChan(uint32_t index) const RowChan getRowChan(uint32_t index) const
{ { return RowChan{index>>shift, index&mask}; }
size_t irow = index/nchan;
size_t ichan = index-nchan*irow;
return RowChan{irow, ichan};
}
UVW<T> effectiveCoord(const RowChan &rc) const UVW<T> effectiveCoord(const RowChan &rc) const
{ return coord[rc.row]*f_over_c[rc.chan]; } { return coord[rc.row]*f_over_c[rc.chan]; }
...@@ -367,6 +367,8 @@ template<typename T> class Baselines ...@@ -367,6 +367,8 @@ template<typename T> class Baselines
{ return effectiveCoord(getRowChan(index)); } { return effectiveCoord(getRowChan(index)); }
size_t Nrows() const { return nrows; } size_t Nrows() const { return nrows; }
size_t Nchannels() const { return nchan; } size_t Nchannels() const { return nchan; }
uint32_t getIdx(size_t irow, size_t ichan) const
{ return ichan+(irow<<shift); }
void effectiveUVW(const mav<const uint32_t,1> &idx, mav<T,2> &res) const void effectiveUVW(const mav<const uint32_t,1> &idx, mav<T,2> &res) const
{ {
...@@ -703,8 +705,6 @@ template<class T, class T2, class T3> class VisServ ...@@ -703,8 +705,6 @@ template<class T, class T2, class T3> class VisServ
have_wgt = wgt.size()!=0; have_wgt = wgt.size()!=0;
if (have_wgt) checkShape(wgt.shape(), {nvis}); if (have_wgt) checkShape(wgt.shape(), {nvis});
} }
VisServ(const VisServ &orig, const const_mav<uint32_t,1> &newidx)
: VisServ(orig.baselines, newidx, orig.vis, orig.wgt) {}
size_t Nvis() const { return nvis; } size_t Nvis() const { return nvis; }
const Baselines<T> &getBaselines() const { return baselines; } const Baselines<T> &getBaselines() const { return baselines; }
UVW<T> getCoord(size_t i) const UVW<T> getCoord(size_t i) const
...@@ -1107,14 +1107,13 @@ template<typename T, typename Serv> void wstack_common( ...@@ -1107,14 +1107,13 @@ template<typename T, typename Serv> void wstack_common(
// determine w values for every visibility, and min/max w; // determine w values for every visibility, and min/max w;
wmin=T(1e38); wmin=T(1e38);
T wmax=T(-1e38); T wmax=T(-1e38);
vector<T> wval(nvis);
#pragma omp parallel for num_threads(nthreads) reduction(min:wmin) reduction(max:wmax) #pragma omp parallel for num_threads(nthreads) reduction(min:wmin) reduction(max:wmax)
for (size_t ipart=0; ipart<nvis; ++ipart) for (size_t ipart=0; ipart<nvis; ++ipart)
{ {
wval[ipart] = abs(srv.getCoord(ipart).w); auto wval = abs(srv.getCoord(ipart).w);
wmin = min(wmin,wval[ipart]); wmin = min(wmin,wval);
wmax = max(wmax,wval[ipart]); wmax = max(wmax,wval);
} }
if (verbosity>0) cout << "Using " << nthreads << " threads" << endl; if (verbosity>0) cout << "Using " << nthreads << " threads" << endl;
if (verbosity>0) cout << "W range: " << wmin << " to " << wmax << endl; if (verbosity>0) cout << "W range: " << wmin << " to " << wmax << endl;
...@@ -1139,7 +1138,7 @@ template<typename T, typename Serv> void wstack_common( ...@@ -1139,7 +1138,7 @@ template<typename T, typename Serv> void wstack_common(
#pragma omp for #pragma omp for
for (size_t ipart=0; ipart<nvis; ++ipart) for (size_t ipart=0; ipart<nvis; ++ipart)
{ {
int plane0 = int((wval[ipart]-wmin)/dw+0.5*w_supp)-int(w_supp-1); int plane0 = int((abs(srv.getCoord(ipart).w)-wmin)/dw+0.5*w_supp)-int(w_supp-1);
minplane[ipart]=plane0; minplane[ipart]=plane0;
for (int i=max<int>(0,plane0); i<min<int>(nplanes,plane0+w_supp); ++i) for (int i=max<int>(0,plane0); i<min<int>(nplanes,plane0+w_supp); ++i)
++nvploc[i]; ++nvploc[i];
...@@ -1219,6 +1218,63 @@ template<typename T> void vis2dirty_wstack(const Baselines<T> &baselines, ...@@ -1219,6 +1218,63 @@ template<typename T> void vis2dirty_wstack(const Baselines<T> &baselines,
{ {
x2dirty_wstack(gconf, VisServ<T,const_mav<complex<T>,1>,const_mav<T,1>>(baselines, idx, vis, const_mav<T,1>(nullptr,{0})), dirty); x2dirty_wstack(gconf, VisServ<T,const_mav<complex<T>,1>,const_mav<T,1>>(baselines, idx, vis, const_mav<T,1>(nullptr,{0})), dirty);
} }
template<typename T, typename Serv> void x2dirty_wstack2(
const GridderConfig<T> &gconf, const Serv &srv, const mav<T,2> &dirty, size_t verbosity=0)
{
auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty();
auto nu=gconf.Nu();
auto nv=gconf.Nv();
size_t nvis = srv.Nvis();
auto w_supp = gconf.Supp();
size_t nthreads = gconf.Nthreads();
EC_Kernel_with_correction<T> kernel(w_supp, nthreads);
T wmin;
double dw;
size_t nplanes;
vector<size_t> nvis_plane;
vector<int> minplane;
if (verbosity>0) cout << "Gridding using improved w-stacking" << endl;
wstack_common(gconf, srv, wmin, dw, nplanes, nvis_plane, minplane, verbosity);
dirty.fill(0);
vector<uint32_t> newidx;
tmpStorage<complex<T>,2> grid_({nu,nv});
auto grid=grid_.getMav();
tmpStorage<complex<T>,2> tdirty_({nx_dirty,ny_dirty});
auto tdirty=tdirty_.getMav();
for (size_t iw=0; iw<nplanes; ++iw)
{
if (nvis_plane[iw]==0) continue;
if (verbosity>1)
cout << "Working on plane " << iw << " containing " << nvis_plane[iw]
<< " visibilities" << endl;
T wcur = wmin+iw*dw;
size_t cnt=0;
newidx.resize(nvis_plane[iw]);
for (size_t ipart=0; ipart<nvis; ++ipart)
if ((int(iw)>=minplane[ipart]) && (iw<minplane[ipart]+w_supp))
newidx[cnt++] = srv.getIdx(ipart);
grid.fill(0);
Serv newsrv(srv, const_mav<uint32_t, 1>(newidx.data(), {newidx.size()}));
x2grid_c(gconf, newsrv, grid, wcur, T(dw));
gconf.grid2dirty_c(grid, tdirty);
gconf.apply_wscreen(tdirty, tdirty, wcur, false);
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nx_dirty; ++i)
for (size_t j=0; j<ny_dirty; ++j)
dirty(i,j) += tdirty(i,j).real();
}
// correct for w gridding
apply_wcorr(gconf, dirty, kernel, dw);
}
template<typename T> void vis2dirty_wstack2(const Baselines<T> &baselines,
const GridderConfig<T> &gconf, const const_mav<uint32_t,1> &idx,
const const_mav<complex<T>,1> &vis, mav<T,2> &dirty)
{
x2dirty_wstack2(gconf, VisServ<T,const_mav<complex<T>,1>,const_mav<T,1>>(baselines, idx, vis, const_mav<T,1>(nullptr,{0})), dirty);
}
template<typename T, typename Serv> void dirty2x_wstack( template<typename T, typename Serv> void dirty2x_wstack(
const GridderConfig<T> &gconf, const const_mav<T,2> &dirty, const GridderConfig<T> &gconf, const const_mav<T,2> &dirty,
...@@ -1367,7 +1423,7 @@ template<typename T> void fillIdx(const Baselines<T> &baselines, ...@@ -1367,7 +1423,7 @@ template<typename T> void fillIdx(const Baselines<T> &baselines,
{ {
auto w = abs(baselines.effectiveCoord(RowChan{irow,size_t(ichan)}).w); auto w = abs(baselines.effectiveCoord(RowChan{irow,size_t(ichan)}).w);
if ((w>=wmin) && (w<wmax)) if ((w>=wmin) && (w<wmax))
res[acc[tmp[idx++]]++] = irow*nchan+ichan; res[acc[tmp[idx++]]++] = baselines.getIdx(irow, ichan);
} }
} }
...@@ -1407,7 +1463,7 @@ template<typename T> vector<uint32_t> getWgtIndices(const Baselines<T> &baseline ...@@ -1407,7 +1463,7 @@ template<typename T> vector<uint32_t> getWgtIndices(const Baselines<T> &baseline
for (size_t irow=0, idx=0; irow<nrow; ++irow) for (size_t irow=0, idx=0; irow<nrow; ++irow)
for (size_t ichan=0; ichan<nchan; ++ichan) for (size_t ichan=0; ichan<nchan; ++ichan)
if ((!have_wgt) || (wgt(irow,ichan)!=0)) if ((!have_wgt) || (wgt(irow,ichan)!=0))
res[acc[tmp[idx++]]++] = irow*nchan+ichan; res[acc[tmp[idx++]]++] = baselines.getIdx(irow, ichan);
return res; return res;
} }
...@@ -1448,7 +1504,7 @@ template<typename T> void full_gridding(const const_mav<T,2> &uvw, ...@@ -1448,7 +1504,7 @@ template<typename T> void full_gridding(const const_mav<T,2> &uvw,
auto idx = getWgtIndices(baselines, gconf, wgt); auto idx = getWgtIndices(baselines, gconf, wgt);
auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()}); auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()});
auto srv = MsServ<T,const_mav<complex<T>,2>,const_mav<T,2>>(baselines,idx2,ms,wgt); auto srv = MsServ<T,const_mav<complex<T>,2>,const_mav<T,2>>(baselines,idx2,ms,wgt);
x2dirty_wstack(gconf, srv, dirty, verbosity); x2dirty_wstack2(gconf, srv, dirty, verbosity);
} }
template<typename T> void full_degridding(const const_mav<T,2> &uvw, template<typename T> void full_degridding(const const_mav<T,2> &uvw,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment