Commit 7a81a8d0 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more tweaks

parent af6ffa43
......@@ -7,6 +7,7 @@
#define MRUTIL_INTERPOL_NG_H
#define SIMD_INTERPOL
#define SPECIAL_CASING
#include <vector>
#include <complex>
......@@ -354,11 +355,59 @@ template<typename T> class Interpolator
cube.apply([](T &v){v=0.;});
}
#ifdef SIMD_INTERPOL
template<size_t nv, size_t nc> void interpol_help0(const T * MRUTIL_RESTRICT wt,
const T * MRUTIL_RESTRICT wp, const native_simd<T> * MRUTIL_RESTRICT p, size_t d0, size_t d1, const native_simd<T> * MRUTIL_RESTRICT psiarr2, mav<T,2> &res, size_t idx) const
{
array<native_simd<T>,nc> vv;
for (auto &vvv:vv) vvv=0;
for (size_t j=0; j<supp; ++j, p+=d0)
{
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
{
array<native_simd<T>,nc> tvv;
for (auto &vvv:tvv) vvv=0;
for (size_t l=0; l<nv; ++l)
for (size_t c=0; c<nc; ++c)
tvv[c] += p1[l+nv*c]*psiarr2[l];
for (size_t c=0; c<nc; ++c)
vv[c] += (wtj*wp[k])*tvv[c];
}
}
for (size_t c=0; c<nc; ++c)
res.v(idx,c) = reduce(vv[c], std::plus<>());
}
template<size_t nv, size_t nc> void deinterpol_help0(const T * MRUTIL_RESTRICT wt,
const T * MRUTIL_RESTRICT wp, native_simd<T> * MRUTIL_RESTRICT p, size_t d0, size_t d1, const native_simd<T> * MRUTIL_RESTRICT psiarr2, const mav<T,2> &data, size_t idx) const
{
array<native_simd<T>,nc> vv;
for (size_t i=0; i<nc; ++i) vv[i] = data(idx,i);
for (size_t j=0; j<supp; ++j, p+=d0)
{
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
{
native_simd<T> wtjwpk = wtj*wp[k];
for (size_t l=0; l<nv; ++l)
{
auto tmp = wtjwpk*psiarr2[l];
for (size_t c=0; c<nc; ++c)
p1[l+nv*c] += vv[c]*tmp;
}
}
}
}
#endif
void interpol (const mav<T,2> &ptg, mav<T,2> &res) const
{
#ifdef SIMD_INTERPOL
constexpr size_t vl=native_simd<T>::size();
MR_assert(scube.stride(3)==1, "bad stride");
size_t nv=scube.shape(3);
#endif
MR_assert(!adjoint, "cannot be called in adjoint mode");
MR_assert(ptg.shape(0)==res.shape(0), "dimension mismatch");
......@@ -416,22 +465,49 @@ template<typename T> class Interpolator
const native_simd<T> *p=&scube(i0,i1,0,0);
ptrdiff_t d0 = scube.stride(0);
ptrdiff_t d1 = scube.stride(1);
size_t nv=scube.shape(3);
native_simd<T> vv=0.;
for (size_t j=0; j<supp; ++j, p+=d0)
switch (nv)
{
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
#ifdef SPECIAL_CASING
case 1:
interpol_help0<1,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 2:
interpol_help0<2,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 3:
interpol_help0<3,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 4:
interpol_help0<4,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 5:
interpol_help0<5,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 6:
interpol_help0<6,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 7:
interpol_help0<7,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
#endif
default:
{
auto p2=p1;
native_simd<T> tvv=0;
for (size_t l=0; l<nv; ++l,++p2)
tvv += *p2*psiarr2[l];
vv += wtj*wp[k]*tvv;
native_simd<T> vv=0.;
for (size_t j=0; j<supp; ++j, p+=d0)
{
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
{
native_simd<T> tvv=0;
for (size_t l=0; l<nv; ++l)
tvv += p1[l]*psiarr2[l];
vv += wtj*wp[k]*tvv;
}
}
res.v(i,0) = reduce(vv, std::plus<>());
}
}
res.v(i,0) = reduce(vv, std::plus<>());
#endif
}
else // ncomp==3
......@@ -454,29 +530,56 @@ template<typename T> class Interpolator
const native_simd<T> *p=&scube(i0,i1,0,0);
ptrdiff_t d0 = scube.stride(0);
ptrdiff_t d1 = scube.stride(1);
ptrdiff_t d2 = scube.stride(2);
size_t nv=scube.shape(3);
native_simd<T> v0=0., v1=0., v2=0.;
for (size_t j=0; j<supp; ++j, p+=d0)
switch (nv)
{
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
#ifdef SPECIAL_CASING
case 1:
interpol_help0<1,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 2:
interpol_help0<2,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 3:
interpol_help0<3,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 4:
interpol_help0<4,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 5:
interpol_help0<5,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 6:
interpol_help0<6,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 7:
interpol_help0<7,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
#endif
default:
{
auto p2=p1;
native_simd<T> wtjwpk = wtj*wp[k];
for (size_t l=0; l<nv; ++l,++p2)
ptrdiff_t d2 = scube.stride(2);
native_simd<T> v0=0., v1=0., v2=0.;
for (size_t j=0; j<supp; ++j, p+=d0)
{
auto tmp = wtjwpk*psiarr2[l];
v0 += p2[0]*tmp;
v1 += p2[d2]*tmp;
v2 += p2[2*d2]*tmp;
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
{
native_simd<T> wtjwpk = wtj*wp[k];
for (size_t l=0; l<nv; ++l)
{
auto tmp = wtjwpk*psiarr2[l];
v0 += p1[l]*tmp;
v1 += p1[l+d2]*tmp;
v2 += p1[l+2*d2]*tmp;
}
}
}
res.v(i,0) = reduce(v0, std::plus<>());
res.v(i,1) = reduce(v1, std::plus<>());
res.v(i,2) = reduce(v2, std::plus<>());
}
}
res.v(i,0) = reduce(v0, std::plus<>());
res.v(i,1) = reduce(v1, std::plus<>());
res.v(i,2) = reduce(v2, std::plus<>());
#endif
}
}
......@@ -491,6 +594,8 @@ template<typename T> class Interpolator
#ifdef SIMD_INTERPOL
constexpr size_t vl=native_simd<T>::size();
MR_assert(scube.stride(3)==1, "bad stride");
MR_assert(scube.stride(2)==ptrdiff_t(scube.shape(3)), "bad stride");
size_t nv=scube.shape(3);
#endif
MR_assert(adjoint, "can only be called in adjoint mode");
MR_assert(ptg.shape(0)==data.shape(0), "dimension mismatch");
......@@ -569,21 +674,48 @@ template<typename T> class Interpolator
for (size_t l=0; l<2*kmax+1; ++l)
cube.v(i0+j,i1+k,0,l) += val*wt[j]*wp[k]*psiarr[l];
#else
native_simd<T> val = data(i,0);
native_simd<T> *p=&scube.v(i0,i1,0,0);
ptrdiff_t d0 = scube.stride(0);
ptrdiff_t d1 = scube.stride(1);
size_t nv=scube.shape(3);
for (size_t j=0; j<supp; ++j, p+=d0)
switch (nv)
{
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
#ifdef SPECIAL_CASING
case 1:
deinterpol_help0<1,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
break;
case 2:
deinterpol_help0<2,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
break;
case 3:
deinterpol_help0<3,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
break;
case 4:
deinterpol_help0<4,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
break;
case 5:
deinterpol_help0<5,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
break;
case 6:
deinterpol_help0<6,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
break;
case 7:
deinterpol_help0<7,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
break;
#endif
default:
{
auto p2=p1;
native_simd<T> tv = wtj*wp[k]*val;
for (size_t l=0; l<nv; ++l,++p2)
*p2 += tv*psiarr2[l];
native_simd<T> val = data(i,0);
for (size_t j=0; j<supp; ++j, p+=d0)
{
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
{
native_simd<T> tv = wtj*wp[k]*val;
for (size_t l=0; l<nv; ++l)
p1[l] += tv*psiarr2[l];
}
}
}
}
#endif
......@@ -605,26 +737,53 @@ template<typename T> class Interpolator
}
}
#else
native_simd<T> v0=data(i,0), v1=data(i,1), v2=data(i,2);
native_simd<T> *p=&scube.v(i0,i1,0,0);
ptrdiff_t d0 = scube.stride(0);
ptrdiff_t d1 = scube.stride(1);
ptrdiff_t d2 = scube.stride(2);
size_t nv=scube.shape(3);
for (size_t j=0; j<supp; ++j, p+=d0)
switch (nv)
{
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
#ifdef SPECIAL_CASING
case 1:
deinterpol_help0<1,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
break;
case 2:
deinterpol_help0<2,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
break;
case 3:
deinterpol_help0<3,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
break;
case 4:
deinterpol_help0<4,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
break;
case 5:
deinterpol_help0<5,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
break;
case 6:
deinterpol_help0<6,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
break;
case 7:
deinterpol_help0<7,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), data, i);
break;
#endif
default:
{
auto p2=p1;
native_simd<T> wtjwpk = wtj*wp[k];
for (size_t l=0; l<nv; ++l,++p2)
native_simd<T> v0=data(i,0), v1=data(i,1), v2=data(i,2);
ptrdiff_t d2 = scube.stride(2);
for (size_t j=0; j<supp; ++j, p+=d0)
{
auto tmp = wtjwpk*psiarr2[l];
p2[0] += v0*tmp;
p2[d2] += v1*tmp;
p2[2*d2] += v2*tmp;
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
{
native_simd<T> wtjwpk = wtj*wp[k];
for (size_t l=0; l<nv; ++l)
{
auto tmp = wtjwpk*psiarr2[l];
p1[l] += v0*tmp;
p1[l+d2] += v1*tmp;
p1[l+2*d2] += v2*tmp;
}
}
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment