Commit 4ea60b74 authored by Martin Reinecke's avatar Martin Reinecke

add switch to disable multithreading; fix lambda capture for MSVC

parent 4e511317
...@@ -17,4 +17,4 @@ Features ...@@ -17,4 +17,4 @@ Features
- makes use of CPU vector instructions when performing 2D and higher-dimensional - makes use of CPU vector instructions when performing 2D and higher-dimensional
transforms transforms
- supports prime-length transforms without degrading to O(N**2) performance - supports prime-length transforms without degrading to O(N**2) performance
- has optional OpenMP support for multidimensional transforms - has optional multi-threading support for multidimensional transforms
...@@ -63,6 +63,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ...@@ -63,6 +63,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <mutex> #include <mutex>
#endif #endif
#ifndef POCKETFFT_NO_MULTITHREADING
#include <mutex> #include <mutex>
#include <condition_variable> #include <condition_variable>
#include <thread> #include <thread>
...@@ -73,7 +74,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ...@@ -73,7 +74,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifdef POCKETFFT_PTHREADS #ifdef POCKETFFT_PTHREADS
# include <pthread.h> # include <pthread.h>
#endif #endif
#endif
#if defined(__GNUC__) #if defined(__GNUC__)
#define POCKETFFT_NOINLINE __attribute__((noinline)) #define POCKETFFT_NOINLINE __attribute__((noinline))
...@@ -694,21 +695,39 @@ struct util // hack to avoid duplicate symbols ...@@ -694,21 +695,39 @@ struct util // hack to avoid duplicate symbols
if (axis>=shape.size()) throw invalid_argument("bad axis number"); if (axis>=shape.size()) throw invalid_argument("bad axis number");
} }
static size_t thread_count (size_t nthreads, const shape_t &shape, #ifdef POCKETFFT_NO_MULTITHREADING
size_t axis, size_t vlen) static size_t thread_count (size_t /*nthreads*/, const shape_t &/*shape*/,
{ size_t /*axis*/, size_t /*vlen*/)
if (nthreads==1) return 1; { return 1; }
size_t size = prod(shape); #else
size_t parallel = size / (shape[axis] * vlen); static size_t thread_count (size_t nthreads, const shape_t &shape,
if (shape[axis] < 1000) size_t axis, size_t vlen)
parallel /= 4; {
size_t max_threads = nthreads == 0 ? if (nthreads==1) return 1;
thread::hardware_concurrency() : nthreads; size_t size = prod(shape);
return max(size_t(1), min(parallel, max_threads)); size_t parallel = size / (shape[axis] * vlen);
} if (shape[axis] < 1000)
parallel /= 4;
size_t max_threads = nthreads == 0 ?
thread::hardware_concurrency() : nthreads;
return max(size_t(1), min(parallel, max_threads));
}
#endif
}; };
namespace threading { namespace threading {
#ifdef POCKETFFT_NO_MULTITHREADING
constexpr size_t thread_id = 0;
constexpr size_t num_threads = 1;
template <typename Func>
void thread_map(size_t /* nthreads */, Func f)
{ f(); }
#else
thread_local size_t thread_id = 0; thread_local size_t thread_id = 0;
thread_local size_t num_threads = 1; thread_local size_t num_threads = 1;
...@@ -892,6 +911,9 @@ void thread_map(size_t nthreads, Func f) ...@@ -892,6 +911,9 @@ void thread_map(size_t nthreads, Func f)
if (ex) if (ex)
rethrow_exception(ex); rethrow_exception(ex);
} }
#endif
} }
// //
...@@ -2789,7 +2811,6 @@ template<typename T0> class T_dcst4 ...@@ -2789,7 +2811,6 @@ template<typename T0> class T_dcst4
// and is released under the 3-clause BSD license with friendly // and is released under the 3-clause BSD license with friendly
// permission of Matteo Frigo and Steven G. Johnson. // permission of Matteo Frigo and Steven G. Johnson.
auto SGN = [](size_t i) {return (i&2) ? -sqrt2 : sqrt2;};
arr<T> y(N); arr<T> y(N);
{ {
size_t i=0, m=n2; size_t i=0, m=n2;
...@@ -2806,6 +2827,7 @@ template<typename T0> class T_dcst4 ...@@ -2806,6 +2827,7 @@ template<typename T0> class T_dcst4
} }
rfft->exec(y.data(), fct, true); rfft->exec(y.data(), fct, true);
{ {
auto SGN = [sqrt2](size_t i) {return (i&2) ? -sqrt2 : sqrt2;};
c[n2] = y[0]*SGN(n2+1); c[n2] = y[0]*SGN(n2+1);
size_t i=0, i1=1, k=1; size_t i=0, i1=1, k=1;
for (; k<n2; ++i, ++i1, k+=2) for (; k<n2; ++i, ++i1, k+=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment