Commit e81b20be authored by Martin Reinecke's avatar Martin Reinecke
Browse files

streamlining

parent cee90fbb
......@@ -300,7 +300,7 @@ struct RowChan
template<typename T> void complex2hartley
(const const_mav<complex<T>, 2> &grid, const mav<T,2> &grid2, size_t nthreads)
{
myassert(grid.shape()==grid2.shape(), "shape mismatch");
checkShape(grid.shape(), grid2.shape());
size_t nu=grid.shape(0), nv=grid.shape(1);
#pragma omp parallel for num_threads(nthreads)
......@@ -319,7 +319,7 @@ template<typename T> void complex2hartley
template<typename T> void hartley2complex
(const const_mav<T,2> &grid, const mav<complex<T>,2> &grid2, size_t nthreads)
{
myassert(grid.shape()==grid2.shape(), "shape mismatch");
checkShape(grid.shape(), grid2.shape());
size_t nu=grid.shape(0), nv=grid.shape(1);
#pragma omp parallel for num_threads(nthreads)
......@@ -339,7 +339,7 @@ template<typename T> void hartley2complex
template<typename T> void hartley2_2D(const const_mav<T,2> &in,
const mav<T,2> &out, size_t nthreads)
{
myassert(in.shape()==out.shape(), "shape mismatch");
checkShape(in.shape(), out.shape());
size_t nu=in.shape(0), nv=in.shape(1);
ptrdiff_t sz=ptrdiff_t(sizeof(T));
pocketfft::stride_t stri{sz*in.stride(0), sz*in.stride(1)};
......@@ -366,42 +366,23 @@ template<typename T> void hartley2_2D(const const_mav<T,2> &in,
class ES_Kernel
{
protected:
double beta;
public:
ES_Kernel(size_t supp) : beta(2.3*supp) {}
double operator()(double v) const { return exp(beta*(sqrt(1.-v*v)-1.)); }
static size_t get_supp(double epsilon)
{
static const vector<double> maxmaperr { 1e8, 0.32, 0.021, 6.2e-4,
1.08e-5, 1.25e-7, 8.25e-10, 5.70e-12, 1.22e-13, 2.48e-15, 4.82e-17,
6.74e-19, 5.41e-21, 4.41e-23, 7.88e-25, 3.9e-26 };
double epssq = epsilon*epsilon;
for (size_t i=1; i<maxmaperr.size(); ++i)
if (epssq>maxmaperr[i]) return i;
myfail("requested epsilon too small - minimum is 2e-13");
}
};
class ES_Kernel_with_correction: public ES_Kernel
{
protected:
private:
static constexpr double pi = 3.141592653589793238462643383279502884197;
double beta;
int p;
vector<double> x, wgt, psi;
size_t supp;
public:
ES_Kernel_with_correction(size_t supp_, size_t nthreads)
: ES_Kernel(supp_), p(int(1.5*supp_+2)), supp(supp_)
ES_Kernel(size_t supp_, size_t nthreads)
: beta(2.3*supp_), p(int(1.5*supp_+2)), supp(supp_)
{
legendre_prep(2*p,x,wgt,nthreads);
psi=x;
for (auto &v:psi)
v=operator()(v);
}
double operator()(double v) const { return exp(beta*(sqrt(1.-v*v)-1.)); }
/* Compute correction factors for the ES gridding kernel
This implementation follows eqs. (3.8) to (3.10) of Barnett et al. 2018 */
double corfac(double v) const
......@@ -411,13 +392,25 @@ class ES_Kernel_with_correction: public ES_Kernel
tmp += wgt[i]*psi[i]*cos(pi*supp*v*x[i]);
return 1./(supp*tmp);
}
static size_t get_supp(double epsilon)
{
static const vector<double> maxmaperr { 1e8, 0.32, 0.021, 6.2e-4,
1.08e-5, 1.25e-7, 8.25e-10, 5.70e-12, 1.22e-13, 2.48e-15, 4.82e-17,
6.74e-19, 5.41e-21, 4.41e-23, 7.88e-25, 3.9e-26 };
double epssq = epsilon*epsilon;
for (size_t i=1; i<maxmaperr.size(); ++i)
if (epssq>maxmaperr[i]) return i;
myfail("requested epsilon too small - minimum is 2e-13");
}
};
/* Compute correction factors for the ES gridding kernel
This implementation follows eqs. (3.8) to (3.10) of Barnett et al. 2018 */
vector<double> correction_factors(size_t n, size_t nval, size_t supp,
size_t nthreads)
{
ES_Kernel_with_correction kernel(supp, nthreads);
ES_Kernel kernel(supp, nthreads);
vector<double> res(nval);
double xn = 1./n;
#pragma omp parallel for schedule(static) num_threads(nthreads)
......@@ -451,7 +444,8 @@ class Baselines
size_t shift, mask;
public:
template<typename T> Baselines(const const_mav<T,2> &coord_, const const_mav<T,1> &freq)
template<typename T> Baselines(const const_mav<T,2> &coord_,
const const_mav<T,1> &freq)
{
constexpr double speedOfLight = 299792458.;
myassert(coord_.shape(1)==3, "dimension mismatch");
......@@ -488,8 +482,7 @@ class Baselines
mav<T,2> &res) const
{
size_t nvis = idx.shape(0);
myassert(res.shape(0)==nvis, "shape mismatch");
myassert(res.shape(1)==3, "shape mismatch");
checkShape(res.shape(), {idx.shape(0), 3});
for (size_t i=0; i<nvis; i++)
{
auto uvw = effectiveCoord(idx(i));
......@@ -502,10 +495,9 @@ class Baselines
template<typename T> void ms2vis(const mav<const T,2> &ms,
const mav<const uint32_t,1> &idx, mav<T,1> &vis, size_t nthreads) const
{
myassert(ms.shape(0)==nrows, "shape mismatch");
myassert(ms.shape(1)==nchan, "shape mismatch");
checkShape(ms.shape(), {nrows, nchan});
size_t nvis = idx.shape(0);
myassert(vis.shape(0)==nvis, "shape mismatch");
checkShape(vis.shape(), {nvis});
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nvis; ++i)
{
......@@ -518,9 +510,8 @@ class Baselines
const mav<const uint32_t,1> &idx, mav<T,2> &ms, size_t nthreads) const
{
size_t nvis = vis.shape(0);
myassert(idx.shape(0)==nvis, "shape mismatch");
myassert(ms.shape(0)==nrows, "shape mismatch");
myassert(ms.shape(1)==nchan, "shape mismatch");
checkShape(idx.shape(), {nvis});
checkShape(ms.shape(), {nrows, nchan});
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nvis; ++i)
{
......@@ -539,6 +530,8 @@ class GridderConfig
double beta;
vector<double> cfu, cfv;
size_t nthreads;
// double ushift, vshift;
int maxiu0, maxiv0;
complex<double> wscreen(double x, double y, double w, bool adjoint) const
{
......@@ -560,7 +553,9 @@ class GridderConfig
supp(ES_Kernel::get_supp(epsilon)), nsafe((supp+1)/2),
nu(max(2*nsafe,2*nx_dirty)), nv(max(2*nsafe,2*ny_dirty)),
beta(2.3*supp),
cfu(nx_dirty), cfv(ny_dirty), nthreads(nthreads_)
cfu(nx_dirty), cfv(ny_dirty), nthreads(nthreads_),
// ushift(-supp*0.5+1), vshift(-supp*0.5+1),
maxiu0((nu+nsafe)-supp), maxiv0((nv+nsafe)-supp)
{
myassert((nx_dirty&1)==0, "nx_dirty must be even");
myassert((ny_dirty&1)==0, "ny_dirty must be even");
......@@ -680,10 +675,10 @@ class GridderConfig
{
u=fmod1(u_in*psx)*nu,
iu0 = int(u-supp*0.5 + 1 + nu) - nu;
iu0 = min<int>(iu0, (nu+nsafe)-supp);
iu0 = min(iu0, maxiu0);
v=fmod1(v_in*psy)*nv;
iv0 = int(v-supp*0.5 + 1 + nv) - nv;
iv0 = min<int>(iv0, (nv+nsafe)-supp);
iv0 = min(iv0, maxiv0);
}
template<typename T> void apply_wscreen(const const_mav<complex<T>,2> &dirty,
......@@ -1182,7 +1177,7 @@ template<typename T> void get_correlations
template<typename T> void apply_wcorr(const GridderConfig &gconf,
const mav<T,2> &dirty, const ES_Kernel_with_correction &kernel, double dw)
const mav<T,2> &dirty, const ES_Kernel &kernel, double dw)
{
auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty();
......@@ -1324,7 +1319,7 @@ template<typename T, typename Serv> void x2dirty(
dirty(i,j) += tdirty(i,j).real();
}
// correct for w gridding
apply_wcorr(gconf, dirty, ES_Kernel_with_correction(supp, nthreads), dw);
apply_wcorr(gconf, dirty, ES_Kernel(supp, nthreads), dw);
}
else
{
......@@ -1376,7 +1371,7 @@ template<typename T, typename Serv> void dirty2x(
for (size_t j=0; j<ny_dirty; ++j)
tdirty(i,j) = dirty(i,j);
// correct for w gridding
apply_wcorr(gconf, tdirty, ES_Kernel_with_correction(supp, nthreads), dw);
apply_wcorr(gconf, tdirty, ES_Kernel(supp, nthreads), dw);
vector<uint32_t> subidx;
tmpStorage<complex<T>,2> grid_({gconf.Nu(),gconf.Nv()});
......
......@@ -818,7 +818,6 @@ template<typename T> pyarr<complex<T>> dirty2vis2(const PyBaselines &baselines,
auto wgt2 = make_const_mav<1>(wgt);
auto vis = makeArray<complex<T>>({idx2.shape(0)});
auto vis2 = make_mav<1>(vis);
vis2.fill(0);
{
py::gil_scoped_release release;
vis2.fill(0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment