diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index 006eebfb2ca705aca6fea160d2fa650bc9280a25..474a7a8325643eeaead13d3e91d621c4489aef6f 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -79,15 +79,30 @@ using stride_t = vector; constexpr bool FORWARD = true, BACKWARD = false; +// only enable vector support for gcc>=5.0 and clang>=5.0 +#ifndef POCKETFFT_NO_VECTORS +#define POCKETFFT_NO_VECTORS +#if defined(__INTEL_COMPILER) +// do nothing. This is necessary because this compiler also sets __GNUC__ +#elif defined(__clang__) +#if __clang__>=5 +#undef POCKETFFT_NO_VECTORS +#endif +#elif defined(__GNUC__) +#if __GNUC__>=5 +#undef POCKETFFT_NO_VECTORS +#endif +#endif +#endif + #ifndef POCKETFFT_NO_VECTORS #if (defined(__AVX512F__)) -#define HAVE_VECSUPPORT constexpr int VBYTELEN=64; #elif (defined(__AVX__)) -#define HAVE_VECSUPPORT constexpr int VBYTELEN=32; #elif (defined(__SSE2__)) -#define HAVE_VECSUPPORT +constexpr int VBYTELEN=16; +#elif (defined(__VSX__)) constexpr int VBYTELEN=16; #else #define POCKETFFT_NO_VECTORS @@ -191,8 +206,9 @@ template struct cmplx { template auto special_mul (const cmplx &other) const -> cmplx { - return bwd ? cmplx(r*other.r-i*other.i, r*other.i + i*other.r) - : cmplx(r*other.r+i*other.i, i*other.r - r*other.i); + using Tres = cmplx; + return bwd ? Tres(r*other.r-i*other.i, r*other.i + i*other.r) + : Tres(r*other.r+i*other.i, i*other.r - r*other.i); } }; template void PMC(cmplx &a, cmplx &b, @@ -203,8 +219,8 @@ template cmplx conj(const cmplx &a) template void ROT90(cmplx &a) { auto tmp_=a.r; a.r=-a.i; a.i=tmp_; } -template void ROTM90(cmplx &a) - { auto tmp_=-a.r; a.r=a.i; a.i=tmp_; } +template void ROTX90(cmplx &a) + { auto tmp_= bwd ? a.r : -a.r; a.r = bwd ? -a.i : a.i; a.i=tmp_; } // // twiddle factor section @@ -223,19 +239,19 @@ template class sincos_2pibyn using Thigh = typename TypeSelectorsizeof(double))>::type; arr data; - // adapted from https://stackoverflow.com/questions/42792939/ - // CAUTION: this function only works for arguments in the range - // [-0.25; 0.25]! void my_sincosm1pi (Thigh a_, Thigh *restrict res) { if (sizeof(Thigh)>sizeof(double)) // don't have the code for long double { - Thigh pi = Thigh(3.141592653589793238462643383279502884197L); - res[1] = sin(pi*a_); - auto s = res[1]; + constexpr Thigh pi = Thigh(3.141592653589793238462643383279502884197L); + auto s = sin(pi*a_); + res[1] = s; res[0] = (s*s)/(-sqrt((1-s)*(1+s))-1); return; } + // adapted from https://stackoverflow.com/questions/42792939/ + // CAUTION: this function only works for arguments in the range + // [-0.25; 0.25]! double a = double(a_); double s = a * a; /* Approximate cos(pi*x)-1 for x in [-0.25,0.25] */ @@ -612,7 +628,7 @@ template void pass4 (size_t ido, size_t l1, T t1, t2, t3, t4; PMC(t2,t1,CC(0,0,k),CC(0,2,k)); PMC(t3,t4,CC(0,1,k),CC(0,3,k)); - bwd ? ROT90(t4) : ROTM90(t4); + ROTX90(t4); PMC(CH(0,k,0),CH(0,k,2),t2,t3); PMC(CH(0,k,1),CH(0,k,3),t1,t4); } @@ -623,7 +639,7 @@ template void pass4 (size_t ido, size_t l1, T t1, t2, t3, t4; PMC(t2,t1,CC(0,0,k),CC(0,2,k)); PMC(t3,t4,CC(0,1,k),CC(0,3,k)); - bwd ? ROT90(t4) : ROTM90(t4); + ROTX90(t4); PMC(CH(0,k,0),CH(0,k,2),t2,t3); PMC(CH(0,k,1),CH(0,k,3),t1,t4); } @@ -633,13 +649,12 @@ template void pass4 (size_t ido, size_t l1, T cc0=CC(i,0,k), cc1=CC(i,1,k),cc2=CC(i,2,k),cc3=CC(i,3,k); PMC(t2,t1,cc0,cc2); PMC(t3,t4,cc1,cc3); - bwd ? ROT90(t4) : ROTM90(t4); - cmplx wa0=WA(0,i), wa1=WA(1,i),wa2=WA(2,i); + ROTX90(t4); PMC(CH(i,k,0),c3,t2,t3); PMC(c2,c4,t1,t4); - CH(i,k,1) = c2.template special_mul(wa0); - CH(i,k,2) = c3.template special_mul(wa1); - CH(i,k,3) = c4.template special_mul(wa2); + CH(i,k,1) = c2.template special_mul(WA(0,i)); + CH(i,k,2) = c3.template special_mul(WA(1,i)); + CH(i,k,3) = c4.template special_mul(WA(2,i)); } } } @@ -779,6 +794,107 @@ template void pass7(size_t ido, size_t l1, #undef PARTSTEP7a #undef PREP7 +template void ROTX45(T &a) + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + if (bwd) + { auto tmp_=a.r; a.r=hsqt2*(a.r-a.i); a.i=hsqt2*(a.i+tmp_); } + else + { auto tmp_=a.r; a.r=hsqt2*(a.r+a.i); a.i=hsqt2*(a.i-tmp_); } + } +template void ROTX135(T &a) + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + if (bwd) + { auto tmp_=a.r; a.r=hsqt2*(-a.r-a.i); a.i=hsqt2*(tmp_-a.i); } + else + { auto tmp_=a.r; a.r=hsqt2*(a.i-a.r); a.i=hsqt2*(-tmp_-a.i); } + } +template inline void PMINPLACE(T &a, T &b) + { T t = a; a.r+=b.r; a.i+=b.i; b.r=t.r-b.r; b.i=t.i-b.i; } + +template void pass8 (size_t ido, size_t l1, + const T * restrict cc, T * restrict ch, const cmplx * restrict wa) + { + constexpr size_t cdim=8; + + if (ido==1) + for (size_t k=0; k(a6); + ROTX90(a7); + PMINPLACE(a0,a2); + PMINPLACE(a1,a3); + PMINPLACE(a4,a6); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX90(a3); + ROTX135(a7); + PMC(CH(0,k,0),CH(0,k,4),a0,a1); + PMC(CH(0,k,1),CH(0,k,5),a4,a5); + PMC(CH(0,k,2),CH(0,k,6),a2,a3); + PMC(CH(0,k,3),CH(0,k,7),a6,a7); + } + else + for (size_t k=0; k(a6); + ROTX90(a7); + PMINPLACE(a0,a2); + PMINPLACE(a1,a3); + PMINPLACE(a4,a6); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX90(a3); + ROTX135(a7); + PMC(CH(0,k,0),CH(0,k,4),a0,a1); + PMC(CH(0,k,1),CH(0,k,5),a4,a5); + PMC(CH(0,k,2),CH(0,k,6),a2,a3); + PMC(CH(0,k,3),CH(0,k,7),a6,a7); + + for (size_t i=1; i(a6); + ROTX90(a7); + PMINPLACE(a0,a2); + PMINPLACE(a1,a3); + PMINPLACE(a4,a6); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX90(a3); + ROTX135(a7); + PMINPLACE(a0,a1); + PMINPLACE(a2,a3); + PMINPLACE(a4,a5); + PMINPLACE(a6,a7); + CH(i,k,0) = a0; + CH(i,k,1) = a4.template special_mul(WA(0,i)); + CH(i,k,2) = a2.template special_mul(WA(1,i)); + CH(i,k,3) = a6.template special_mul(WA(2,i)); + CH(i,k,4) = a1.template special_mul(WA(3,i)); + CH(i,k,5) = a5.template special_mul(WA(4,i)); + CH(i,k,6) = a3.template special_mul(WA(5,i)); + CH(i,k,7) = a7.template special_mul(WA(6,i)); + } + } + } + + #define PREP11(idx) \ T t1 = CC(idx,0,k), t2, t3, t4, t5, t6, t7, t8, t9, t10, t11; \ PMC (t2,t11,CC(idx,1,k),CC(idx,10,k)); \ @@ -979,6 +1095,8 @@ template void pass_all(T c[], T0 fct) size_t ido = length/l2; if (ip==4) pass4 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==8) + pass8(ido, l1, p1, p2, fact[k1].tw); else if(ip==2) pass2(ido, l1, p1, p2, fact[k1].tw); else if(ip==3) @@ -1026,6 +1144,8 @@ template void pass_all(T c[], T0 fct) NOINLINE void factorize() { size_t len=length; + while ((len&7)==0) + { add_factor(8); len>>=3; } while ((len&3)==0) { add_factor(4); len>>=2; } if ((len&1)==0) @@ -1874,12 +1994,12 @@ template void radbg(size_t ido, size_t ip, size_t l1, fact[k].tws=ptr; ptr+=2*ip; fact[k].tws[0] = 1.; fact[k].tws[1] = 0.; - for (size_t i=1; i<=(ip>>1); ++i) + for (size_t i=2, ic=2*ip-2; i<=ic; i+=2, ic-=2) { - fact[k].tws[2*i ] = twid[2*i*(length/ip)]; - fact[k].tws[2*i+1] = twid[2*i*(length/ip)+1]; - fact[k].tws[2*(ip-i) ] = twid[2*i*(length/ip)]; - fact[k].tws[2*(ip-i)+1] = -twid[2*i*(length/ip)+1]; + fact[k].tws[i ] = twid[i*(length/ip)]; + fact[k].tws[i+1] = twid[i*(length/ip)+1]; + fact[k].tws[ic] = twid[i*(length/ip)]; + fact[k].tws[ic+1] = -twid[i*(length/ip)+1]; } } l1*=ip; @@ -2192,7 +2312,7 @@ template class multi_iter bool contiguous_out() const { return str_o==sizeof(To); } }; -#if defined(HAVE_VECSUPPORT) +#ifndef POCKETFFT_NO_VECTORS template struct VTYPE {}; template<> struct VTYPE { @@ -2261,7 +2381,7 @@ template NOINLINE void general_c( { auto storage = alloc_tmp(in.shape(), len, sizeof(cmplx)); multi_iter, cmplx> it(iax==0? in : out, out, axes[iax]); -#if defined(HAVE_VECSUPPORT) +#ifndef POCKETFFT_NO_VECTORS if (vlen>1) while (it.remaining()>=vlen) { @@ -2324,7 +2444,7 @@ template NOINLINE void general_hartley( { auto storage = alloc_tmp(in.shape(), len, sizeof(T)); multi_iter it(iax==0 ? in : out, out, axes[iax]); -#if defined(HAVE_VECSUPPORT) +#ifndef POCKETFFT_NO_VECTORS if (vlen>1) while (it.remaining()>=vlen) { @@ -2382,7 +2502,7 @@ template NOINLINE void general_r2c( { auto storage = alloc_tmp(in.shape(), len, sizeof(T)); multi_iter> it(in, out, axis); -#if defined(HAVE_VECSUPPORT) +#ifndef POCKETFFT_NO_VECTORS if (vlen>1) while (it.remaining()>=vlen) { @@ -2433,7 +2553,7 @@ template NOINLINE void general_c2r( { auto storage = alloc_tmp(out.shape(), len, sizeof(T)); multi_iter, T> it(in, out, axis); -#if defined(HAVE_VECSUPPORT) +#ifndef POCKETFFT_NO_VECTORS if (vlen>1) while (it.remaining()>=vlen) { @@ -2489,7 +2609,7 @@ template NOINLINE void general_r( { auto storage = alloc_tmp(in.shape(), len, sizeof(T)); multi_iter it(in, out, axis); -#if defined(HAVE_VECSUPPORT) +#ifndef POCKETFFT_NO_VECTORS if (vlen>1) while (it.remaining()>=vlen) { @@ -2532,7 +2652,6 @@ template NOINLINE void general_r( } #undef POCKETFFT_NTHREADS -#undef HAVE_VECSUPPORT template void c2c(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, bool forward,