Commit aaf6dfcd authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'tweak_es_kernel' into 'ducc0'

Tweak ES kernel

See merge request !19
parents 1ef56637 6d45cc1d
Pipeline #77290 passed with stages
in 13 minutes and 6 seconds
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "ducc0/infra/threading.h" #include "ducc0/infra/threading.h"
#include "ducc0/infra/useful_macros.h" #include "ducc0/infra/useful_macros.h"
#include "ducc0/infra/mav.h" #include "ducc0/infra/mav.h"
#include "ducc0/infra/simd.h"
#include "ducc0/math/es_kernel.h" #include "ducc0/math/es_kernel.h"
namespace ducc0 { namespace ducc0 {
...@@ -536,8 +537,11 @@ template<typename T, typename T2=complex<T>> class Helper ...@@ -536,8 +537,11 @@ template<typename T, typename T2=complex<T>> class Helper
public: public:
const T2 *p0r; const T2 *p0r;
T2 *p0w; T2 *p0w;
T kernel[64] DUCC0_ALIGNED(64);
static constexpr size_t vlen=native_simd<T>::size(); static constexpr size_t vlen=native_simd<T>::size();
union {
T scalar[64];
native_simd<T> simd[64/vlen];
} kernel;
Helper(const GridderConfig &gconf_, const T2 *grid_r_, T2 *grid_w_, Helper(const GridderConfig &gconf_, const T2 *grid_r_, T2 *grid_w_,
vector<std::mutex> &locks_, double w0_=-1, double dw_=-1) vector<std::mutex> &locks_, double w0_=-1, double dw_=-1)
...@@ -551,13 +555,16 @@ template<typename T, typename T2=complex<T>> class Helper ...@@ -551,13 +555,16 @@ template<typename T, typename T2=complex<T>> class Helper
w0(w0_), w0(w0_),
xdw(T(1)/dw_), xdw(T(1)/dw_),
nexp(2*supp + do_w_gridding), nexp(2*supp + do_w_gridding),
nvecs(vlen*((nexp+vlen-1)/vlen)), nvecs((nexp+vlen-1)/vlen),
locks(locks_) locks(locks_)
{} {
MR_assert(2*supp+1<=64, "support too large");
for (auto &v: kernel.simd) v=0;
}
~Helper() { if (grid_w) dump(); } ~Helper() { if (grid_w) dump(); }
int lineJump() const { return sv; } int lineJump() const { return sv; }
T Wfac() const { return kernel[2*supp]; } T Wfac() const { return kernel.scalar[2*supp]; }
void prep(const UVW &in) void prep(const UVW &in)
{ {
double u, v; double u, v;
...@@ -567,20 +574,13 @@ template<typename T, typename T2=complex<T>> class Helper ...@@ -567,20 +574,13 @@ template<typename T, typename T2=complex<T>> class Helper
double y0 = xsupp*(iv0-v); double y0 = xsupp*(iv0-v);
for (int i=0; i<supp; ++i) for (int i=0; i<supp; ++i)
{ {
kernel[i ] = T(x0+i*xsupp); kernel.scalar[i ] = T(x0+i*xsupp);
kernel[i+supp] = T(y0+i*xsupp); kernel.scalar[i+supp] = T(y0+i*xsupp);
} }
if (do_w_gridding) if (do_w_gridding)
kernel[2*supp] = min(T(1), T(xdw*xsupp*abs(w0-in.w))); kernel.scalar[2*supp] = T(xdw*xsupp*abs(w0-in.w));
for (size_t i=nexp; i<nvecs; ++i)
kernel[i]=0;
for (size_t i=0; i<nvecs; ++i)
{
kernel[i] = T(1) - kernel[i]*kernel[i];
kernel[i] = (kernel[i]<0) ? T(-200.) : beta*(sqrt(kernel[i])-T(1));
}
for (size_t i=0; i<nvecs; ++i) for (size_t i=0; i<nvecs; ++i)
kernel[i] = exp(kernel[i]); kernel.simd[i] = esk(kernel.simd[i], beta);
if ((iu0<bu0) || (iv0<bv0) || (iu0+supp>bu0+su) || (iv0+supp>bv0+sv)) if ((iu0<bu0) || (iv0<bv0) || (iu0+supp>bu0+su) || (iv0+supp>bv0+sv))
{ {
if (grid_w) { dump(); fill(wbuf.begin(), wbuf.end(), T(0)); } if (grid_w) { dump(); fill(wbuf.begin(), wbuf.end(), T(0)); }
...@@ -681,8 +681,8 @@ template<typename T, typename Serv> void x2grid_c ...@@ -681,8 +681,8 @@ template<typename T, typename Serv> void x2grid_c
{ {
Helper<T> hlp(gconf, nullptr, grid.vdata(), locks, w0, dw); Helper<T> hlp(gconf, nullptr, grid.vdata(), locks, w0, dw);
int jump = hlp.lineJump(); int jump = hlp.lineJump();
const T * DUCC0_RESTRICT ku = hlp.kernel; const T * DUCC0_RESTRICT ku = hlp.kernel.scalar;
const T * DUCC0_RESTRICT kv = hlp.kernel+supp; const T * DUCC0_RESTRICT kv = hlp.kernel.scalar+supp;
while (auto rng=sched.getNext()) for(auto ipart=rng.lo; ipart<rng.hi; ++ipart) while (auto rng=sched.getNext()) for(auto ipart=rng.lo; ipart<rng.hi; ++ipart)
{ {
...@@ -729,8 +729,8 @@ template<typename T, typename Serv> void grid2x_c ...@@ -729,8 +729,8 @@ template<typename T, typename Serv> void grid2x_c
{ {
Helper<T> hlp(gconf, grid.data(), nullptr, locks, w0, dw); Helper<T> hlp(gconf, grid.data(), nullptr, locks, w0, dw);
int jump = hlp.lineJump(); int jump = hlp.lineJump();
const T * DUCC0_RESTRICT ku = hlp.kernel; const T * DUCC0_RESTRICT ku = hlp.kernel.scalar;
const T * DUCC0_RESTRICT kv = hlp.kernel+supp; const T * DUCC0_RESTRICT kv = hlp.kernel.scalar+supp;
while (auto rng=sched.getNext()) for(auto ipart=rng.lo; ipart<rng.hi; ++ipart) while (auto rng=sched.getNext()) for(auto ipart=rng.lo; ipart<rng.hi; ++ipart)
{ {
......
...@@ -220,9 +220,13 @@ template<typename T> class Interpolator ...@@ -220,9 +220,13 @@ template<typename T> class Interpolator
mapper[itheta*ncp+iphi].push_back(i); mapper[itheta*ncp+iphi].push_back(i);
} }
size_t cnt=0; size_t cnt=0;
for (const auto &vec: mapper) for (auto &vec: mapper)
{
for (auto i:vec) for (auto i:vec)
idx[cnt++] = i; idx[cnt++] = i;
vec.clear();
vec.shrink_to_fit();
}
return idx; return idx;
} }
...@@ -434,7 +438,14 @@ template<typename T> class Interpolator ...@@ -434,7 +438,14 @@ template<typename T> class Interpolator
auto idx = getIdx(ptg); auto idx = getIdx(ptg);
execStatic(idx.size(), nthreads, 0, [&](Scheduler &sched) execStatic(idx.size(), nthreads, 0, [&](Scheduler &sched)
{ {
vector<T> wt(supp), wp(supp); union {
native_simd<T> simd[64/vl];
T scalar[64];
} kdata;
T *wt(kdata.scalar), *wp(kdata.scalar+supp);
size_t nvec = (2*supp+vl-1)/vl;
for (auto &v: kdata.simd) v=0;
MR_assert(2*supp<=64, "support too large");
vector<T> psiarr(2*kmax+1); vector<T> psiarr(2*kmax+1);
#ifdef SIMD_INTERPOL #ifdef SIMD_INTERPOL
vector<native_simd<T>> psiarr2((2*kmax+1+vl-1)/vl); vector<native_simd<T>> psiarr2((2*kmax+1+vl-1)/vl);
...@@ -446,11 +457,13 @@ template<typename T> class Interpolator ...@@ -446,11 +457,13 @@ template<typename T> class Interpolator
T f0=T(0.5*supp+ptg(i,0)*xdtheta); T f0=T(0.5*supp+ptg(i,0)*xdtheta);
size_t i0 = size_t(f0+T(1)); size_t i0 = size_t(f0+T(1));
for (size_t t=0; t<supp; ++t) for (size_t t=0; t<supp; ++t)
wt[t] = kernel((t+i0-f0)*delta - 1); wt[t] = (t+i0-f0)*delta - 1;
T f1=T(0.5)*supp+ptg(i,1)*xdphi; T f1=T(0.5)*supp+ptg(i,1)*xdphi;
size_t i1 = size_t(f1+1.); size_t i1 = size_t(f1+1.);
for (size_t t=0; t<supp; ++t) for (size_t t=0; t<supp; ++t)
wp[t] = kernel((t+i1-f1)*delta - 1); wp[t] = (t+i1-f1)*delta - 1;
for (size_t t=0; t<nvec; ++t)
kdata.simd[t] = kernel(kdata.simd[t]);
psiarr[0]=1.; psiarr[0]=1.;
double psi=ptg(i,2); double psi=ptg(i,2);
double cpsi=cos(psi), spsi=sin(psi); double cpsi=cos(psi), spsi=sin(psi);
...@@ -484,25 +497,25 @@ template<typename T> class Interpolator ...@@ -484,25 +497,25 @@ template<typename T> class Interpolator
{ {
#ifdef SPECIAL_CASING #ifdef SPECIAL_CASING
case 1: case 1:
interpol_help0<1,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i); interpol_help0<1,1>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break; break;
case 2: case 2:
interpol_help0<2,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i); interpol_help0<2,1>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break; break;
case 3: case 3:
interpol_help0<3,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i); interpol_help0<3,1>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break; break;
case 4: case 4:
interpol_help0<4,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i); interpol_help0<4,1>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break; break;
case 5: case 5:
interpol_help0<5,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i); interpol_help0<5,1>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break; break;
case 6: case 6:
interpol_help0<6,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i); interpol_help0<6,1>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break; break;
case 7: case 7:
interpol_help0<7,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i); interpol_help0<7,1>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break; break;
#endif #endif
default: default:
...@@ -549,25 +562,25 @@ template<typename T> class Interpolator ...@@ -549,25 +562,25 @@ template<typename T> class Interpolator
{ {
#ifdef SPECIAL_CASING #ifdef SPECIAL_CASING
case 1: case 1:
interpol_help0<1,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i); interpol_help0<1,3>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break; break;
case 2: case 2:
interpol_help0<2,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i); interpol_help0<2,3>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break; break;
case 3: case 3:
interpol_help0<3,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i); interpol_help0<3,3>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break; break;
case 4: case 4:
interpol_help0<4,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i); interpol_help0<4,3>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break; break;
case 5: case 5:
interpol_help0<5,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i); interpol_help0<5,3>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break; break;
case 6: case 6:
interpol_help0<6,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i); interpol_help0<6,3>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break; break;
case 7: case 7:
interpol_help0<7,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i); interpol_help0<7,3>(wt, wp, p, d0, d1, psiarr2.data(), res, i);
break; break;
#endif #endif
default: default:
...@@ -629,7 +642,14 @@ template<typename T> class Interpolator ...@@ -629,7 +642,14 @@ template<typename T> class Interpolator
execStatic(idx.size(), nthreads, 0, [&](Scheduler &sched) execStatic(idx.size(), nthreads, 0, [&](Scheduler &sched)
{ {
size_t b_theta=99999999999999, b_phi=9999999999999999; size_t b_theta=99999999999999, b_phi=9999999999999999;
vector<T> wt(supp), wp(supp); union {
native_simd<T> simd[64/vl];
T scalar[64];
} kdata;
T *wt(kdata.scalar), *wp(kdata.scalar+supp);
size_t nvec = (2*supp+vl-1)/vl;
for (auto &v: kdata.simd) v=0;
MR_assert(2*supp<=64, "support too large");
vector<T> psiarr(2*kmax+1); vector<T> psiarr(2*kmax+1);
#ifdef SIMD_INTERPOL #ifdef SIMD_INTERPOL
vector<native_simd<T>> psiarr2((2*kmax+1+vl-1)/vl); vector<native_simd<T>> psiarr2((2*kmax+1+vl-1)/vl);
...@@ -641,11 +661,13 @@ template<typename T> class Interpolator ...@@ -641,11 +661,13 @@ template<typename T> class Interpolator
T f0=T(0.5)*supp+ptg(i,0)*xdtheta; T f0=T(0.5)*supp+ptg(i,0)*xdtheta;
size_t i0 = size_t(f0+1.); size_t i0 = size_t(f0+1.);
for (size_t t=0; t<supp; ++t) for (size_t t=0; t<supp; ++t)
wt[t] = kernel((t+i0-f0)*delta - 1); wt[t] = (t+i0-f0)*delta - 1;
T f1=T(0.5)*supp+ptg(i,1)*xdphi; T f1=T(0.5)*supp+ptg(i,1)*xdphi;
size_t i1 = size_t(f1+1.); size_t i1 = size_t(f1+1.);
for (size_t t=0; t<supp; ++t) for (size_t t=0; t<supp; ++t)
wp[t] = kernel((t+i1-f1)*delta - 1); wp[t] = (t+i1-f1)*delta - 1;
for (size_t t=0; t<nvec; ++t)
kdata.simd[t] = kernel(kdata.simd[t]);
psiarr[0]=1.; psiarr[0]=1.;
double psi=ptg(i,2); double psi=ptg(i,2);
double cpsi=cos(psi), spsi=sin(psi); double cpsi=cos(psi), spsi=sin(psi);
...@@ -696,25 +718,25 @@ template<typename T> class Interpolator ...@@ -696,25 +718,25 @@ template<typename T> class Interpolator
{ {
#ifdef SPECIAL_CASING #ifdef SPECIAL_CASING
case 1: case 1:
deinterpol_help0<1,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i); deinterpol_help0<1,1>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break; break;
case 2: case 2:
deinterpol_help0<2,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i); deinterpol_help0<2,1>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break; break;
case 3: case 3:
deinterpol_help0<3,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i); deinterpol_help0<3,1>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break; break;
case 4: case 4:
deinterpol_help0<4,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i); deinterpol_help0<4,1>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break; break;
case 5: case 5:
deinterpol_help0<5,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i); deinterpol_help0<5,1>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break; break;
case 6: case 6:
deinterpol_help0<6,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i); deinterpol_help0<6,1>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break; break;
case 7: case 7:
deinterpol_help0<7,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i); deinterpol_help0<7,1>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break; break;
#endif #endif
default: default:
...@@ -759,25 +781,25 @@ template<typename T> class Interpolator ...@@ -759,25 +781,25 @@ template<typename T> class Interpolator
{ {
#ifdef SPECIAL_CASING #ifdef SPECIAL_CASING
case 1: case 1:
deinterpol_help0<1,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i); deinterpol_help0<1,3>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break; break;
case 2: case 2:
deinterpol_help0<2,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i); deinterpol_help0<2,3>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break; break;
case 3: case 3:
deinterpol_help0<3,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i); deinterpol_help0<3,3>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break; break;
case 4: case 4:
deinterpol_help0<4,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i); deinterpol_help0<4,3>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break; break;
case 5: case 5:
deinterpol_help0<5,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i); deinterpol_help0<5,3>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break; break;
case 6: case 6:
deinterpol_help0<6,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i); deinterpol_help0<6,3>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break; break;
case 7: case 7:
deinterpol_help0<7,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i); deinterpol_help0<7,3>(wt, wp, p, d0, d1, psiarr2.data(), data, i);
break; break;
#endif #endif
default: default:
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <cmath> #include <cmath>
#include <vector> #include <vector>
#include <array> #include <array>
#include "ducc0/infra/simd.h"
#include "ducc0/infra/error_handling.h" #include "ducc0/infra/error_handling.h"
#include "ducc0/infra/threading.h" #include "ducc0/infra/threading.h"
#include "ducc0/math/constants.h" #include "ducc0/math/constants.h"
...@@ -36,6 +37,28 @@ namespace detail_es_kernel { ...@@ -36,6 +37,28 @@ namespace detail_es_kernel {
using namespace std; using namespace std;
template<typename T> T esk (T v, T beta)
{
auto tmp = (1-v)*(1+v);
auto tmp2 = tmp>=0;
return tmp2*exp(T(beta)*(sqrt(tmp*tmp2)-1));
}
template<typename T> class exponator
{
public:
T operator()(T val) { return exp(val); }
};
template<typename T> native_simd<T> esk (native_simd<T> v, T beta)
{
auto tmp = (T(1)-v)*(T(1)+v);
auto mask = tmp<T(0);
where(mask,tmp)*=T(0);
auto res = (beta*(sqrt(tmp)-T(1))).apply(exponator<T>());
where (mask,res)*=T(0);
return res;
}
/* This class implements the "exponential of semicircle" gridding kernel /* This class implements the "exponential of semicircle" gridding kernel
described in https://arxiv.org/abs/1808.06736 */ described in https://arxiv.org/abs/1808.06736 */
class ES_Kernel class ES_Kernel
...@@ -62,11 +85,11 @@ class ES_Kernel ...@@ -62,11 +85,11 @@ class ES_Kernel
: ES_Kernel(supp_, 2., nthreads){} : ES_Kernel(supp_, 2., nthreads){}
template<typename T> T operator()(T v) const template<typename T> T operator()(T v) const
{ { return esk(v, T(beta)); }
auto tmp = (1-v)*(1+v);
auto tmp2 = tmp>=0; template<typename T> native_simd<T> operator()(native_simd<T> v) const
return tmp2*exp(T(beta)*(sqrt(tmp*tmp2)-1)); { return esk(v, T(beta)); }
}
/* Compute correction factors for the ES gridding kernel /* Compute correction factors for the ES gridding kernel
This implementation follows eqs. (3.8) to (3.10) of Barnett et al. 2018 */ This implementation follows eqs. (3.8) to (3.10) of Barnett et al. 2018 */
double corfac(double v) const double corfac(double v) const
...@@ -138,6 +161,7 @@ class ES_Kernel ...@@ -138,6 +161,7 @@ class ES_Kernel
} }
using detail_es_kernel::esk;
using detail_es_kernel::ES_Kernel; using detail_es_kernel::ES_Kernel;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment