Commit 6575e568 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

polishing

parent bdac5ac6
......@@ -227,15 +227,8 @@ template<size_t ndim> void checkShape
/*! Returns the remainder of the division \a v1/v2.
The result is non-negative.
\a v1 can be positive or negative; \a v2 must be positive. */
template<typename T> inline T fmodulo (T v1, T v2)
{
if (v1>=0)
return (v1<v2) ? v1 : fmod(v1,v2);
T tmp = v1+v2;
if (tmp<0)
tmp=fmod(v1,v2)+v2;
return (tmp==v2) ? T(0) : tmp;
}
template<typename T> inline T fmod1 (T v)
{ return v-floor(v); }
//
// Utilities for Gauss-Legendre quadrature
......@@ -544,10 +537,10 @@ template<typename T> class GridderConfig
{
protected:
size_t nx_dirty, ny_dirty;
T eps, psx, psy;
double eps, psx, psy;
size_t supp, nsafe, nu, nv;
T beta;
vector<T> cfu, cfv;
double beta;
vector<double> cfu, cfv;
size_t nthreads;
complex<T> wscreen(double x, double y, double w, bool adjoint) const
......@@ -598,7 +591,7 @@ template<typename T> class GridderConfig
size_t Nv() const { return nv; }
size_t Supp() const { return supp; }
size_t Nsafe() const { return nsafe; }
T Beta() const { return beta; }
double Beta() const { return beta; }
size_t Nthreads() const { return nthreads; }
void apply_taper(const mav<const T,2> &img, mav<T,2> &img2, bool divide) const
......@@ -629,7 +622,7 @@ template<typename T> class GridderConfig
if (i2>=nu) i2-=nu;
size_t j2 = nv-ny_dirty/2+j;
if (j2>=nv) j2-=nv;
dirty(i,j) = tmav(i2,j2)*cfu[i]*cfv[j];
dirty(i,j) = tmav(i2,j2)*T(cfu[i]*cfv[j]);
}
}
......@@ -650,7 +643,7 @@ template<typename T> class GridderConfig
if (i2>=nu) i2-=nu;
size_t j2 = nv-ny_dirty/2+j;
if (j2>=nv) j2-=nv;
dirty(i,j) = tmp(i2,j2)*cfu[i]*cfv[j];
dirty(i,j) = tmp(i2,j2)*T(cfu[i]*cfv[j]);
}
}
......@@ -666,7 +659,7 @@ template<typename T> class GridderConfig
if (i2>=nu) i2-=nu;
size_t j2 = nv-ny_dirty/2+j;
if (j2>=nv) j2-=nv;
grid(i2,j2) = dirty(i,j)*cfu[i]*cfv[j];
grid(i2,j2) = dirty(i,j)*T(cfu[i]*cfv[j]);
}
hartley2_2D<T>(cmav(grid), grid, nthreads);
}
......@@ -684,7 +677,7 @@ template<typename T> class GridderConfig
if (i2>=nu) i2-=nu;
size_t j2 = nv-ny_dirty/2+j;
if (j2>=nv) j2-=nv;
grid(i2,j2) = dirty(i,j)*cfu[i]*cfv[j];
grid(i2,j2) = dirty(i,j)*T(cfu[i]*cfv[j]);
}
constexpr auto sc = ptrdiff_t(sizeof(complex<T>));
pocketfft::stride_t strides{grid.stride(0)*sc,grid.stride(1)*sc};
......@@ -694,10 +687,10 @@ template<typename T> class GridderConfig
void getpix(T u_in, T v_in, T &u, T &v, int &iu0, int &iv0) const
{
u=fmodulo(u_in*psx, T(1))*nu,
u=fmod1(u_in*psx)*nu,
iu0 = int(u-supp*0.5 + 1 + nu) - nu;
iu0 = min<int>(iu0, (nu+nsafe)-supp);
v=fmodulo(v_in*psy, T(1))*nv;
v=fmod1(v_in*psy)*nv;
iv0 = int(v-supp*0.5 + 1 + nv) - nv;
iv0 = min<int>(iv0, (nv+nsafe)-supp);
}
......@@ -856,11 +849,10 @@ template<class T, class T2, class T3> class VisServ
public:
VisServ(const Baselines<T> &baselines_,
const const_mav<uint32_t,1> &idx_, T2 vis_, const T3 &wgt_)
: baselines(baselines_), idx(idx_), vis(vis_), wgt(wgt_)
: baselines(baselines_), idx(idx_), vis(vis_), wgt(wgt_),
nvis(vis.shape(0)), have_wgt(wgt.size()!=0)
{
nvis = vis.shape(0);
checkShape(idx.shape(), {nvis});
have_wgt = wgt.size()!=0;
if (have_wgt) checkShape(wgt.shape(), {nvis});
}
size_t Nvis() const { return nvis; }
......@@ -875,6 +867,10 @@ template<class T, class T2, class T3> class VisServ
void addVis (size_t i, const complex<T> &v) const
{ vis(i) += have_wgt ? v*wgt(i) : v; }
};
template<class T, class T2, class T3> VisServ<T,T2,T3> makeVisServ
(const Baselines<T> &baselines,
const const_mav<uint32_t,1> &idx, T2 vis, const T3 &wgt)
{ return VisServ<T,T2,T3>(baselines, idx, vis, wgt); }
template<class T, class T2, class T3> class MsServ
{
private:
......@@ -888,14 +884,11 @@ template<class T, class T2, class T3> class MsServ
public:
MsServ(const Baselines<T> &baselines_,
const const_mav<uint32_t,1> &idx_, T2 ms_, const T3 &wgt_)
: baselines(baselines_), idx(idx_), ms(ms_), wgt(wgt_)
: baselines(baselines_), idx(idx_), ms(ms_), wgt(wgt_),
nvis(idx.shape(0)), have_wgt(wgt.size()!=0)
{
auto nrows = baselines.Nrows();
auto nchan = baselines.Nchannels();
nvis = idx.shape(0);
checkShape(ms.shape(), {nrows, nchan});
have_wgt = wgt.size()!=0;
if (have_wgt) checkShape(wgt.shape(), {nrows, nchan});
checkShape(ms.shape(), {baselines.Nrows(), baselines.Nchannels()});
if (have_wgt) checkShape(wgt.shape(), ms.shape());
}
MsServ(const MsServ &orig, const const_mav<uint32_t,1> &newidx)
: MsServ(orig.baselines, newidx, orig.ms, orig.wgt) {}
......@@ -921,16 +914,20 @@ template<class T, class T2, class T3> class MsServ
ms(rc.row, rc.chan) += have_wgt ? v*wgt(rc.row, rc.chan) : v;
}
};
template<class T, class T2, class T3> MsServ<T,T2,T3> makeMsServ
(const Baselines<T> &baselines,
const const_mav<uint32_t,1> &idx, T2 ms, const T3 &wgt)
{ return MsServ<T,T2,T3>(baselines, idx, ms, wgt); }
template<typename T, typename Serv> void x2grid_c
(const GridderConfig<T> &gconf, const Serv &srv, mav<complex<T>,2> &grid, T w0=-1, T dw=-1)
(const GridderConfig<T> &gconf, const Serv &srv, mav<complex<T>,2> &grid,
T w0=-1, T dw=-1)
{
size_t nu=gconf.Nu(), nv=gconf.Nv();
checkShape(grid.shape(), {nu, nv});
checkShape(grid.shape(), {gconf.Nu(), gconf.Nv()});
myassert(grid.contiguous(), "grid is not contiguous");
size_t supp = gconf.Supp();
size_t nthreads = gconf.Nthreads();
bool do_w_gridding=dw>0;
bool do_w_gridding = dw>0;
#pragma omp parallel num_threads(nthreads)
{
......@@ -975,7 +972,7 @@ template<typename T> void vis2grid_c
const const_mav<uint32_t,1> &idx, const const_mav<complex<T>,1> &vis,
mav<complex<T>,2> &grid, const const_mav<T,1> &wgt, T w0=-1, T dw=-1)
{
x2grid_c(gconf, VisServ<T,const_mav<complex<T>,1>,const_mav<T,1>> (baselines, idx, vis, wgt), grid, w0, dw);
x2grid_c(gconf, makeVisServ(baselines, idx, vis, wgt), grid, w0, dw);
}
template<typename T> void ms2grid_c
......@@ -983,7 +980,7 @@ template<typename T> void ms2grid_c
const const_mav<uint32_t,1> &idx, const const_mav<complex<T>,2> &ms,
mav<complex<T>,2> &grid, const const_mav<T,2> &wgt)
{
x2grid_c(gconf, MsServ<T,const_mav<complex<T>,2>,const_mav<T,2>> (baselines, idx, ms, wgt), grid);
x2grid_c(gconf, makeMsServ(baselines, idx, ms, wgt), grid);
}
template<typename T, typename Serv> void grid2x_c
......@@ -1039,7 +1036,7 @@ template<typename T> void grid2vis_c
const const_mav<uint32_t,1> &idx, const const_mav<complex<T>,2> &grid,
mav<complex<T>,1> &vis, const const_mav<T,1> &wgt, T w0=-1, T dw=-1)
{
grid2x_c(gconf, grid, VisServ<T,mav<complex<T>,1>,const_mav<T,1>> (baselines, idx, vis, wgt), w0, dw);
grid2x_c(gconf, grid, makeVisServ(baselines, idx, vis, wgt), w0, dw);
}
template<typename T> void grid2ms_c
......@@ -1047,7 +1044,7 @@ template<typename T> void grid2ms_c
const const_mav<uint32_t,1> &idx, const const_mav<complex<T>,2> &grid,
mav<complex<T>,2> &ms, const const_mav<T,2> &wgt)
{
grid2x_c(gconf, grid, MsServ<T,mav<complex<T>,2>,const_mav<T,2>> (baselines, idx, ms, wgt));
grid2x_c(gconf, grid, makeMsServ(baselines, idx, ms, wgt));
}
template<typename T> void apply_holo
......@@ -1055,16 +1052,13 @@ template<typename T> void apply_holo
const const_mav<uint32_t,1> &idx, const const_mav<complex<T>,2> &grid,
mav<complex<T>,2> &ogrid, const const_mav<T,1> &wgt)
{
size_t nu=gconf.Nu(), nv=gconf.Nv();
checkShape(grid.shape(), {nu, nv});
checkShape(grid.shape(), {gconf.Nu(), gconf.Nv()});
size_t nvis = idx.shape(0);
size_t nthreads = gconf.Nthreads();
bool have_wgt = wgt.size()!=0;
if (have_wgt) checkShape(wgt.shape(), {nvis});
checkShape(ogrid.shape(), {nu, nv});
checkShape(ogrid.shape(), ogrid.shape());
ogrid.fill(0);
size_t supp = gconf.Supp();
// Loop over sampling points
......@@ -1072,8 +1066,8 @@ template<typename T> void apply_holo
{
Helper<T> hlp(gconf, grid.data(), ogrid.data());
int jump = hlp.lineJump();
const T * ku = hlp.kernel;
const T * kv = hlp.kernel+supp;
const T * RESTRICT ku = hlp.kernel;
const T * RESTRICT kv = hlp.kernel+supp;
#pragma omp for schedule(guided,100)
for (size_t ipart=0; ipart<nvis; ++ipart)
......@@ -1082,11 +1076,17 @@ template<typename T> void apply_holo
coord.FixW();
hlp.prep(coord);
complex<T> r = 0;
const auto * ptr = hlp.p0r;
const auto * RESTRICT ptr = hlp.p0r;
for (size_t cu=0; cu<supp; ++cu)
{
complex<T> tmp(0);
for (size_t cv=0; cv<supp; ++cv)
size_t cv=0;
for (;cv<supp-3; cv+=4)
tmp += ptr[cv ]*kv[cv ]
+ptr[cv+1]*kv[cv+1]
+ptr[cv+2]*kv[cv+2]
+ptr[cv+3]*kv[cv+3];
for (; cv<supp; ++cv)
tmp += ptr[cv] * kv[cv];
r += tmp*ku[cu];
ptr += jump;
......@@ -1096,11 +1096,19 @@ template<typename T> void apply_holo
auto twgt = wgt(ipart);
r*=twgt*twgt;
}
auto * wptr = hlp.p0w;
auto * RESTRICT wptr = hlp.p0w;
for (size_t cu=0; cu<supp; ++cu)
{
complex<T> tmp(r*ku[cu]);
for (size_t cv=0; cv<supp; ++cv)
size_t cv=0;
for (;cv<supp-3; cv+=4)
{
wptr[cv ] += tmp*kv[cv ];
wptr[cv+1] += tmp*kv[cv+1];
wptr[cv+2] += tmp*kv[cv+2];
wptr[cv+3] += tmp*kv[cv+3];
}
for (; cv<supp; ++cv)
wptr[cv] += tmp*kv[cv];
wptr += jump;
}
......@@ -1113,16 +1121,14 @@ template<typename T> void get_correlations
const const_mav<uint32_t,1> &idx, int du, int dv,
mav<T,2> &ogrid, const const_mav<T,1> &wgt)
{
size_t nu=gconf.Nu(), nv=gconf.Nv();
size_t nvis = idx.shape(0);
bool have_wgt = wgt.size()!=0;
if (have_wgt) checkShape(wgt.shape(), {nvis});
checkShape(ogrid.shape(), {nu, nv});
checkShape(ogrid.shape(), {gconf.Nu(), gconf.Nv()});
size_t supp = gconf.Supp();
myassert(size_t(abs(du))<supp, "|du| must be smaller than Supp");
myassert(size_t(abs(dv))<supp, "|dv| must be smaller than Supp");
size_t nthreads = gconf.Nthreads();
ogrid.fill(0);
size_t u0, u1, v0, v1;
......@@ -1140,8 +1146,8 @@ template<typename T> void get_correlations
{
Helper<T,T> hlp(gconf, nullptr, ogrid.data());
int jump = hlp.lineJump();
const T * ku = hlp.kernel;
const T * kv = hlp.kernel+supp;
const T * RESTRICT ku = hlp.kernel;
const T * RESTRICT kv = hlp.kernel+supp;
#pragma omp for schedule(guided,100)
for (size_t ipart=0; ipart<nvis; ++ipart)
......@@ -1149,7 +1155,7 @@ template<typename T> void get_correlations
UVW<T> coord = baselines.effectiveCoord(idx(ipart));
coord.FixW();
hlp.prep(coord);
auto * wptr = hlp.p0w + u0*jump;
auto * RESTRICT wptr = hlp.p0w + u0*jump;
auto f0 = T(1);
if (have_wgt)
{
......@@ -1211,19 +1217,16 @@ template<typename T, typename T2> void apply_wcorr(const GridderConfig<T> &gconf
template<typename T, typename Serv> void x2dirty(
const GridderConfig<T> &gconf, const Serv &srv, const mav<T,2> &dirty, size_t verbosity=0)
{
auto nu=gconf.Nu();
auto nv=gconf.Nv();
if (verbosity>0)
cout << "Gridding without w-stacking: " << srv.Nvis()
<< " visibilities" << endl;
if (verbosity>0) cout << "Using " << gconf.Nthreads() << " threads" << endl;
tmpStorage<complex<T>,2> grid_({nu,nv});
tmpStorage<complex<T>,2> grid_({gconf.Nu(), gconf.Nv()});
auto grid=grid_.getMav();
grid_.fill(0.);
x2grid_c(gconf, srv, grid);
tmpStorage<T,2> rgrid_({nu, nv});
tmpStorage<T,2> rgrid_(grid.shape());
auto rgrid=rgrid_.getMav();
complex2hartley(cmav(grid), rgrid, gconf.Nthreads());
gconf.grid2dirty(cmav(rgrid), dirty);
......@@ -1233,22 +1236,15 @@ template<typename T, typename Serv> void dirty2x(
const GridderConfig<T> &gconf, const const_mav<T,2> &dirty,
const Serv &srv, size_t verbosity=0)
{
auto nu=gconf.Nu();
auto nv=gconf.Nv();
size_t nvis = srv.Nvis();
if (verbosity>0)
cout << "Degridding without w-stacking: " << srv.Nvis()
<< " visibilities" << endl;
if (verbosity>0) cout << "Using " << gconf.Nthreads() << " threads" << endl;
for (size_t i=0; i<nvis; ++i)
srv.setVis(i, 0.);
tmpStorage<T,2> grid_({nu,nv});
tmpStorage<T,2> grid_({gconf.Nu(), gconf.Nv()});
auto grid=grid_.getMav();
gconf.dirty2grid(dirty, grid);
tmpStorage<complex<T>,2> grid2_({nu,nv});
tmpStorage<complex<T>,2> grid2_(grid.shape());
auto grid2=grid2_.getMav();
hartley2complex(cmav(grid), grid2, gconf.Nthreads());
grid2x_c(gconf, cmav(grid2), srv);
......@@ -1258,10 +1254,6 @@ template<typename T, typename Serv> void wstack_common(
const GridderConfig<T> &gconf, const Serv &srv, T &wmin,
double &dw, size_t &nplanes, vector<size_t> &nvis_plane, vector<int> &minplane, size_t verbosity)
{
auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty();
auto psx=gconf.Pixsize_x();
auto psy=gconf.Pixsize_y();
size_t nvis = srv.Nvis();
size_t nthreads = gconf.Nthreads();
......@@ -1279,8 +1271,8 @@ template<typename T, typename Serv> void wstack_common(
}
if (verbosity>0) cout << "Using " << nthreads << " threads" << endl;
if (verbosity>0) cout << "W range: " << wmin << " to " << wmax << endl;
double x0 = -0.5*nx_dirty*psx,
y0 = -0.5*ny_dirty*psy;
double x0 = -0.5*gconf.Nxdirty()*gconf.Pixsize_x(),
y0 = -0.5*gconf.Nydirty()*gconf.Pixsize_y();
double nmin = sqrt(max(1.-x0*x0-y0*y0,0.))-1.;
dw = 0.25/abs(nmin);
nplanes = max<size_t>(2,((wmax-wmin)/dw+2));
......@@ -1458,9 +1450,6 @@ template<typename T, typename Serv> void dirty2x_wstack(
if (verbosity>0) cout << "Degridding using improved w-stacking" << endl;
wstack_common(gconf, srv, wmin, dw, nplanes, nvis_plane, minplane, verbosity);
for (size_t i=0; i<nvis; ++i)
srv.setVis(i, 0.);
tmpStorage<T,2> tdirty_({nx_dirty,ny_dirty});
auto tdirty=tdirty_.getMav();
for (size_t i=0; i<nx_dirty; ++i)
......@@ -1638,8 +1627,7 @@ template<typename T> void gridding(const const_mav<T,2> &uvw,
GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y, nthreads);
auto idx = getWgtIndices(baselines, gconf, wgt);
auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()});
auto srv = MsServ<T,const_mav<complex<T>,2>,const_mav<T,2>>(baselines,idx2,ms,wgt);
x2dirty(gconf, srv, dirty, verbosity);
x2dirty(gconf, makeMsServ(baselines,idx2,ms,wgt), dirty, verbosity);
}
template<typename T> void degridding(const const_mav<T,2> &uvw,
......@@ -1652,8 +1640,7 @@ template<typename T> void degridding(const const_mav<T,2> &uvw,
auto idx = getWgtIndices(baselines, gconf, wgt);
auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()});
ms.fill(0);
auto srv = MsServ<T,mav<complex<T>,2>,const_mav<T,2>>(baselines,idx2,ms,wgt);
dirty2x(gconf, dirty, srv, verbosity);
dirty2x(gconf, dirty, makeMsServ(baselines,idx2,ms,wgt), verbosity);
}
template<typename T> void full_gridding(const const_mav<T,2> &uvw,
......@@ -1665,8 +1652,7 @@ template<typename T> void full_gridding(const const_mav<T,2> &uvw,
GridderConfig<T> gconf(dirty.shape(0), dirty.shape(1), epsilon, pixsize_x, pixsize_y, nthreads);
auto idx = getWgtIndices(baselines, gconf, wgt);
auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()});
auto srv = MsServ<T,const_mav<complex<T>,2>,const_mav<T,2>>(baselines,idx2,ms,wgt);
x2dirty_wstack2(gconf, srv, dirty, verbosity);
x2dirty_wstack2(gconf, makeMsServ(baselines,idx2,ms,wgt), dirty, verbosity);
}
template<typename T> void full_degridding(const const_mav<T,2> &uvw,
......@@ -1679,8 +1665,7 @@ template<typename T> void full_degridding(const const_mav<T,2> &uvw,
auto idx = getWgtIndices(baselines, gconf, wgt);
auto idx2 = const_mav<uint32_t,1>(idx.data(),{idx.size()});
ms.fill(0);
auto srv = MsServ<T,mav<complex<T>,2>,const_mav<T,2>>(baselines,idx2,ms,wgt);
dirty2x_wstack(gconf, dirty, srv, verbosity);
dirty2x_wstack(gconf, dirty, makeMsServ(baselines,idx2,ms,wgt), verbosity);
}
} // namespace detail
......
......@@ -728,6 +728,7 @@ template<typename T> pyarr<complex<T>> Pydirty2vis_wstack(const PyBaselines<T> &
auto dirty2=make_const_mav<2>(dirty_);
{
py::gil_scoped_release release;
vis2.fill(0);
dirty2vis_wstack<T>(baselines, gconf, idx2, dirty2, vis2);
}
return vis;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment