diff --git a/README.md b/README.md index 9948ff1caf2132f1eb80c536da9dd76e8ee39d92..a54e6bb2b8fa5f564841a67c86bddd7f35bb6d46 100644 --- a/README.md +++ b/README.md @@ -17,4 +17,4 @@ Features - makes use of CPU vector instructions when performing 2D and higher-dimensional transforms - supports prime-length transforms without degrading to O(N**2) performance -- has optional OpenMP support for multidimensional transforms +- has optional multi-threading support for multidimensional transforms diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index 95e34a906f5ae49f93d5b429819fe3b2a132a8c8..a90bd400fead85a91ae120da52ec4e1e4e3b7f7f 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -63,6 +63,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include <mutex> #endif +#ifndef POCKETFFT_NO_MULTITHREADING #include <mutex> #include <condition_variable> #include <thread> @@ -73,7 +74,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifdef POCKETFFT_PTHREADS # include <pthread.h> #endif - +#endif #if defined(__GNUC__) #define POCKETFFT_NOINLINE __attribute__((noinline)) @@ -694,21 +695,39 @@ struct util // hack to avoid duplicate symbols if (axis>=shape.size()) throw invalid_argument("bad axis number"); } - static size_t thread_count (size_t nthreads, const shape_t &shape, - size_t axis, size_t vlen) - { - if (nthreads==1) return 1; - size_t size = prod(shape); - size_t parallel = size / (shape[axis] * vlen); - if (shape[axis] < 1000) - parallel /= 4; - size_t max_threads = nthreads == 0 ? - thread::hardware_concurrency() : nthreads; - return max(size_t(1), min(parallel, max_threads)); - } +#ifdef POCKETFFT_NO_MULTITHREADING + static size_t thread_count (size_t /*nthreads*/, const shape_t &/*shape*/, + size_t /*axis*/, size_t /*vlen*/) + { return 1; } +#else + static size_t thread_count (size_t nthreads, const shape_t &shape, + size_t axis, size_t vlen) + { + if (nthreads==1) return 1; + size_t size = prod(shape); + size_t parallel = size / (shape[axis] * vlen); + if (shape[axis] < 1000) + parallel /= 4; + size_t max_threads = nthreads == 0 ? + thread::hardware_concurrency() : nthreads; + return max(size_t(1), min(parallel, max_threads)); + } +#endif }; namespace threading { + +#ifdef POCKETFFT_NO_MULTITHREADING + +constexpr size_t thread_id = 0; +constexpr size_t num_threads = 1; + +template <typename Func> +void thread_map(size_t /* nthreads */, Func f) + { f(); } + +#else + thread_local size_t thread_id = 0; thread_local size_t num_threads = 1; @@ -892,6 +911,9 @@ void thread_map(size_t nthreads, Func f) if (ex) rethrow_exception(ex); } + +#endif + } // @@ -2752,8 +2774,6 @@ template<typename T0> class T_dcst23 template<typename T0> class T_dcst4 { - // even length algorithm from - // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/ private: size_t N; unique_ptr<pocketfft_c<T0>> fft; @@ -2789,9 +2809,8 @@ template<typename T0> class T_dcst4 { // The following code is derived from the FFTW3 function apply_re11() // and is released under the 3-clause BSD license with friendly - // permission of Matteo Frigo. + // permission of Matteo Frigo and Steven G. Johnson. - auto SGN_SET = [](T x, size_t i) {return (i%2) ? -x : x;}; arr<T> y(N); { size_t i=0, m=n2; @@ -2803,30 +2822,27 @@ template<typename T0> class T_dcst4 y[i] = -c[m-2*N]; for (; m<4*N; ++i, m+=4) y[i] = c[4*N-m-1]; - m -= 4*N; for (; i<N; ++i, m+=4) - y[i] = c[m]; + y[i] = c[m-4*N]; } rfft->exec(y.data(), fct, true); { - size_t i=0; - for (; i+i+1<n2; ++i) + auto SGN = [sqrt2](size_t i) {return (i&2) ? -sqrt2 : sqrt2;}; + c[n2] = y[0]*SGN(n2+1); + size_t i=0, i1=1, k=1; + for (; k<n2; ++i, ++i1, k+=2) { - size_t k = i+i+1; - T c1=y[2*k-1], s1=y[2*k], c2=y[2*k+1], s2=y[2*k+2]; - c[i] = sqrt2 * (SGN_SET(c1, (i+1)/2) + SGN_SET(s1, i/2)); - c[N-(i+1)] = sqrt2 * (SGN_SET(c1, (N-i)/2) - SGN_SET(s1, (N-(i+1))/2)); - c[n2-(i+1)] = sqrt2 * (SGN_SET(c2, (n2-i)/2) - SGN_SET(s2, (n2-(i+1))/2)); - c[n2+(i+1)] = sqrt2 * (SGN_SET(c2, (n2+i+2)/2) + SGN_SET(s2, (n2+(i+1))/2)); + c[i ] = y[2*k-1]*SGN(i1) + y[2*k ]*SGN(i); + c[N -i1] = y[2*k-1]*SGN(N -i) - y[2*k ]*SGN(N -i1); + c[n2-i1] = y[2*k+1]*SGN(n2-i) - y[2*k+2]*SGN(n2-i1); + c[n2+i1] = y[2*k+1]*SGN(n2+i+2) + y[2*k+2]*SGN(n2+i1); } - if (i+i+1 == n2) + if (k == n2) { - T cx=y[2*n2-1], sx=y[2*n2]; - c[i] = sqrt2 * (SGN_SET(cx, (i+1)/2) + SGN_SET(sx, i/2)); - c[N-(i+1)] = sqrt2 * (SGN_SET(cx, (i+2)/2) + SGN_SET(sx, (i+1)/2)); + c[i ] = y[2*k-1]*SGN(i+1) + y[2*k]*SGN(i); + c[N-i1] = y[2*k-1]*SGN(i+2) + y[2*k]*SGN(i1); } } - c[n2] = sqrt2 * SGN_SET(y[0], (n2+1)/2); // FFTW-derived code ends here } diff --git a/setup.py b/setup.py index cd8c7979cc5cae44356f5824b88451c6c0a95457..985e84607c2dea8348718b4909ef050ec6ca2a0e 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ def get_extension_modules(): return [Extension('pypocketfft', language='c++', sources=['pypocketfft.cc'], - depends=['pocketfft_hdronly.h'], + depends=['pocketfft_hdronly.h', 'setup.py'], include_dirs=include_dirs, define_macros=define_macros, extra_compile_args=extra_compile_args,