Commit cfd1df1e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

some rearrangements and speedups

parent 97858e67
......@@ -46,6 +46,21 @@ namespace detail_gridder {
using namespace std;
template<typename T> complex<T> hsum_cmplx(native_simd<T> vr, native_simd<T> vi)
{ return complex<T>(reduce(vr, std::plus<>()), reduce(vi, std::plus<>())); }
#if (defined(__AVX__) && (!defined(__AVX512F__)))
inline complex<float> hsum_cmplx(native_simd<float> vr, native_simd<float> vi)
{
auto t1 = _mm256_hadd_ps(vr, vi);
auto t2 = _mm_hadd_ps(_mm256_extractf128_ps(t1, 0), _mm256_extractf128_ps(t1, 1));
t2 += _mm_shuffle_ps(t2, t2, _MM_SHUFFLE(1,0,3,2));
return complex<float>(t2[0], t2[1]);
//FIXME perhaps some shuffling?
return complex<float>(t2[0]+t2[2], t2[1]+t2[3]);
}
#endif
template<size_t ndim> void checkShape
(const array<size_t, ndim> &shp1, const array<size_t, ndim> &shp2)
{ MR_assert(shp1==shp2, "shape mismatch"); }
......@@ -760,15 +775,15 @@ void apply_global_corrections(const GridderConfig<T> &gconf, mav<T,2> &dirty)
timers.push("global corrections");
double x0 = -0.5*nxdirty*pixsize_x,
y0 = -0.5*nydirty*pixsize_y;
auto cfu = gconf.krn->corfunc(nx_dirty/2+1, 1./gconf.Nu(), nthreads);
auto cfv = gconf.krn->corfunc(ny_dirty/2+1, 1./gconf.Nv(), nthreads);
execStatic(nx_dirty/2+1, nthreads, 0, [&](Scheduler &sched)
auto cfu = gconf.krn->corfunc(nxdirty/2+1, 1./gconf.Nu(), nthreads);
auto cfv = gconf.krn->corfunc(nydirty/2+1, 1./gconf.Nv(), nthreads);
execStatic(nxdirty/2+1, nthreads, 0, [&](Scheduler &sched)
{
while (auto rng=sched.getNext()) for(auto i=rng.lo; i<rng.hi; ++i)
{
auto fx = T(x0+i*pixsize_x);
fx *= fx;
for (size_t j=0; j<=ny_dirty/2; ++j)
for (size_t j=0; j<=nydirty/2; ++j)
{
auto fy = T(y0+j*pixsize_y);
fy*=fy;
......@@ -791,8 +806,8 @@ void apply_global_corrections(const GridderConfig<T> &gconf, mav<T,2> &dirty)
fct = T(gconf.krn->corfunc(nm1*dw));
}
}
fct *= T(cfu[nx_dirty/2-i]*cfv[ny_dirty/2-j]);
size_t i2 = nx_dirty-i, j2 = ny_dirty-j;
fct *= T(cfu[nxdirty/2-i]*cfv[nydirty/2-j]);
size_t i2 = nxdirty-i, j2 = nydirty-j;
dirty.v(i,j)*=fct;
if ((i>0)&&(i<i2))
{
......@@ -1079,15 +1094,28 @@ template<size_t SUPP, bool wgrid, typename T> [[gnu::hot]] void x2grid_c_helper
native_simd<T> vr(v.real()), vi(v.imag());
for (size_t cu=0; cu<SUPP; ++cu)
{
native_simd<T> tmpr=vr*ku[cu], tmpi=vi*ku[cu];
for (size_t cv=0; cv<NVEC; ++cv)
if constexpr (NVEC==1)
{
auto fct = kv[0]*ku[cu];
auto tr = native_simd<T>::loadu(ptrr);
tr += vr*fct;
tr.storeu(ptrr);
auto ti = native_simd<T>::loadu(ptri);
ti += vi*fct;
ti.storeu(ptri);
}
else
{
auto tr = native_simd<T>::loadu(ptrr+cv*hlp.vlen);
tr += tmpr*kv[cv];
tr.storeu(ptrr+cv*hlp.vlen);
auto ti = native_simd<T>::loadu(ptri+cv*hlp.vlen);
ti += tmpi*kv[cv];
ti.storeu(ptri+cv*hlp.vlen);
native_simd<T> tmpr=vr*ku[cu], tmpi=vi*ku[cu];
for (size_t cv=0; cv<NVEC; ++cv)
{
auto tr = native_simd<T>::loadu(ptrr+cv*hlp.vlen);
tr += tmpr*kv[cv];
tr.storeu(ptrr+cv*hlp.vlen);
auto ti = native_simd<T>::loadu(ptri+cv*hlp.vlen);
ti += tmpi*kv[cv];
ti.storeu(ptri+cv*hlp.vlen);
}
}
ptrr+=jump;
ptri+=jump;
......@@ -1169,18 +1197,28 @@ template<size_t SUPP, bool wgrid, typename T> [[gnu::hot]] void grid2x_c_helper
const auto * DUCC0_RESTRICT ptri = hlp.p0i;
for (size_t cu=0; cu<SUPP; ++cu)
{
native_simd<T> tmpr(0), tmpi(0);
for (size_t cv=0; cv<NVEC; ++cv)
// if constexpr(NVEC==1)
// {
// auto fct = kv[0]*ku[cu];
// rr += native_simd<T>::loadu(ptrr)*fct;
// ri += native_simd<T>::loadu(ptri)*fct;
// }
// else
{
tmpr += kv[cv]*native_simd<T>::loadu(ptrr+hlp.vlen*cv);
tmpi += kv[cv]*native_simd<T>::loadu(ptri+hlp.vlen*cv);
native_simd<T> tmpr(0), tmpi(0);
for (size_t cv=0; cv<NVEC; ++cv)
{
tmpr += kv[cv]*native_simd<T>::loadu(ptrr+hlp.vlen*cv);
tmpi += kv[cv]*native_simd<T>::loadu(ptri+hlp.vlen*cv);
}
rr += ku[cu]*tmpr;
ri += ku[cu]*tmpi;
}
rr += ku[cu]*tmpr;
ri += ku[cu]*tmpi;
ptrr += jump;
ptri += jump;
}
auto r = complex<T>(reduce(rr, std::plus<>()), reduce(ri, std::plus<>()));
auto r = hsum_cmplx(rr,ri);
// auto r = complex<T>(reduce(rr, std::plus<>()), reduce(ri, std::plus<>()));
if (flip) r=conj(r);
if (have_wgt) r*=wgt(row, ch);
vis.v(row, ch) += r;
......@@ -1228,65 +1266,6 @@ template<bool wgrid, typename T> void grid2x_c
par.timers.pop();
}
template<typename T> void apply_global_corrections(const GridderConfig<T> &gconf,
Params<T> &par, mav<T,2> &dirty)
{
par.timers.push("global corrections");
auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty();
size_t nthreads = gconf.Nthreads();
auto psx=gconf.Pixsize_x();
auto psy=gconf.Pixsize_y();
double x0 = -0.5*nx_dirty*psx,
y0 = -0.5*ny_dirty*psy;
auto cfu = gconf.krn->corfunc(nx_dirty/2+1, 1./gconf.Nu(), nthreads);
auto cfv = gconf.krn->corfunc(ny_dirty/2+1, 1./gconf.Nv(), nthreads);
execStatic(nx_dirty/2+1, nthreads, 0, [&](Scheduler &sched)
{
while (auto rng=sched.getNext()) for(auto i=rng.lo; i<rng.hi; ++i)
{
auto fx = T(x0+i*psx);
fx *= fx;
for (size_t j=0; j<=ny_dirty/2; ++j)
{
auto fy = T(y0+j*psy);
fy*=fy;
T fct = 0;
auto tmp = 1-fx-fy;
if (tmp>=0)
{
auto nm1 = (-fx-fy)/(sqrt(tmp)+1); // accurate form of sqrt(1-x-y)-1
fct = T(gconf.krn->corfunc(nm1*par.dw));
if (par.divide_by_n)
fct /= nm1+1;
}
else // beyond the horizon, don't really know what to do here
{
if (par.divide_by_n)
fct=0;
else
{
auto nm1 = sqrt(-tmp)-1;
fct = T(gconf.krn->corfunc(nm1*par.dw));
}
}
fct *= T(cfu[nx_dirty/2-i]*cfv[ny_dirty/2-j]);
size_t i2 = nx_dirty-i, j2 = ny_dirty-j;
dirty.v(i,j)*=fct;
if ((i>0)&&(i<i2))
{
dirty.v(i2,j)*=fct;
if ((j>0)&&(j<j2))
dirty.v(i2,j2)*=fct;
}
if ((j>0)&&(j<j2))
dirty.v(i,j2)*=fct;
}
}
});
par.timers.pop();
}
template<typename T> void x2dirty(
GridderConfig<T> &gconf, const mav<complex<T>,2> &vis, const mav<T,2> &wgt, mav<T,2> &dirty,
Params<T> &par)
......@@ -1308,7 +1287,7 @@ template<typename T> void x2dirty(
gconf.grid2dirty_c_overwrite_wscreen_add(grid, dirty, T(w));
}
// correct for w gridding etc.
apply_global_corrections(gconf, par, dirty);
par.apply_global_corrections(gconf, dirty);
}
else
{
......@@ -1337,7 +1316,7 @@ template<typename T> void dirty2x(
tdirty.apply(dirty, [](T&a, T b) {a=b;});
par.timers.pop();
// correct for w gridding etc.
apply_global_corrections(gconf, par, tdirty);
par.apply_global_corrections(gconf, tdirty);
par.timers.push("allocating grid");
auto grid = mav<complex<T>,2>::build_noncritical({gconf.Nu(),gconf.Nv()});
par.timers.pop();
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment