Commit aaf6dfcd authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'tweak_es_kernel' into 'ducc0'

Tweak ES kernel

See merge request !19
parents 1ef56637 6d45cc1d
Pipeline #77290 passed with stages
in 13 minutes and 6 seconds
......@@ -34,6 +34,7 @@
#include "ducc0/infra/threading.h"
#include "ducc0/infra/useful_macros.h"
#include "ducc0/infra/mav.h"
#include "ducc0/infra/simd.h"
#include "ducc0/math/es_kernel.h"
namespace ducc0 {
......@@ -536,8 +537,11 @@ template<typename T, typename T2=complex<T>> class Helper
public:
const T2 *p0r;
T2 *p0w;
T kernel[64] DUCC0_ALIGNED(64);
static constexpr size_t vlen=native_simd<T>::size();
union {
T scalar[64];
native_simd<T> simd[64/vlen];
} kernel;
Helper(const GridderConfig &gconf_, const T2 *grid_r_, T2 *grid_w_,
vector<std::mutex> &locks_, double w0_=-1, double dw_=-1)
......@@ -551,13 +555,16 @@ template<typename T, typename T2=complex<T>> class Helper
w0(w0_),
xdw(T(1)/dw_),
nexp(2*supp + do_w_gridding),
nvecs(vlen*((nexp+vlen-1)/vlen)),
nvecs((nexp+vlen-1)/vlen),
locks(locks_)
{}
{
MR_assert(2*supp+1<=64, "support too large");
for (auto &v: kernel.simd) v=0;
}
~Helper() { if (grid_w) dump(); }
int lineJump() const { return sv; }
T Wfac() const { return kernel[2*supp]; }
T Wfac() const { return kernel.scalar[2*supp]; }
void prep(const UVW &in)
{
double u, v;
......@@ -567,20 +574,13 @@ template<typename T, typename T2=complex<T>> class Helper
double y0 = xsupp*(iv0-v);
for (int i=0; i<supp; ++i)
{
kernel[i ] = T(x0+i*xsupp);
kernel[i+supp] = T(y0+i*xsupp);
kernel.scalar[i ] = T(x0+i*xsupp);
kernel.scalar[i+supp] = T(y0+i*xsupp);
}
if (do_w_gridding)
kernel[2*supp] = min(T(1), T(xdw*xsupp*abs(w0-in.w)));
for (size_t i=nexp; i<nvecs; ++i)
kernel[i]=0;
for (size_t i=0; i<nvecs; ++i)
{
kernel[i] = T(1) - kernel[i]*kernel[i];
kernel[i] = (kernel[i]<0) ? T(-200.) : beta*(sqrt(kernel[i])-T(1));
}
kernel.scalar[2*supp] = T(xdw*xsupp*abs(w0-in.w));
for (size_t i=0; i<nvecs; ++i)
kernel[i] = exp(kernel[i]);
kernel.simd[i] = esk(kernel.simd[i], beta);
if ((iu0<bu0) || (iv0<bv0) || (iu0+supp>bu0+su) || (iv0+supp>bv0+sv))
{
if (grid_w) { dump(); fill(wbuf.begin(), wbuf.end(), T(0)); }
......@@ -681,8 +681,8 @@ template<typename T, typename Serv> void x2grid_c
{
Helper<T> hlp(gconf, nullptr, grid.vdata(), locks, w0, dw);
int jump = hlp.lineJump();
const T * DUCC0_RESTRICT ku = hlp.kernel;
const T * DUCC0_RESTRICT kv = hlp.kernel+supp;
const T * DUCC0_RESTRICT ku = hlp.kernel.scalar;
const T * DUCC0_RESTRICT kv = hlp.kernel.scalar+supp;
while (auto rng=sched.getNext()) for(auto ipart=rng.lo; ipart<rng.hi; ++ipart)
{
......@@ -729,8 +729,8 @@ template<typename T, typename Serv> void grid2x_c
{
Helper<T> hlp(gconf, grid.data(), nullptr, locks, w0, dw);
int jump = hlp.lineJump();
const T * DUCC0_RESTRICT ku = hlp.kernel;
const T * DUCC0_RESTRICT kv = hlp.kernel+supp;
const T * DUCC0_RESTRICT ku = hlp.kernel.scalar;
const T * DUCC0_RESTRICT kv = hlp.kernel.scalar+supp;
while (auto rng=sched.getNext()) for(auto ipart=rng.lo; ipart<rng.hi; ++ipart)
{
......
......@@ -220,9 +220,13 @@ template<typename T> class Interpolator
mapper[itheta*ncp+iphi].push_back(i);
}
size_t cnt=0;
for (const auto &vec: mapper)
for (auto &vec: mapper)
{
for (auto i:vec)
idx[cnt++] = i;
vec.clear();
vec.shrink_to_fit();
}
return idx;
}
......@@ -434,7 +438,14 @@ template<typename T> class Interpolator
auto idx = getIdx(ptg);
execStatic(idx.size(), nthreads, 0, [&](Scheduler &sched)
{
vector<T> wt(supp), wp(supp);
union {
native_simd<T> simd[64/vl];
T scalar[64];
} kdata;
T *wt(kdata.scalar), *wp(kdata.scalar+supp);
size_t nvec = (2*supp+vl-1)/vl;
for (auto &v: kdata.simd) v=0;
MR_assert(2*supp<=64, "support too large");
vector<T> psiarr(2*kmax+1);
#ifdef SIMD_INTERPOL
vector<native_simd<T>> psiarr2((2*kmax+1+vl-1)/vl);
......@@ -446,11 +457,13 @@ template<typename T> class Interpolator
T f0=T(0.5*supp+ptg(i,0)*xdtheta);
size_t i0 = size_t(f0+T(1));
for (size_t t=0; t<supp; ++t)
wt[t] = kernel((t+i0-f0)*delta - 1);
wt[t] = (t+i0-f0)*delta - 1;
T f1=T(0.5)*supp+ptg(i,1)*xdphi;
size_t i1 = size_t(f1+1.);
for (size_t t=0; t<supp; ++t)
wp[t] = kernel((t+i1-f1)*delta - 1);
wp[t] = (t+i1-f1)*delta - 1;
for (size_t t=0; t<nvec; ++t)
kdata.simd[t] = kernel(kdata.simd[t]);
psiarr[0]=1.;
double psi=ptg(i,2);
double cpsi=cos(psi), spsi=sin(psi);
......@@ -484,25 +497,25 @@ template<typename T> class Interpolator
{
#ifdef SPECIAL_CASING
case 1:
interpol_help0<1,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
interpol_help0<1,1>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break;
case 2:
interpol_help0<2,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
interpol_help0<2,1>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break;
case 3:
interpol_help0<3,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
interpol_help0<3,1>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break;
case 4:
interpol_help0<4,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
interpol_help0<4,1>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break;
case 5:
interpol_help0<5,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
interpol_help0<5,1>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break;
case 6:
interpol_help0<6,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
interpol_help0<6,1>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break;
case 7:
interpol_help0<7,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
interpol_help0<7,1>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break;
#endif
default:
......@@ -549,25 +562,25 @@ template<typename T> class Interpolator
{
#ifdef SPECIAL_CASING
case 1:
interpol_help0<1,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
interpol_help0<1,3>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break;
case 2:
interpol_help0<2,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
interpol_help0<2,3>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break;
case 3:
interpol_help0<3,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
interpol_help0<3,3>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break;
case 4:
interpol_help0<4,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
interpol_help0<4,3>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break;
case 5:
interpol_help0<5,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
interpol_help0<5,3>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break;
case 6:
interpol_help0<6,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
interpol_help0<6,3>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break;
case 7:
interpol_help0<7,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
interpol_help0<7,3>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break;
#endif
default:
......@@ -629,7 +642,14 @@ template<typename T> class Interpolator
execStatic(idx.size(), nthreads, 0, [&](Scheduler &sched)
{
size_t b_theta=99999999999999, b_phi=9999999999999999;
vector<T> wt(supp), wp(supp);
union {
native_simd<T> simd[64/vl];
T scalar[64];
} kdata;
T *wt(kdata.scalar), *wp(kdata.scalar+supp);
size_t nvec = (2*supp+vl-1)/vl;
for (auto &v: kdata.simd) v=0;
MR_assert(2*supp<=64, "support too large");
vector<T> psiarr(2*kmax+1);
#ifdef SIMD_INTERPOL
vector<native_simd<T>> psiarr2((2*kmax+1+vl-1)/vl);
......@@ -641,11 +661,13 @@ template<typename T> class Interpolator
T f0=T(0.5)*supp+ptg(i,0)*xdtheta;
size_t i0 = size_t(f0+1.);
for (size_t t=0; t<supp; ++t)
wt[t] = kernel((t+i0-f0)*delta - 1);
wt[t] = (t+i0-f0)*delta - 1;
T f1=T(0.5)*supp+ptg(i,1)*xdphi;
size_t i1 = size_t(f1+1.);
for (size_t t=0; t<supp; ++t)
wp[t] = kernel((t+i1-f1)*delta - 1);
wp[t] = (t+i1-f1)*delta - 1;
for (size_t t=0; t<nvec; ++t)
kdata.simd[t] = kernel(kdata.simd[t]);
psiarr[0]=1.;
double psi=ptg(i,2);
double cpsi=cos(psi), spsi=sin(psi);
......@@ -696,25 +718,25 @@ template<typename T> class Interpolator
{
#ifdef SPECIAL_CASING
case 1:
deinterpol_help0<1,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
deinterpol_help0<1,1>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break;
case 2:
deinterpol_help0<2,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
deinterpol_help0<2,1>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break;
case 3:
deinterpol_help0<3,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
deinterpol_help0<3,1>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break;
case 4:
deinterpol_help0<4,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
deinterpol_help0<4,1>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break;
case 5:
deinterpol_help0<5,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
deinterpol_help0<5,1>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break;
case 6:
deinterpol_help0<6,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
deinterpol_help0<6,1>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break;
case 7:
deinterpol_help0<7,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
deinterpol_help0<7,1>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break;
#endif
default:
......@@ -759,25 +781,25 @@ template<typename T> class Interpolator
{
#ifdef SPECIAL_CASING
case 1:
deinterpol_help0<1,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
deinterpol_help0<1,3>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break;
case 2:
deinterpol_help0<2,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
deinterpol_help0<2,3>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break;
case 3:
deinterpol_help0<3,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
deinterpol_help0<3,3>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break;
case 4:
deinterpol_help0<4,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
deinterpol_help0<4,3>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break;
case 5:
deinterpol_help0<5,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
deinterpol_help0<5,3>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break;
case 6:
deinterpol_help0<6,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
deinterpol_help0<6,3>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break;
case 7:
deinterpol_help0<7,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
deinterpol_help0<7,3>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break;
#endif
default:
......
......@@ -25,6 +25,7 @@
#include <cmath>
#include <vector>
#include <array>
#include "ducc0/infra/simd.h"
#include "ducc0/infra/error_handling.h"
#include "ducc0/infra/threading.h"
#include "ducc0/math/constants.h"
......@@ -36,6 +37,28 @@ namespace detail_es_kernel {
using namespace std;
template<typename T> T esk (T v, T beta)
{
auto tmp = (1-v)*(1+v);
auto tmp2 = tmp>=0;
return tmp2*exp(T(beta)*(sqrt(tmp*tmp2)-1));
}
template<typename T> class exponator
{
public:
T operator()(T val) { return exp(val); }
};
template<typename T> native_simd<T> esk (native_simd<T> v, T beta)
{
auto tmp = (T(1)-v)*(T(1)+v);
auto mask = tmp<T(0);
where(mask,tmp)*=T(0);
auto res = (beta*(sqrt(tmp)-T(1))).apply(exponator<T>());
where (mask,res)*=T(0);
return res;
}
/* This class implements the "exponential of semicircle" gridding kernel
described in https://arxiv.org/abs/1808.06736 */
class ES_Kernel
......@@ -62,11 +85,11 @@ class ES_Kernel
: ES_Kernel(supp_, 2., nthreads){}
template<typename T> T operator()(T v) const
{
auto tmp = (1-v)*(1+v);
auto tmp2 = tmp>=0;
return tmp2*exp(T(beta)*(sqrt(tmp*tmp2)-1));
}
{ return esk(v, T(beta)); }
template<typename T> native_simd<T> operator()(native_simd<T> v) const
{ return esk(v, T(beta)); }
/* Compute correction factors for the ES gridding kernel
This implementation follows eqs. (3.8) to (3.10) of Barnett et al. 2018 */
double corfac(double v) const
......@@ -138,6 +161,7 @@ class ES_Kernel
}
using detail_es_kernel::esk;
using detail_es_kernel::ES_Kernel;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment