diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index 87c3186047c58530c39d56c6019847d7de407cf2..95e34a906f5ae49f93d5b429819fe3b2a132a8c8 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -266,50 +266,179 @@ template<bool fwd, typename T> void ROTX90(cmplx<T> &a) // twiddle factor section // +/** Approximate sin(pi*a) within the range [-0.25, 0.25] */ +inline float sinpi0(float a) + { + // adapted from https://stackoverflow.com/questions/42792939/ + float s = a * a; + float r = -5.957031250000000000e-01f; + r = fma (r, s, 2.550399541854858398e+00f); + r = fma (r, s, -5.167724132537841797e+00f); + r = (a * s) * r; + return fma (a, 3.141592741012573242e+00f, r); + } + +/** Approximate sin(pi*a) within the range [-0.25, 0.25] */ +inline double sinpi0(double a) + { + // adapted from https://stackoverflow.com/questions/42792939/ + double s = a * a; + double r = 4.6151442520157035e-4; + r = fma (r, s, -7.3700183130883555e-3); + r = fma (r, s, 8.2145868949323936e-2); + r = fma (r, s, -5.9926452893214921e-1); + r = fma (r, s, 2.5501640398732688e+0); + r = fma (r, s, -5.1677127800499516e+0); + s = s * a; + r = r * s; + return fma (a, 3.1415926535897931e+0, r); + } + +/** Approximate cos(pi*x)-1 for x in [-0.25,0.25] */ +inline float cosm1pi0(float a) + { + // adapted from https://stackoverflow.com/questions/42792939/ + float s = a * a; + float r = 2.313842773437500000e-01f; + r = fmaf (r, s, -1.335021972656250000e+00f); + r = fmaf (r, s, 4.058703899383544922e+00f); + r = fmaf (r, s, -4.934802055358886719e+00f); + return r*s; + } + +/** Approximate cos(pi*x)-1 for x in [-0.25,0.25] */ +inline double cosm1pi0(double a) + { + // adapted from https://stackoverflow.com/questions/42792939/ + double s = a * a; + double r = -1.0369917389758117e-4; + r = fma (r, s, 1.9294935641298806e-3); + r = fma (r, s, -2.5806887942825395e-2); + r = fma (r, s, 2.3533063028328211e-1); + r = fma (r, s, -1.3352627688538006e+0); + r = fma (r, s, 4.0587121264167623e+0); + r = fma (r, s, -4.9348022005446790e+0); + return r*s; + } + +template <typename T> void sincosm1pi0(T a_, T *POCKETFFT_RESTRICT res) + { + if (sizeof(T)>sizeof(double)) // don't have the code for long double + { + constexpr T pi = T(3.141592653589793238462643383279502884197L); + auto s = sin(pi*a_); + res[1] = s; + res[0] = (s*s)/(-sqrt((1-s)*(1+s))-1); + return; + } + res[0] = T(cosm1pi0(double(a_))); + res[1] = T(sinpi0(double(a_))); + } + +template <typename T> T sinpi(T a) + { + // reduce argument to primary approximation interval (-0.25, 0.25) + auto r = nearbyint (a + a); // must use IEEE-754 "to nearest" rounding + auto i = (int64_t)r; + auto t = fma (T(-0.5), r, a); + + switch (i%4) + { + case 0: + return sinpi0(t); + case 1: case -3: + return cosm1pi0(t) + T(1.); + case 2: case -2: + return T(0.) - sinpi0(t); + case 3: case -1: + return T(-1.) - cosm1pi0(t); + } + throw runtime_error("cannot happen"); + } + +template <typename T> T cospi(T a) + { + // reduce argument to primary approximation interval (-0.25, 0.25) + auto r = nearbyint (a + a); // must use IEEE-754 "to nearest" rounding + auto i = (int64_t)r; + auto t = fma (T(-0.5), r, a); + + switch (i%4) + { + case 0: + return cosm1pi0(t) + T(1.); + case 1: case -3: + return T(0.) - sinpi0(t); + case 2: case -2: + return T(-1.) - cosm1pi0(t); + case 3: case -1: + return sinpi0(t); + } + throw runtime_error("cannot happen"); + } + +inline long double cospi(long double a) + { + constexpr auto pi = 3.141592653589793238462643383279502884197L; + return sizeof(long double) > sizeof(double) ? cos(a * pi) : + static_cast<long double>(cospi(static_cast<double>(a))); + } + +inline long double sinpi(long double a) + { + constexpr auto pi = 3.141592653589793238462643383279502884197L; + return sizeof(long double) > sizeof(double) ? sin(a * pi) : + static_cast<long double>(sinpi(static_cast<double>(a))); + } + +template <typename T> void sincospi(T a, T *POCKETFFT_RESTRICT res) + { + // reduce argument to primary approximation interval (-0.25, 0.25) + auto r = nearbyint (a + a); // must use IEEE-754 "to nearest" rounding + auto i = (int64_t)r; + auto t = fma (T(-0.5), r, a); + + auto c=cosm1pi0(t)+T(1.); + auto s=sinpi0(t); + + // map results according to quadrant + if (i & 2) + { + s = T(0.)-s; + c = T(0.)-c; + } + if (i & 1) + { + swap(s, c); + c = T(0.)-c; + } + res[0]=c; + res[1]=s; + } + +inline void sincospi(long double a, long double *POCKETFFT_RESTRICT res) + { + if (sizeof(long double) > sizeof(double)) + { + constexpr auto pi = 3.141592653589793238462643383279502884197L; + res[0] = cos(pi * a); + res[1] = sin(pi * a); + } + else + { + double sincos[2]; + sincospi(static_cast<double>(a), sincos); + res[0] = static_cast<long double>(sincos[0]); + res[1] = static_cast<long double>(sincos[1]); + } + } + template<typename T> class sincos_2pibyn { private: using Thigh = typename conditional<(sizeof(T)>sizeof(double)), T, double>::type; arr<T> data; - void my_sincosm1pi (Thigh a_, Thigh *POCKETFFT_RESTRICT res) - { - if (sizeof(Thigh)>sizeof(double)) // don't have the code for long double - { - constexpr Thigh pi = Thigh(3.141592653589793238462643383279502884197L); - auto s = sin(pi*a_); - res[1] = s; - res[0] = (s*s)/(-sqrt((1-s)*(1+s))-1); - return; - } - // adapted from https://stackoverflow.com/questions/42792939/ - // CAUTION: this function only works for arguments in the range - // [-0.25; 0.25]! - double a = double(a_); - double s = a * a; - /* Approximate cos(pi*x)-1 for x in [-0.25,0.25] */ - double r = -1.0369917389758117e-4; - r = fma (r, s, 1.9294935641298806e-3); - r = fma (r, s, -2.5806887942825395e-2); - r = fma (r, s, 2.3533063028328211e-1); - r = fma (r, s, -1.3352627688538006e+0); - r = fma (r, s, 4.0587121264167623e+0); - r = fma (r, s, -4.9348022005446790e+0); - double c = r*s; - /* Approximate sin(pi*x) for x in [-0.25,0.25] */ - r = 4.6151442520157035e-4; - r = fma (r, s, -7.3700183130883555e-3); - r = fma (r, s, 8.2145868949323936e-2); - r = fma (r, s, -5.9926452893214921e-1); - r = fma (r, s, 2.5501640398732688e+0); - r = fma (r, s, -5.1677127800499516e+0); - s = s * a; - r = r * s; - s = fma (a, 3.1415926535897931e+0, r); - res[0] = c; - res[1] = s; - } - POCKETFFT_NOINLINE void calc_first_octant(size_t den, T * POCKETFFT_RESTRICT res) { @@ -321,7 +450,7 @@ template<typename T> class sincos_2pibyn arr<Thigh> tmp(2*l1); for (size_t i=1; i<l1; ++i) { - my_sincosm1pi(Thigh(2*i)/Thigh(den),&tmp[2*i]); + sincosm1pi0(Thigh(2*i)/Thigh(den),&tmp[2*i]); res[2*i ] = T(tmp[2*i]+1); res[2*i+1] = T(tmp[2*i+1]); } @@ -329,7 +458,7 @@ template<typename T> class sincos_2pibyn while(start<n) { Thigh cs[2]; - my_sincosm1pi((Thigh(2*start))/Thigh(den),cs); + sincosm1pi0((Thigh(2*start))/Thigh(den),cs); res[2*start] = T(cs[0]+1); res[2*start+1] = T(cs[1]); size_t end = l1; @@ -2561,9 +2690,9 @@ template<typename T0> class T_dcst23 POCKETFFT_NOINLINE T_dcst23(size_t length) : fftplan(length), twiddle(length) { - constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); + const auto oo2n = T0(0.5)/T0(length); for (size_t i=0; i<length; ++i) - twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length))); + twiddle[i] = cospi(oo2n*T0(i+1)); } template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, @@ -2638,12 +2767,13 @@ template<typename T0> class T_dcst4 rfft((N&1)? new pocketfft_r<T0>(N) : nullptr), C2((N&1) ? 0 : N/2) { - constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); + const auto oon = -T0(1.)/T0(N); if ((N&1)==0) for (size_t i=0; i<N/2; ++i) { - T0 ang = -pi/T0(N)*(T0(i)+T0(0.125)); - C2[i].Set(cos(ang), sin(ang)); + T0 sincos[2]; + sincospi(oon*(T0(i)+T0(0.125)), sincos); + C2[i].Set(sincos[0], sincos[1]); } } diff --git a/stress.py b/stress.py index 6b97490b63489ed1450311adb481e8426817485a..ce1a43397d41647b26b2da95642bb3efb9980308 100644 --- a/stress.py +++ b/stress.py @@ -50,6 +50,7 @@ def test(err): axes = axes[:nax] lastsize = shape[axes[-1]] a = np.random.rand(*shape)-0.5 + 1j*np.random.rand(*shape)-0.5j + a_32 = a.astype(np.complex64) b = ifftn(fftn(a, axes=axes, nthreads=nthreads), axes=axes, inorm=2, nthreads=nthreads) err = update_err(err, "cmax", _l2error(a, b)) @@ -86,6 +87,51 @@ def test(err): nthreads=nthreads), axes=axes, inorm=2, nthreads=nthreads) err = update_err(err, "hmaxf", _l2error(a.real.astype(np.float32), b)) + if all(a.shape[i] > 1 for i in axes): + b = pypocketfft.dct( + pypocketfft.dct(a.real, axes=axes, nthreads=nthreads, type=1), + axes=axes, type=1, nthreads=nthreads, inorm=2) + err = update_err(err, "c1max", _l2error(a.real, b)) + b = pypocketfft.dct( + pypocketfft.dct(a_32.real, axes=axes, nthreads=nthreads, type=1), + axes=axes, type=1, nthreads=nthreads, inorm=2) + err = update_err(err, "c1maxf", _l2error(a_32.real, b)) + b = pypocketfft.dct( + pypocketfft.dct(a.real, axes=axes, nthreads=nthreads, type=2), + axes=axes, type=3, nthreads=nthreads, inorm=2) + err = update_err(err, "c23max", _l2error(a.real, b)) + b = pypocketfft.dct( + pypocketfft.dct(a_32.real, axes=axes, nthreads=nthreads, type=2), + axes=axes, type=3, nthreads=nthreads, inorm=2) + err = update_err(err, "c23maxf", _l2error(a_32.real, b)) + b = pypocketfft.dct( + pypocketfft.dct(a.real, axes=axes, nthreads=nthreads, type=4), + axes=axes, type=4, nthreads=nthreads, inorm=2) + err = update_err(err, "c4max", _l2error(a.real, b)) + b = pypocketfft.dct( + pypocketfft.dct(a_32.real, axes=axes, nthreads=nthreads, type=4), + axes=axes, type=4, nthreads=nthreads, inorm=2) + err = update_err(err, "c4maxf", _l2error(a_32.real, b)) + b = pypocketfft.dst( + pypocketfft.dst(a_32.real, axes=axes, nthreads=nthreads, type=1), + axes=axes, type=1, nthreads=nthreads, inorm=2) + err = update_err(err, "s1maxf", _l2error(a_32.real, b)) + b = pypocketfft.dst( + pypocketfft.dst(a.real, axes=axes, nthreads=nthreads, type=2), + axes=axes, type=3, nthreads=nthreads, inorm=2) + err = update_err(err, "s23max", _l2error(a.real, b)) + b = pypocketfft.dst( + pypocketfft.dst(a_32.real, axes=axes, nthreads=nthreads, type=2), + axes=axes, type=3, nthreads=nthreads, inorm=2) + err = update_err(err, "s23maxf", _l2error(a_32.real, b)) + b = pypocketfft.dst( + pypocketfft.dst(a.real, axes=axes, nthreads=nthreads, type=4), + axes=axes, type=4, nthreads=nthreads, inorm=2) + err = update_err(err, "s4max", _l2error(a.real, b)) + b = pypocketfft.dst( + pypocketfft.dst(a_32.real, axes=axes, nthreads=nthreads, type=4), + axes=axes, type=4, nthreads=nthreads, inorm=2) + err = update_err(err, "s4maxf", _l2error(a_32.real, b)) err = dict()