Commit e81b20be authored by Martin Reinecke's avatar Martin Reinecke
Browse files

streamlining

parent cee90fbb
...@@ -300,7 +300,7 @@ struct RowChan ...@@ -300,7 +300,7 @@ struct RowChan
template<typename T> void complex2hartley template<typename T> void complex2hartley
(const const_mav<complex<T>, 2> &grid, const mav<T,2> &grid2, size_t nthreads) (const const_mav<complex<T>, 2> &grid, const mav<T,2> &grid2, size_t nthreads)
{ {
myassert(grid.shape()==grid2.shape(), "shape mismatch"); checkShape(grid.shape(), grid2.shape());
size_t nu=grid.shape(0), nv=grid.shape(1); size_t nu=grid.shape(0), nv=grid.shape(1);
#pragma omp parallel for num_threads(nthreads) #pragma omp parallel for num_threads(nthreads)
...@@ -319,7 +319,7 @@ template<typename T> void complex2hartley ...@@ -319,7 +319,7 @@ template<typename T> void complex2hartley
template<typename T> void hartley2complex template<typename T> void hartley2complex
(const const_mav<T,2> &grid, const mav<complex<T>,2> &grid2, size_t nthreads) (const const_mav<T,2> &grid, const mav<complex<T>,2> &grid2, size_t nthreads)
{ {
myassert(grid.shape()==grid2.shape(), "shape mismatch"); checkShape(grid.shape(), grid2.shape());
size_t nu=grid.shape(0), nv=grid.shape(1); size_t nu=grid.shape(0), nv=grid.shape(1);
#pragma omp parallel for num_threads(nthreads) #pragma omp parallel for num_threads(nthreads)
...@@ -339,7 +339,7 @@ template<typename T> void hartley2complex ...@@ -339,7 +339,7 @@ template<typename T> void hartley2complex
template<typename T> void hartley2_2D(const const_mav<T,2> &in, template<typename T> void hartley2_2D(const const_mav<T,2> &in,
const mav<T,2> &out, size_t nthreads) const mav<T,2> &out, size_t nthreads)
{ {
myassert(in.shape()==out.shape(), "shape mismatch"); checkShape(in.shape(), out.shape());
size_t nu=in.shape(0), nv=in.shape(1); size_t nu=in.shape(0), nv=in.shape(1);
ptrdiff_t sz=ptrdiff_t(sizeof(T)); ptrdiff_t sz=ptrdiff_t(sizeof(T));
pocketfft::stride_t stri{sz*in.stride(0), sz*in.stride(1)}; pocketfft::stride_t stri{sz*in.stride(0), sz*in.stride(1)};
...@@ -366,42 +366,23 @@ template<typename T> void hartley2_2D(const const_mav<T,2> &in, ...@@ -366,42 +366,23 @@ template<typename T> void hartley2_2D(const const_mav<T,2> &in,
class ES_Kernel class ES_Kernel
{ {
protected: private:
double beta;
public:
ES_Kernel(size_t supp) : beta(2.3*supp) {}
double operator()(double v) const { return exp(beta*(sqrt(1.-v*v)-1.)); }
static size_t get_supp(double epsilon)
{
static const vector<double> maxmaperr { 1e8, 0.32, 0.021, 6.2e-4,
1.08e-5, 1.25e-7, 8.25e-10, 5.70e-12, 1.22e-13, 2.48e-15, 4.82e-17,
6.74e-19, 5.41e-21, 4.41e-23, 7.88e-25, 3.9e-26 };
double epssq = epsilon*epsilon;
for (size_t i=1; i<maxmaperr.size(); ++i)
if (epssq>maxmaperr[i]) return i;
myfail("requested epsilon too small - minimum is 2e-13");
}
};
class ES_Kernel_with_correction: public ES_Kernel
{
protected:
static constexpr double pi = 3.141592653589793238462643383279502884197; static constexpr double pi = 3.141592653589793238462643383279502884197;
double beta;
int p; int p;
vector<double> x, wgt, psi; vector<double> x, wgt, psi;
size_t supp; size_t supp;
public: public:
ES_Kernel_with_correction(size_t supp_, size_t nthreads) ES_Kernel(size_t supp_, size_t nthreads)
: ES_Kernel(supp_), p(int(1.5*supp_+2)), supp(supp_) : beta(2.3*supp_), p(int(1.5*supp_+2)), supp(supp_)
{ {
legendre_prep(2*p,x,wgt,nthreads); legendre_prep(2*p,x,wgt,nthreads);
psi=x; psi=x;
for (auto &v:psi) for (auto &v:psi)
v=operator()(v); v=operator()(v);
} }
double operator()(double v) const { return exp(beta*(sqrt(1.-v*v)-1.)); }
/* Compute correction factors for the ES gridding kernel /* Compute correction factors for the ES gridding kernel
This implementation follows eqs. (3.8) to (3.10) of Barnett et al. 2018 */ This implementation follows eqs. (3.8) to (3.10) of Barnett et al. 2018 */
double corfac(double v) const double corfac(double v) const
...@@ -411,13 +392,25 @@ class ES_Kernel_with_correction: public ES_Kernel ...@@ -411,13 +392,25 @@ class ES_Kernel_with_correction: public ES_Kernel
tmp += wgt[i]*psi[i]*cos(pi*supp*v*x[i]); tmp += wgt[i]*psi[i]*cos(pi*supp*v*x[i]);
return 1./(supp*tmp); return 1./(supp*tmp);
} }
static size_t get_supp(double epsilon)
{
static const vector<double> maxmaperr { 1e8, 0.32, 0.021, 6.2e-4,
1.08e-5, 1.25e-7, 8.25e-10, 5.70e-12, 1.22e-13, 2.48e-15, 4.82e-17,
6.74e-19, 5.41e-21, 4.41e-23, 7.88e-25, 3.9e-26 };
double epssq = epsilon*epsilon;
for (size_t i=1; i<maxmaperr.size(); ++i)
if (epssq>maxmaperr[i]) return i;
myfail("requested epsilon too small - minimum is 2e-13");
}
}; };
/* Compute correction factors for the ES gridding kernel /* Compute correction factors for the ES gridding kernel
This implementation follows eqs. (3.8) to (3.10) of Barnett et al. 2018 */ This implementation follows eqs. (3.8) to (3.10) of Barnett et al. 2018 */
vector<double> correction_factors(size_t n, size_t nval, size_t supp, vector<double> correction_factors(size_t n, size_t nval, size_t supp,
size_t nthreads) size_t nthreads)
{ {
ES_Kernel_with_correction kernel(supp, nthreads); ES_Kernel kernel(supp, nthreads);
vector<double> res(nval); vector<double> res(nval);
double xn = 1./n; double xn = 1./n;
#pragma omp parallel for schedule(static) num_threads(nthreads) #pragma omp parallel for schedule(static) num_threads(nthreads)
...@@ -451,7 +444,8 @@ class Baselines ...@@ -451,7 +444,8 @@ class Baselines
size_t shift, mask; size_t shift, mask;
public: public:
template<typename T> Baselines(const const_mav<T,2> &coord_, const const_mav<T,1> &freq) template<typename T> Baselines(const const_mav<T,2> &coord_,
const const_mav<T,1> &freq)
{ {
constexpr double speedOfLight = 299792458.; constexpr double speedOfLight = 299792458.;
myassert(coord_.shape(1)==3, "dimension mismatch"); myassert(coord_.shape(1)==3, "dimension mismatch");
...@@ -488,8 +482,7 @@ class Baselines ...@@ -488,8 +482,7 @@ class Baselines
mav<T,2> &res) const mav<T,2> &res) const
{ {
size_t nvis = idx.shape(0); size_t nvis = idx.shape(0);
myassert(res.shape(0)==nvis, "shape mismatch"); checkShape(res.shape(), {idx.shape(0), 3});
myassert(res.shape(1)==3, "shape mismatch");
for (size_t i=0; i<nvis; i++) for (size_t i=0; i<nvis; i++)
{ {
auto uvw = effectiveCoord(idx(i)); auto uvw = effectiveCoord(idx(i));
...@@ -502,10 +495,9 @@ class Baselines ...@@ -502,10 +495,9 @@ class Baselines
template<typename T> void ms2vis(const mav<const T,2> &ms, template<typename T> void ms2vis(const mav<const T,2> &ms,
const mav<const uint32_t,1> &idx, mav<T,1> &vis, size_t nthreads) const const mav<const uint32_t,1> &idx, mav<T,1> &vis, size_t nthreads) const
{ {
myassert(ms.shape(0)==nrows, "shape mismatch"); checkShape(ms.shape(), {nrows, nchan});
myassert(ms.shape(1)==nchan, "shape mismatch");
size_t nvis = idx.shape(0); size_t nvis = idx.shape(0);
myassert(vis.shape(0)==nvis, "shape mismatch"); checkShape(vis.shape(), {nvis});
#pragma omp parallel for num_threads(nthreads) #pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nvis; ++i) for (size_t i=0; i<nvis; ++i)
{ {
...@@ -518,9 +510,8 @@ class Baselines ...@@ -518,9 +510,8 @@ class Baselines
const mav<const uint32_t,1> &idx, mav<T,2> &ms, size_t nthreads) const const mav<const uint32_t,1> &idx, mav<T,2> &ms, size_t nthreads) const
{ {
size_t nvis = vis.shape(0); size_t nvis = vis.shape(0);
myassert(idx.shape(0)==nvis, "shape mismatch"); checkShape(idx.shape(), {nvis});
myassert(ms.shape(0)==nrows, "shape mismatch"); checkShape(ms.shape(), {nrows, nchan});
myassert(ms.shape(1)==nchan, "shape mismatch");
#pragma omp parallel for num_threads(nthreads) #pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nvis; ++i) for (size_t i=0; i<nvis; ++i)
{ {
...@@ -539,6 +530,8 @@ class GridderConfig ...@@ -539,6 +530,8 @@ class GridderConfig
double beta; double beta;
vector<double> cfu, cfv; vector<double> cfu, cfv;
size_t nthreads; size_t nthreads;
// double ushift, vshift;
int maxiu0, maxiv0;
complex<double> wscreen(double x, double y, double w, bool adjoint) const complex<double> wscreen(double x, double y, double w, bool adjoint) const
{ {
...@@ -560,7 +553,9 @@ class GridderConfig ...@@ -560,7 +553,9 @@ class GridderConfig
supp(ES_Kernel::get_supp(epsilon)), nsafe((supp+1)/2), supp(ES_Kernel::get_supp(epsilon)), nsafe((supp+1)/2),
nu(max(2*nsafe,2*nx_dirty)), nv(max(2*nsafe,2*ny_dirty)), nu(max(2*nsafe,2*nx_dirty)), nv(max(2*nsafe,2*ny_dirty)),
beta(2.3*supp), beta(2.3*supp),
cfu(nx_dirty), cfv(ny_dirty), nthreads(nthreads_) cfu(nx_dirty), cfv(ny_dirty), nthreads(nthreads_),
// ushift(-supp*0.5+1), vshift(-supp*0.5+1),
maxiu0((nu+nsafe)-supp), maxiv0((nv+nsafe)-supp)
{ {
myassert((nx_dirty&1)==0, "nx_dirty must be even"); myassert((nx_dirty&1)==0, "nx_dirty must be even");
myassert((ny_dirty&1)==0, "ny_dirty must be even"); myassert((ny_dirty&1)==0, "ny_dirty must be even");
...@@ -680,10 +675,10 @@ class GridderConfig ...@@ -680,10 +675,10 @@ class GridderConfig
{ {
u=fmod1(u_in*psx)*nu, u=fmod1(u_in*psx)*nu,
iu0 = int(u-supp*0.5 + 1 + nu) - nu; iu0 = int(u-supp*0.5 + 1 + nu) - nu;
iu0 = min<int>(iu0, (nu+nsafe)-supp); iu0 = min(iu0, maxiu0);
v=fmod1(v_in*psy)*nv; v=fmod1(v_in*psy)*nv;
iv0 = int(v-supp*0.5 + 1 + nv) - nv; iv0 = int(v-supp*0.5 + 1 + nv) - nv;
iv0 = min<int>(iv0, (nv+nsafe)-supp); iv0 = min(iv0, maxiv0);
} }
template<typename T> void apply_wscreen(const const_mav<complex<T>,2> &dirty, template<typename T> void apply_wscreen(const const_mav<complex<T>,2> &dirty,
...@@ -1182,7 +1177,7 @@ template<typename T> void get_correlations ...@@ -1182,7 +1177,7 @@ template<typename T> void get_correlations
template<typename T> void apply_wcorr(const GridderConfig &gconf, template<typename T> void apply_wcorr(const GridderConfig &gconf,
const mav<T,2> &dirty, const ES_Kernel_with_correction &kernel, double dw) const mav<T,2> &dirty, const ES_Kernel &kernel, double dw)
{ {
auto nx_dirty=gconf.Nxdirty(); auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty(); auto ny_dirty=gconf.Nydirty();
...@@ -1324,7 +1319,7 @@ template<typename T, typename Serv> void x2dirty( ...@@ -1324,7 +1319,7 @@ template<typename T, typename Serv> void x2dirty(
dirty(i,j) += tdirty(i,j).real(); dirty(i,j) += tdirty(i,j).real();
} }
// correct for w gridding // correct for w gridding
apply_wcorr(gconf, dirty, ES_Kernel_with_correction(supp, nthreads), dw); apply_wcorr(gconf, dirty, ES_Kernel(supp, nthreads), dw);
} }
else else
{ {
...@@ -1376,7 +1371,7 @@ template<typename T, typename Serv> void dirty2x( ...@@ -1376,7 +1371,7 @@ template<typename T, typename Serv> void dirty2x(
for (size_t j=0; j<ny_dirty; ++j) for (size_t j=0; j<ny_dirty; ++j)
tdirty(i,j) = dirty(i,j); tdirty(i,j) = dirty(i,j);
// correct for w gridding // correct for w gridding
apply_wcorr(gconf, tdirty, ES_Kernel_with_correction(supp, nthreads), dw); apply_wcorr(gconf, tdirty, ES_Kernel(supp, nthreads), dw);
vector<uint32_t> subidx; vector<uint32_t> subidx;
tmpStorage<complex<T>,2> grid_({gconf.Nu(),gconf.Nv()}); tmpStorage<complex<T>,2> grid_({gconf.Nu(),gconf.Nv()});
......
...@@ -818,7 +818,6 @@ template<typename T> pyarr<complex<T>> dirty2vis2(const PyBaselines &baselines, ...@@ -818,7 +818,6 @@ template<typename T> pyarr<complex<T>> dirty2vis2(const PyBaselines &baselines,
auto wgt2 = make_const_mav<1>(wgt); auto wgt2 = make_const_mav<1>(wgt);
auto vis = makeArray<complex<T>>({idx2.shape(0)}); auto vis = makeArray<complex<T>>({idx2.shape(0)});
auto vis2 = make_mav<1>(vis); auto vis2 = make_mav<1>(vis);
vis2.fill(0);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
vis2.fill(0); vis2.fill(0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment