Commit 538b4a1c authored by Martin Reinecke's avatar Martin Reinecke
Browse files

experimental templated version

parent 6db3283a
......@@ -61,22 +61,12 @@ template<typename T, size_t ndim> class mav
{ return operator()(i); }
const T &operator[](size_t i) const
{ return operator()(i); }
T &operator()(size_t i)
T &operator()(size_t i) const
{
static_assert(ndim==1, "ndim must be 1");
return d[str[0]*i];
}
const T &operator()(size_t i) const
{
static_assert(ndim==1, "ndim must be 1");
return d[str[0]*i];
}
T &operator()(size_t i, size_t j)
{
static_assert(ndim==2, "ndim must be 2");
return d[str[0]*i + str[1]*j];
}
const T &operator()(size_t i, size_t j) const
T &operator()(size_t i, size_t j) const
{
static_assert(ndim==2, "ndim must be 2");
return d[str[0]*i + str[1]*j];
......@@ -653,63 +643,84 @@ template<typename T, typename T2=complex<T>> class Helper
}
};
template<typename T> void vis2grid_c
(const Baselines<T> &baselines, const GridderConfig<T> &gconf,
const const_mav<uint32_t,1> &idx, const const_mav<complex<T>,1> &vis,
mav<complex<T>,2> &grid, const const_mav<T,1> &wgt)
template<class T, class T2, class T3> class VisServ
{
size_t nvis = vis.shape(0);
checkShape(idx.shape(), {nvis});
bool have_wgt = wgt.size()!=0;
if (have_wgt) checkShape(wgt.shape(), {nvis});
size_t nu=gconf.Nu(), nv=gconf.Nv();
checkShape(grid.shape(), {nu, nv});
myassert(grid.contiguous(), "grid is not contiguous");
T beta = gconf.Beta();
size_t supp = gconf.Supp();
private:
const Baselines<T> &baselines;
const const_mav<uint32_t,1> &idx;
T2 vis;
const T3 &wgt;
size_t nvis;
bool have_wgt;
#pragma omp parallel num_threads(nthreads)
{
Helper<T> hlp(gconf, nullptr, grid.data());
T emb = exp(-2*beta);
int jump = hlp.lineJump();
const T * ku = hlp.kernel.data();
const T * kv = hlp.kernel.data()+supp;
public:
VisServ(const Baselines<T> &baselines_,
const const_mav<uint32_t,1> &idx_, T2 vis_, const T3 &wgt_)
: baselines(baselines_), idx(idx_), vis(vis_), wgt(wgt_)
{
nvis = vis.shape(0);
checkShape(idx.shape(), {nvis});
have_wgt = wgt.size()!=0;
if (have_wgt) checkShape(wgt.shape(), {nvis});
}
size_t Nvis() const { return nvis; }
UVW<T> getCoord(size_t i) const
{ return baselines.effectiveCoord(idx[i]); }
complex<T> getVis(size_t i) const
{ return have_wgt ? vis(i)*wgt(i) : vis(i); }
uint32_t getIdx(size_t i) const { return idx[i]; }
void setVis (size_t i, const complex<T> &v) const
{ vis(i) = have_wgt ? v*wgt(i) : v; }
void addVis (size_t i, const complex<T> &v) const
{ vis(i) += have_wgt ? v*wgt(i) : v; }
};
template<class T, class T2, class T3> class MsServ
{
private:
const Baselines<T> &baselines;
const const_mav<uint32_t,1> &idx;
T2 ms;
const T3 &wgt;
size_t nvis;
bool have_wgt;
// Loop over sampling points
#pragma omp for schedule(guided,100)
for (size_t ipart=0; ipart<nvis; ++ipart)
{
UVW<T> coord = baselines.effectiveCoord(idx(ipart));
hlp.prep(coord.u, coord.v);
auto * ptr = hlp.p0w;
auto v(vis(ipart)*emb);
if (have_wgt)
v*=wgt(ipart);
for (size_t cu=0; cu<supp; ++cu)
public:
MsServ(const Baselines<T> &baselines_,
const const_mav<uint32_t,1> &idx_, T2 ms_, const T3 &wgt_)
: baselines(baselines_), idx(idx_), ms(ms_), wgt(wgt_)
{
complex<T> tmp(v*ku[cu]);
for (size_t cv=0; cv<supp; ++cv)
ptr[cv] += tmp*kv[cv];
ptr+=jump;
auto nrows = baselines.Nrows();
auto nchan = baselines.Nchannels();
nvis = idx.shape(0);
checkShape(ms.shape(), {nrows, nchan});
have_wgt = wgt.size()!=0;
if (have_wgt) checkShape(wgt.shape(), {nrows, nchan});
}
}
} // end of parallel region
}
size_t Nvis() const { return nvis; }
UVW<T> getCoord(size_t i) const
{ return baselines.effectiveCoord(idx(i)); }
complex<T> getVis(size_t i) const
{
auto rc = baselines.getRowChan(idx(i));
return have_wgt ? ms(rc.row, rc.chan)*wgt(rc.row, rc.chan)
: ms(rc.row, rc.chan);
}
uint32_t getIdx(size_t i) const { return idx[i]; }
void setVis (size_t i, const complex<T> &v) const
{
auto rc = baselines.getRowChan(idx(i));
ms(rc.row, rc.chan) = have_wgt ? v*wgt(rc.row, rc.chan) : v;
}
void addVis (size_t i, const complex<T> &v) const
{
auto rc = baselines.getRowChan(idx(i));
ms(rc.row, rc.chan) += have_wgt ? v*wgt(rc.row, rc.chan) : v;
}
};
template<typename T> void ms2grid_c
(const Baselines<T> &baselines, const GridderConfig<T> &gconf,
const const_mav<uint32_t,1> &idx, const const_mav<complex<T>,2> &ms,
mav<complex<T>,2> &grid, const const_mav<T,2> &wgt)
template<typename T, typename Serv> void x2grid_c
(const GridderConfig<T> &gconf, const Serv &srv, mav<complex<T>,2> &grid)
{
auto nrows = baselines.Nrows();
auto nchan = baselines.Nchannels();
size_t nvis = idx.shape(0);
checkShape(ms.shape(), {nrows, nchan});
bool have_wgt = wgt.size()!=0;
if (have_wgt) checkShape(wgt.shape(), {nrows, nchan});
size_t nu=gconf.Nu(), nv=gconf.Nv();
checkShape(grid.shape(), {nu, nv});
myassert(grid.contiguous(), "grid is not contiguous");
......@@ -723,18 +734,16 @@ template<typename T> void ms2grid_c
int jump = hlp.lineJump();
const T * ku = hlp.kernel.data();
const T * kv = hlp.kernel.data()+supp;
size_t np = srv.Nvis();
// Loop over sampling points
#pragma omp for schedule(guided,100)
for (size_t ipart=0; ipart<nvis; ++ipart)
for (size_t ipart=0; ipart<np; ++ipart)
{
auto rc = baselines.getRowChan(idx(ipart));
UVW<T> coord = baselines.effectiveCoord(rc);
UVW<T> coord = srv.getCoord(ipart);
hlp.prep(coord.u, coord.v);
auto * ptr = hlp.p0w;
auto v(ms(rc.row,rc.chan)*emb);
if (have_wgt)
v*=wgt(rc.row,rc.chan);
auto v(srv.getVis(ipart)*emb);
for (size_t cu=0; cu<supp; ++cu)
{
complex<T> tmp(v*ku[cu]);
......@@ -746,16 +755,25 @@ template<typename T> void ms2grid_c
} // end of parallel region
}
template<typename T> void grid2vis_c
template<typename T> void vis2grid_c
(const Baselines<T> &baselines, const GridderConfig<T> &gconf,
const const_mav<uint32_t,1> &idx, const const_mav<complex<T>,2> &grid,
mav<complex<T>,1> &vis, const const_mav<T,1> &wgt)
const const_mav<uint32_t,1> &idx, const const_mav<complex<T>,1> &vis,
mav<complex<T>,2> &grid, const const_mav<T,1> &wgt)
{
size_t nvis = vis.shape(0);
checkShape(idx.shape(), {nvis});
bool have_wgt = wgt.size()!=0;
if (have_wgt) checkShape(wgt.shape(), {nvis});
x2grid_c(gconf, VisServ<T,const_mav<complex<T>,1>,const_mav<T,1>> (baselines, idx, vis, wgt), grid);
}
template<typename T> void ms2grid_c
(const Baselines<T> &baselines, const GridderConfig<T> &gconf,
const const_mav<uint32_t,1> &idx, const const_mav<complex<T>,2> &ms,
mav<complex<T>,2> &grid, const const_mav<T,2> &wgt)
{
x2grid_c(gconf, MsServ<T,const_mav<complex<T>,2>,const_mav<T,2>> (baselines, idx, ms, wgt), grid);
}
template<typename T, typename Serv> void grid2x_c
(const GridderConfig<T> &gconf, const const_mav<complex<T>,2> &grid, const Serv &srv)
{
size_t nu=gconf.Nu(), nv=gconf.Nv();
checkShape(grid.shape(), {nu, nv});
myassert(grid.contiguous(), "grid is not contiguous");
......@@ -770,11 +788,12 @@ template<typename T> void grid2vis_c
int jump = hlp.lineJump();
const T * ku = hlp.kernel.data();
const T * kv = hlp.kernel.data()+supp;
size_t np = srv.Nvis();
#pragma omp for schedule(guided,100)
for (size_t ipart=0; ipart<nvis; ++ipart)
for (size_t ipart=0; ipart<np; ++ipart)
{
UVW<T> coord = baselines.effectiveCoord(idx(ipart));
UVW<T> coord = srv.getCoord(ipart);
hlp.prep(coord.u, coord.v);
complex<T> r = 0;
const auto * ptr = hlp.p0r;
......@@ -786,59 +805,24 @@ template<typename T> void grid2vis_c
r += tmp*ku[cu];
ptr += jump;
}
if (have_wgt) r*=wgt[ipart];
vis[ipart] = r*emb;
srv.setVis(ipart, r*emb);
}
}
}
template<typename T> void grid2vis_c
(const Baselines<T> &baselines, const GridderConfig<T> &gconf,
const const_mav<uint32_t,1> &idx, const const_mav<complex<T>,2> &grid,
mav<complex<T>,1> &vis, const const_mav<T,1> &wgt)
{
grid2x_c(gconf, grid, VisServ<T,mav<complex<T>,1>,const_mav<T,1>> (baselines, idx, vis, wgt));
}
template<typename T> void grid2ms_c
(const Baselines<T> &baselines, const GridderConfig<T> &gconf,
const const_mav<uint32_t,1> &idx, const const_mav<complex<T>,2> &grid,
mav<complex<T>,2> &ms, const const_mav<T,2> &wgt)
{
auto nrows = baselines.Nrows();
auto nchan = baselines.Nchannels();
size_t nvis = idx.shape(0);
checkShape(ms.shape(), {nrows, nchan});
bool have_wgt = wgt.size()!=0;
if (have_wgt) checkShape(wgt.shape(), {nrows, nchan});
size_t nu=gconf.Nu(), nv=gconf.Nv();
checkShape(grid.shape(), {nu, nv});
myassert(grid.contiguous(), "grid is not contiguous");
T beta = gconf.Beta();
size_t supp = gconf.Supp();
// Loop over sampling points
#pragma omp parallel num_threads(nthreads)
{
Helper<T> hlp(gconf, grid.data(), nullptr);
T emb = exp(-2*beta);
int jump = hlp.lineJump();
const T * ku = hlp.kernel.data();
const T * kv = hlp.kernel.data()+supp;
#pragma omp for schedule(guided,100)
for (size_t ipart=0; ipart<nvis; ++ipart)
{
auto rc = baselines.getRowChan(idx(ipart));
UVW<T> coord = baselines.effectiveCoord(rc);
hlp.prep(coord.u, coord.v);
complex<T> r = 0;
const auto * ptr = hlp.p0r;
for (size_t cu=0; cu<supp; ++cu)
{
complex<T> tmp(0);
for (size_t cv=0; cv<supp; ++cv)
tmp += ptr[cv] * kv[cv];
r += tmp*ku[cu];
ptr += jump;
}
if (have_wgt) r*=wgt(rc.row,rc.chan);
ms(rc.row,rc.chan) = r*emb;
}
}
grid2x_c(gconf, grid, MsServ<T,mav<complex<T>,2>,const_mav<T,2>> (baselines, idx, ms, wgt));
}
template<typename T> void apply_holo
......@@ -998,9 +982,8 @@ template<typename T, typename T2> void apply_wcorr(const GridderConfig<T> &gconf
}
}
template<typename T> void vis2dirty_wstack(const Baselines<T> &baselines,
const GridderConfig<T> &gconf, const const_mav<uint32_t,1> &idx,
const const_mav<complex<T>,1> &vis, mav<complex<T>,2> &dirty)
template<typename T, typename Serv> void x2dirty_wstack(const Baselines<T> &baselines,
const GridderConfig<T> &gconf, const Serv &srv, mav<complex<T>,2> &dirty)
{
auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty();
......@@ -1008,8 +991,7 @@ template<typename T> void vis2dirty_wstack(const Baselines<T> &baselines,
auto nv=gconf.Nv();
auto psx=gconf.Pixsize_x();
auto psy=gconf.Pixsize_y();
size_t nvis = vis.shape(0);
checkShape(idx.shape(),vis.shape());
size_t nvis = srv.Nvis();
checkShape(dirty.shape(),{nx_dirty,ny_dirty});
// determine w values for every visibility, and min/max w;
......@@ -1017,7 +999,7 @@ template<typename T> void vis2dirty_wstack(const Baselines<T> &baselines,
vector<T> wval(nvis);
for (size_t ipart=0; ipart<nvis; ++ipart)
{
wval[ipart] = baselines.effectiveCoord(idx(ipart)).w;
wval[ipart] = srv.getCoord(ipart).w;
wmin = min(wmin,wval[ipart]);
wmax = max(wmax,wval[ipart]);
}
......@@ -1071,8 +1053,8 @@ cout << "working on w plane #" << iw << " containing " << nvis_plane[iw]
if ((int(iw)>=minplane[ipart]) && (iw<minplane[ipart]+w_supp))
{
double x=min(1.,2./(w_supp*dw)*abs(wcur-wval[ipart]));
vis_loc[cnt] = vis[ipart]*kernel(x);
idx_loc[cnt] = idx[ipart];
vis_loc[cnt] = srv.getVis(ipart)*kernel(x);
idx_loc[cnt] = srv.getIdx(ipart);
++cnt;
}
myassert(cnt==nvis_plane[iw],"must not happen 2");
......@@ -1090,10 +1072,16 @@ cout << "applying correction for gridding in w direction" << endl;
apply_wcorr(gconf, dirty, kernel, dw);
}
template<typename T> void dirty2vis_wstack(const Baselines<T> &baselines,
template<typename T> void vis2dirty_wstack(const Baselines<T> &baselines,
const GridderConfig<T> &gconf, const const_mav<uint32_t,1> &idx,
const const_mav<complex<T>,2> &dirty, mav<complex<T>,1> &vis)
const const_mav<complex<T>,1> &vis, mav<complex<T>,2> &dirty)
{
x2dirty_wstack(baselines, gconf, VisServ<T,const_mav<complex<T>,1>,const_mav<T,1>>(baselines, idx, vis, const_mav<T,1>(nullptr,{0})), dirty);
}
template<typename T, typename Serv> void dirty2x_wstack(const Baselines<T> &baselines,
const GridderConfig<T> &gconf, const const_mav<complex<T>,2> &dirty,
const Serv &srv)
{
auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty();
......@@ -1101,8 +1089,7 @@ template<typename T> void dirty2vis_wstack(const Baselines<T> &baselines,
auto nv=gconf.Nv();
auto psx=gconf.Pixsize_x();
auto psy=gconf.Pixsize_y();
size_t nvis = vis.shape(0);
checkShape(idx.shape(),vis.shape());
size_t nvis = srv.Nvis();
checkShape(dirty.shape(),{nx_dirty,ny_dirty});
// determine w values for every visibility, and min/max w;
......@@ -1110,7 +1097,7 @@ template<typename T> void dirty2vis_wstack(const Baselines<T> &baselines,
vector<T> wval(nvis);
for (size_t ipart=0; ipart<nvis; ++ipart)
{
wval[ipart] = baselines.effectiveCoord(idx(ipart)).w;
wval[ipart] = srv.getCoord(ipart).w;
wmin = min(wmin,wval[ipart]);
wmax = max(wmax,wval[ipart]);
}
......@@ -1140,7 +1127,7 @@ cout << "nplanes: " << nplanes << endl;
}
for (size_t i=0; i<nvis; ++i)
vis(i) = 0.;
srv.setVis(i, 0.);
tmpStorage<complex<T>,2> tdirty_({nx_dirty,ny_dirty});
auto tdirty=tdirty_.getMav();
......@@ -1172,7 +1159,7 @@ cout << "working on w plane #" << iw << " containing " << nvis_plane[iw]
auto idx_loc = mav<uint32_t,1>(idx_loc_.data(), {nvis_plane[iw]});
for (size_t ipart=0; ipart<nvis; ++ipart)
if ((int(iw)>=minplane[ipart]) && (iw<minplane[ipart]+w_supp))
idx_loc[cnt++] = idx[ipart];
idx_loc[cnt++] = srv.getIdx(ipart);
myassert(cnt==nvis_plane[iw],"must not happen 2");
grid2vis_c(baselines, gconf, const_mav<uint32_t,1>(idx_loc), const_mav<complex<T>,2>(grid), vis_loc, const_mav<T,1>(nullptr,{0}));
......@@ -1181,11 +1168,17 @@ myassert(cnt==nvis_plane[iw],"must not happen 2");
if ((int(iw)>=minplane[ipart]) && (iw<minplane[ipart]+w_supp))
{
double x=min(1.,2./(w_supp*dw)*abs(wcur-wval[ipart]));
vis[ipart] += vis_loc[cnt]*kernel(x);
srv.addVis(ipart, vis_loc[cnt]*kernel(x));
++cnt;
}
}
}
template<typename T> void dirty2vis_wstack(const Baselines<T> &baselines,
const GridderConfig<T> &gconf, const const_mav<uint32_t,1> &idx,
const const_mav<complex<T>,2> &dirty, mav<complex<T>,1> &vis)
{
dirty2x_wstack(baselines, gconf, dirty, VisServ<T,mav<complex<T>,1>,const_mav<T,1>>(baselines, idx, vis, const_mav<T,1>(nullptr,{0})));
}
} // namespace gridder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment