Commit d55fad0c authored by Martin Reinecke's avatar Martin Reinecke
Browse files

try to get more accuracy

parent 7694232c
...@@ -63,6 +63,8 @@ template<> struct VLEN<float> { static constexpr size_t val=4; }; ...@@ -63,6 +63,8 @@ template<> struct VLEN<float> { static constexpr size_t val=4; };
template<> struct VLEN<double> { static constexpr size_t val=2; }; template<> struct VLEN<double> { static constexpr size_t val=2; };
#endif #endif
// "mav" stands for "multidimensional array view"
template<typename T, size_t ndim> class mav template<typename T, size_t ndim> class mav
{ {
static_assert((ndim>0) && (ndim<3), "only supports 1D and 2D arrays"); static_assert((ndim>0) && (ndim<3), "only supports 1D and 2D arrays");
...@@ -83,8 +85,6 @@ template<typename T, size_t ndim> class mav ...@@ -83,8 +85,6 @@ template<typename T, size_t ndim> class mav
for (size_t d=2; d<=ndim; ++d) for (size_t d=2; d<=ndim; ++d)
str[ndim-d] = str[ndim-d+1]*shp[ndim-d+1]; str[ndim-d] = str[ndim-d+1]*shp[ndim-d+1];
} }
operator mav<const T, ndim>() const
{ return mav<const T, ndim>(d,shp,str); }
T &operator[](size_t i) const T &operator[](size_t i) const
{ return operator()(i); } { return operator()(i); }
T &operator()(size_t i) const T &operator()(size_t i) const
...@@ -120,6 +120,7 @@ template<typename T, size_t ndim> class mav ...@@ -120,6 +120,7 @@ template<typename T, size_t ndim> class mav
} }
void fill(const T &val) const void fill(const T &val) const
{ {
// FIXME: special cases for contiguous arrays and/or zeroing?
if (ndim==1) if (ndim==1)
for (size_t i=0; i<shp[0]; ++i) for (size_t i=0; i<shp[0]; ++i)
d[str[0]*i]=val; d[str[0]*i]=val;
...@@ -160,15 +161,61 @@ template<typename T, size_t ndim> class tmpStorage ...@@ -160,15 +161,61 @@ template<typename T, size_t ndim> class tmpStorage
void fill(const T & val) void fill(const T & val)
{ std::fill(d.begin(), d.end(), val); } { std::fill(d.begin(), d.end(), val); }
}; };
// //
// basic utilities // basic utilities
// //
#if defined (__GNUC__)
#define LOC_ CodeLocation(__FILE__, __LINE__, __PRETTY_FUNCTION__)
#else
#define LOC_ CodeLocation(__FILE__, __LINE__)
#endif
void myassert(bool cond, const char *msg) #define myfail(...) \
do { \
std::ostringstream os; \
streamDump__(os, LOC_, "\n", ##__VA_ARGS__, "\n"); \
throw std::runtime_error(os.str()); \
} while(0)
#define myassert(cond,...) \
do { \
if (cond); \
else { myfail("Assertion failure\n", ##__VA_ARGS__); } \
} while(0)
template<typename T>
inline void streamDump__(std::ostream &os, const T& value)
{ os << value; }
template<typename T, typename ... Args>
inline void streamDump__(std::ostream &os, const T& value,
const Args& ... args)
{ {
if (cond) return; os << value;
throw runtime_error(msg); streamDump__(os, args...);
} }
// to be replaced with std::source_location once available
class CodeLocation
{
private:
const char *file, *func;
int line;
public:
CodeLocation(const char *file_, int line_, const char *func_=nullptr)
: file(file_), func(func_), line(line_) {}
ostream &print(ostream &os) const
{
os << "file: " << file << ", line: " << line;
if (func) os << ", function: " << func;
return os;
}
};
inline std::ostream &operator<<(std::ostream &os, const CodeLocation &loc)
{ return loc.print(os); }
template<size_t ndim> void checkShape template<size_t ndim> void checkShape
(const array<size_t, ndim> &shp1, const array<size_t, ndim> &shp2) (const array<size_t, ndim> &shp1, const array<size_t, ndim> &shp2)
...@@ -265,7 +312,7 @@ size_t get_supp(double epsilon) ...@@ -265,7 +312,7 @@ size_t get_supp(double epsilon)
for (size_t i=1; i<maxmaperr.size(); ++i) for (size_t i=1; i<maxmaperr.size(); ++i)
if (epssq>maxmaperr[i]) return i; if (epssq>maxmaperr[i]) return i;
throw runtime_error("requested epsilon too small - minimum is 2e-13"); myfail("requested epsilon too small - minimum is 2e-13");
} }
struct RowChan struct RowChan
...@@ -339,28 +386,27 @@ template<typename T> void hartley2_2D(const const_mav<T,2> &in, const mav<T,2> & ...@@ -339,28 +386,27 @@ template<typename T> void hartley2_2D(const const_mav<T,2> &in, const mav<T,2> &
} }
template<typename T> class EC_Kernel class EC_Kernel
{ {
protected: protected:
T beta; double beta;
public: public:
EC_Kernel(size_t supp) : beta(2.3*supp) {} EC_Kernel(size_t supp) : beta(2.3*supp) {}
T operator()(T v) const { return exp(beta*(sqrt(T(1)-v*v)-T(1))); } double operator()(double v) const { return exp(beta*(sqrt(1.-v*v)-1.)); }
}; };
template<typename T> class EC_Kernel_with_correction: public EC_Kernel<T> class EC_Kernel_with_correction: public EC_Kernel
{ {
protected: protected:
static constexpr T pi = T(3.141592653589793238462643383279502884197L); static constexpr double pi = 3.141592653589793238462643383279502884197;
int p; int p;
vector<double> x, wgt, psi; vector<double> x, wgt, psi;
size_t supp; size_t supp;
public: public:
using EC_Kernel<T>::operator();
EC_Kernel_with_correction(size_t supp_, size_t nthreads) EC_Kernel_with_correction(size_t supp_, size_t nthreads)
: EC_Kernel<T>(supp_), p(int(1.5*supp_+2)), supp(supp_) : EC_Kernel(supp_), p(int(1.5*supp_+2)), supp(supp_)
{ {
legendre_prep(2*p,x,wgt,nthreads); legendre_prep(2*p,x,wgt,nthreads);
psi=x; psi=x;
...@@ -369,12 +415,12 @@ template<typename T> class EC_Kernel_with_correction: public EC_Kernel<T> ...@@ -369,12 +415,12 @@ template<typename T> class EC_Kernel_with_correction: public EC_Kernel<T>
} }
/* Compute correction factors for the ES gridding kernel /* Compute correction factors for the ES gridding kernel
This implementation follows eqs. (3.8) to (3.10) of Barnett et al. 2018 */ This implementation follows eqs. (3.8) to (3.10) of Barnett et al. 2018 */
T corfac(T v) const double corfac(double v) const
{ {
T tmp=0; double tmp=0;
for (int i=0; i<p; ++i) for (int i=0; i<p; ++i)
tmp += wgt[i]*psi[i]*cos(pi*supp*v*x[i]); tmp += wgt[i]*psi[i]*cos(pi*supp*v*x[i]);
return T(1./(supp*tmp)); return 1./(supp*tmp);
} }
}; };
/* Compute correction factors for the ES gridding kernel /* Compute correction factors for the ES gridding kernel
...@@ -382,7 +428,7 @@ template<typename T> class EC_Kernel_with_correction: public EC_Kernel<T> ...@@ -382,7 +428,7 @@ template<typename T> class EC_Kernel_with_correction: public EC_Kernel<T>
vector<double> correction_factors (size_t n, size_t nval, size_t supp, vector<double> correction_factors (size_t n, size_t nval, size_t supp,
size_t nthreads) size_t nthreads)
{ {
EC_Kernel_with_correction<double> kernel(supp, nthreads); EC_Kernel_with_correction kernel(supp, nthreads);
vector<double> res(nval); vector<double> res(nval);
double xn = 1./n; double xn = 1./n;
#pragma omp parallel for schedule(static) num_threads(nthreads) #pragma omp parallel for schedule(static) num_threads(nthreads)
...@@ -622,7 +668,7 @@ template<typename T> class GridderConfig ...@@ -622,7 +668,7 @@ template<typename T> class GridderConfig
if (j2>=nv) j2-=nv; if (j2>=nv) j2-=nv;
grid(i2,j2) = dirty(i,j)*cfu[i]*cfv[j]; grid(i2,j2) = dirty(i,j)*cfu[i]*cfv[j];
} }
hartley2_2D<T>(grid, grid, nthreads); hartley2_2D<T>(cmav(grid), grid, nthreads);
} }
void dirty2grid_c(const const_mav<complex<T>,2> &dirty, void dirty2grid_c(const const_mav<complex<T>,2> &dirty,
...@@ -1123,7 +1169,7 @@ template<typename T> void get_correlations ...@@ -1123,7 +1169,7 @@ template<typename T> void get_correlations
template<typename T, typename T2> void apply_wcorr(const GridderConfig<T> &gconf, template<typename T, typename T2> void apply_wcorr(const GridderConfig<T> &gconf,
const mav<T2,2> &dirty, const EC_Kernel_with_correction<T> &kernel, double dw) const mav<T2,2> &dirty, const EC_Kernel_with_correction &kernel, double dw)
{ {
auto nx_dirty=gconf.Nxdirty(); auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty(); auto ny_dirty=gconf.Nydirty();
...@@ -1180,7 +1226,7 @@ template<typename T, typename Serv> void x2dirty( ...@@ -1180,7 +1226,7 @@ template<typename T, typename Serv> void x2dirty(
tmpStorage<T,2> rgrid_({nu, nv}); tmpStorage<T,2> rgrid_({nu, nv});
auto rgrid=rgrid_.getMav(); auto rgrid=rgrid_.getMav();
complex2hartley(cmav(grid), rgrid, gconf.Nthreads()); complex2hartley(cmav(grid), rgrid, gconf.Nthreads());
gconf.grid2dirty(rgrid, dirty); gconf.grid2dirty(cmav(rgrid), dirty);
} }
template<typename T, typename Serv> void dirty2x( template<typename T, typename Serv> void dirty2x(
...@@ -1275,7 +1321,7 @@ template<typename T, typename Serv> void x2dirty_wstack( ...@@ -1275,7 +1321,7 @@ template<typename T, typename Serv> void x2dirty_wstack(
size_t nvis = srv.Nvis(); size_t nvis = srv.Nvis();
auto w_supp = gconf.Supp(); auto w_supp = gconf.Supp();
size_t nthreads = gconf.Nthreads(); size_t nthreads = gconf.Nthreads();
EC_Kernel_with_correction<T> kernel(w_supp, nthreads); EC_Kernel_with_correction kernel(w_supp, nthreads);
T wmin; T wmin;
double dw; double dw;
size_t nplanes; size_t nplanes;
...@@ -1318,8 +1364,8 @@ template<typename T, typename Serv> void x2dirty_wstack( ...@@ -1318,8 +1364,8 @@ template<typename T, typename Serv> void x2dirty_wstack(
} }
grid.fill(0); grid.fill(0);
vis2grid_c(srv.getBaselines(), gconf, cmav(idx_loc), cmav(vis_loc), grid, nullmav<T,1>(), wcur, T(dw)); vis2grid_c(srv.getBaselines(), gconf, cmav(idx_loc), cmav(vis_loc), grid, nullmav<T,1>(), wcur, T(dw));
gconf.grid2dirty_c(grid, tdirty); gconf.grid2dirty_c(cmav(grid), tdirty);
gconf.apply_wscreen(tdirty, tdirty, wcur, false); gconf.apply_wscreen(cmav(tdirty), tdirty, wcur, false);
#pragma omp parallel for num_threads(nthreads) #pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nx_dirty; ++i) for (size_t i=0; i<nx_dirty; ++i)
for (size_t j=0; j<ny_dirty; ++j) for (size_t j=0; j<ny_dirty; ++j)
...@@ -1344,7 +1390,7 @@ template<typename T, typename Serv> void x2dirty_wstack2( ...@@ -1344,7 +1390,7 @@ template<typename T, typename Serv> void x2dirty_wstack2(
size_t nvis = srv.Nvis(); size_t nvis = srv.Nvis();
auto w_supp = gconf.Supp(); auto w_supp = gconf.Supp();
size_t nthreads = gconf.Nthreads(); size_t nthreads = gconf.Nthreads();
EC_Kernel_with_correction<T> kernel(w_supp, nthreads); EC_Kernel_with_correction kernel(w_supp, nthreads);
T wmin; T wmin;
double dw; double dw;
size_t nplanes; size_t nplanes;
...@@ -1375,8 +1421,8 @@ template<typename T, typename Serv> void x2dirty_wstack2( ...@@ -1375,8 +1421,8 @@ template<typename T, typename Serv> void x2dirty_wstack2(
grid.fill(0); grid.fill(0);
Serv newsrv(srv, const_mav<uint32_t, 1>(newidx.data(), {newidx.size()})); Serv newsrv(srv, const_mav<uint32_t, 1>(newidx.data(), {newidx.size()}));
x2grid_c(gconf, newsrv, grid, wcur, T(dw)); x2grid_c(gconf, newsrv, grid, wcur, T(dw));
gconf.grid2dirty_c(grid, tdirty); gconf.grid2dirty_c(cmav(grid), tdirty);
gconf.apply_wscreen(tdirty, tdirty, wcur, false); gconf.apply_wscreen(cmav(tdirty), tdirty, wcur, false);
#pragma omp parallel for num_threads(nthreads) #pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nx_dirty; ++i) for (size_t i=0; i<nx_dirty; ++i)
for (size_t j=0; j<ny_dirty; ++j) for (size_t j=0; j<ny_dirty; ++j)
...@@ -1403,7 +1449,7 @@ template<typename T, typename Serv> void dirty2x_wstack( ...@@ -1403,7 +1449,7 @@ template<typename T, typename Serv> void dirty2x_wstack(
size_t nvis = srv.Nvis(); size_t nvis = srv.Nvis();
auto w_supp = gconf.Supp(); auto w_supp = gconf.Supp();
size_t nthreads = gconf.Nthreads(); size_t nthreads = gconf.Nthreads();
EC_Kernel_with_correction<T> kernel(w_supp, nthreads); EC_Kernel_with_correction kernel(w_supp, nthreads);
T wmin; T wmin;
double dw; double dw;
size_t nplanes; size_t nplanes;
...@@ -1454,8 +1500,8 @@ template<typename T, typename Serv> void dirty2x_wstack( ...@@ -1454,8 +1500,8 @@ template<typename T, typename Serv> void dirty2x_wstack(
for (size_t i=0; i<nx_dirty; ++i) for (size_t i=0; i<nx_dirty; ++i)
for (size_t j=0; j<ny_dirty; ++j) for (size_t j=0; j<ny_dirty; ++j)
tdirty2(i,j) = tdirty(i,j); tdirty2(i,j) = tdirty(i,j);
gconf.apply_wscreen(tdirty2, tdirty2, wcur, true); gconf.apply_wscreen(cmav(tdirty2), tdirty2, wcur, true);
gconf.dirty2grid_c(tdirty2, grid); gconf.dirty2grid_c(cmav(tdirty2), grid);
grid2vis_c(srv.getBaselines(), gconf, cmav(idx_loc), cmav(grid), vis_loc, const_mav<T,1>(nullptr,{0}), wcur, T(dw)); grid2vis_c(srv.getBaselines(), gconf, cmav(idx_loc), cmav(grid), vis_loc, const_mav<T,1>(nullptr,{0}), wcur, T(dw));
#pragma omp parallel for num_threads(nthreads) #pragma omp parallel for num_threads(nthreads)
...@@ -1642,7 +1688,6 @@ template<typename T> void full_degridding(const const_mav<T,2> &uvw, ...@@ -1642,7 +1688,6 @@ template<typename T> void full_degridding(const const_mav<T,2> &uvw,
// public names // public names
using detail::mav; using detail::mav;
using detail::const_mav; using detail::const_mav;
using detail::myassert;
using detail::Baselines; using detail::Baselines;
using detail::GridderConfig; using detail::GridderConfig;
using detail::ms2grid_c; using detail::ms2grid_c;
...@@ -1653,6 +1698,8 @@ using detail::vis2dirty_wstack; ...@@ -1653,6 +1698,8 @@ using detail::vis2dirty_wstack;
using detail::dirty2vis_wstack; using detail::dirty2vis_wstack;
using detail::full_gridding; using detail::full_gridding;
using detail::full_degridding; using detail::full_degridding;
using detail::streamDump__;
using detail::CodeLocation;
} // namespace gridder } // namespace gridder
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment