From d3d3204c55a2cc9ca37d382e1187ca139fff4d67 Mon Sep 17 00:00:00 2001 From: Martin Reinecke Date: Thu, 20 Jun 2019 13:04:26 +0200 Subject: [PATCH] add caching --- pocketfft_hdronly.h | 45 ++++++++++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index 584279d..bda7d5e 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -503,7 +503,7 @@ struct util // hack to avoid duplicate symbols size_t axis) { if (nthreads==1) return 1; - if (prod(shape)/shape[axis] < 20) return 1; + if (prod(shape) < 20*shape[axis]) return 1; return (nthreads==0) ? size_t(omp_get_max_threads()) : nthreads; } #else @@ -2200,6 +2200,25 @@ template class pocketfft_r // multi-D infrastructure // +template shared_ptr get_plan(size_t length) + { + constexpr size_t nmax=10; + static vector> cache(nmax); + static size_t lastpos=0; + // searching phase + size_t pos = lastpos; + do { + if (!cache[pos]) break; + if (cache[pos]->length()==length) + return cache[pos]; + pos = (pos==0) ? nmax-1 : pos-1; + } while (pos != lastpos); + // adding phase + if (++lastpos>=nmax) lastpos=0; + cache[lastpos] = make_shared(length); + return cache[lastpos]; + } + class arr_info { protected: @@ -2455,14 +2474,14 @@ template NOINLINE void general_c( const cndarr> &in, ndarr> &out, const shape_t &axes, bool forward, T fct, size_t POCKETFFT_NTHREADS) { - unique_ptr> plan; + shared_ptr> plan; for (size_t iax=0; iax::val; size_t len=in.shape(axes[iax]); if ((!plan) || (len!=plan->length())) - plan.reset(new pocketfft_c(len)); + plan = get_plan>(len); #ifdef POCKETFFT_OPENMP #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax])) @@ -2522,14 +2541,14 @@ template NOINLINE void general_hartley( const cndarr &in, ndarr &out, const shape_t &axes, T fct, size_t POCKETFFT_NTHREADS) { - unique_ptr> plan; + shared_ptr> plan; for (size_t iax=0; iax::val; size_t len=in.shape(axes[iax]); if ((!plan) || (len!=plan->length())) - plan.reset(new pocketfft_r(len)); + plan = get_plan>(len); #ifdef POCKETFFT_OPENMP #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax])) @@ -2590,7 +2609,7 @@ template NOINLINE void general_r2c( const cndarr &in, ndarr> &out, size_t axis, bool forward, T fct, size_t POCKETFFT_NTHREADS) { - pocketfft_r plan(in.shape(axis)); + auto plan = get_plan>(in.shape(axis)); constexpr auto vlen = VLEN::val; size_t len=in.shape(axis); #ifdef POCKETFFT_OPENMP @@ -2609,7 +2628,7 @@ template NOINLINE void general_r2c( for (size_t i=0; iforward(tdatav, fct); for (size_t j=0; j NOINLINE void general_r2c( auto tdata = reinterpret_cast(storage.data()); for (size_t i=0; iforward(tdata, fct); out[it.oofs(0)].Set(tdata[0]); size_t i=1, ii=1; if (forward) @@ -2650,7 +2669,7 @@ template NOINLINE void general_c2r( const cndarr> &in, ndarr &out, size_t axis, bool forward, T fct, size_t POCKETFFT_NTHREADS) { - pocketfft_r plan(out.shape(axis)); + auto plan = get_plan>(out.shape(axis)); constexpr auto vlen = VLEN::val; size_t len=out.shape(axis); #ifdef POCKETFFT_OPENMP @@ -2688,7 +2707,7 @@ template NOINLINE void general_c2r( for (size_t j=0; jbackward(tdatav, fct); for (size_t i=0; i NOINLINE void general_c2r( if (ibackward(tdata, fct); for (size_t i=0; i NOINLINE void general_r( const cndarr &in, ndarr &out, const shape_t &axes, bool r2c, bool forward, T fct, size_t POCKETFFT_NTHREADS) { - unique_ptr> plan; + shared_ptr> plan; for (size_t iax=0; iax::val; size_t len=in.shape(axes[iax]); if ((!plan) || (len!=plan->length())) - plan.reset(new pocketfft_r(len)); + plan = get_plan>(len); #ifdef POCKETFFT_OPENMP #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax])) -- GitLab