From d3d3204c55a2cc9ca37d382e1187ca139fff4d67 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Thu, 20 Jun 2019 13:04:26 +0200 Subject: [PATCH] add caching --- pocketfft_hdronly.h | 45 ++++++++++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index 584279d..bda7d5e 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -503,7 +503,7 @@ struct util // hack to avoid duplicate symbols size_t axis) { if (nthreads==1) return 1; - if (prod(shape)/shape[axis] < 20) return 1; + if (prod(shape) < 20*shape[axis]) return 1; return (nthreads==0) ? size_t(omp_get_max_threads()) : nthreads; } #else @@ -2200,6 +2200,25 @@ template<typename T0> class pocketfft_r // multi-D infrastructure // +template<typename T> shared_ptr<T> get_plan(size_t length) + { + constexpr size_t nmax=10; + static vector<shared_ptr<T>> cache(nmax); + static size_t lastpos=0; + // searching phase + size_t pos = lastpos; + do { + if (!cache[pos]) break; + if (cache[pos]->length()==length) + return cache[pos]; + pos = (pos==0) ? nmax-1 : pos-1; + } while (pos != lastpos); + // adding phase + if (++lastpos>=nmax) lastpos=0; + cache[lastpos] = make_shared<T>(length); + return cache[lastpos]; + } + class arr_info { protected: @@ -2455,14 +2474,14 @@ template<typename T> NOINLINE void general_c( const cndarr<cmplx<T>> &in, ndarr<cmplx<T>> &out, const shape_t &axes, bool forward, T fct, size_t POCKETFFT_NTHREADS) { - unique_ptr<pocketfft_c<T>> plan; + shared_ptr<pocketfft_c<T>> plan; for (size_t iax=0; iax<axes.size(); ++iax) { constexpr auto vlen = VLEN<T>::val; size_t len=in.shape(axes[iax]); if ((!plan) || (len!=plan->length())) - plan.reset(new pocketfft_c<T>(len)); + plan = get_plan<pocketfft_c<T>>(len); #ifdef POCKETFFT_OPENMP #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax])) @@ -2522,14 +2541,14 @@ template<typename T> NOINLINE void general_hartley( const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, T fct, size_t POCKETFFT_NTHREADS) { - unique_ptr<pocketfft_r<T>> plan; + shared_ptr<pocketfft_r<T>> plan; for (size_t iax=0; iax<axes.size(); ++iax) { constexpr auto vlen = VLEN<T>::val; size_t len=in.shape(axes[iax]); if ((!plan) || (len!=plan->length())) - plan.reset(new pocketfft_r<T>(len)); + plan = get_plan<pocketfft_r<T>>(len); #ifdef POCKETFFT_OPENMP #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax])) @@ -2590,7 +2609,7 @@ template<typename T> NOINLINE void general_r2c( const cndarr<T> &in, ndarr<cmplx<T>> &out, size_t axis, bool forward, T fct, size_t POCKETFFT_NTHREADS) { - pocketfft_r<T> plan(in.shape(axis)); + auto plan = get_plan<pocketfft_r<T>>(in.shape(axis)); constexpr auto vlen = VLEN<T>::val; size_t len=in.shape(axis); #ifdef POCKETFFT_OPENMP @@ -2609,7 +2628,7 @@ template<typename T> NOINLINE void general_r2c( for (size_t i=0; i<len; ++i) for (size_t j=0; j<vlen; ++j) tdatav[i][j] = in[it.iofs(j,i)]; - plan.forward(tdatav, fct); + plan->forward(tdatav, fct); for (size_t j=0; j<vlen; ++j) out[it.oofs(j,0)].Set(tdatav[0][j]); size_t i=1, ii=1; @@ -2632,7 +2651,7 @@ template<typename T> NOINLINE void general_r2c( auto tdata = reinterpret_cast<T *>(storage.data()); for (size_t i=0; i<len; ++i) tdata[i] = in[it.iofs(i)]; - plan.forward(tdata, fct); + plan->forward(tdata, fct); out[it.oofs(0)].Set(tdata[0]); size_t i=1, ii=1; if (forward) @@ -2650,7 +2669,7 @@ template<typename T> NOINLINE void general_c2r( const cndarr<cmplx<T>> &in, ndarr<T> &out, size_t axis, bool forward, T fct, size_t POCKETFFT_NTHREADS) { - pocketfft_r<T> plan(out.shape(axis)); + auto plan = get_plan<pocketfft_r<T>>(out.shape(axis)); constexpr auto vlen = VLEN<T>::val; size_t len=out.shape(axis); #ifdef POCKETFFT_OPENMP @@ -2688,7 +2707,7 @@ template<typename T> NOINLINE void general_c2r( for (size_t j=0; j<vlen; ++j) tdatav[i][j] = in[it.iofs(j,ii)].r; } - plan.backward(tdatav, fct); + plan->backward(tdatav, fct); for (size_t i=0; i<len; ++i) for (size_t j=0; j<vlen; ++j) out[it.oofs(j,i)] = tdatav[i][j]; @@ -2716,7 +2735,7 @@ template<typename T> NOINLINE void general_c2r( if (i<len) tdata[i] = in[it.iofs(ii)].r; } - plan.backward(tdata, fct); + plan->backward(tdata, fct); for (size_t i=0; i<len; ++i) out[it.oofs(i)] = tdata[i]; } @@ -2727,14 +2746,14 @@ template<typename T> NOINLINE void general_r( const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, bool r2c, bool forward, T fct, size_t POCKETFFT_NTHREADS) { - unique_ptr<pocketfft_r<T>> plan; + shared_ptr<pocketfft_r<T>> plan; for (size_t iax=0; iax<axes.size(); ++iax) { constexpr auto vlen = VLEN<T>::val; size_t len=in.shape(axes[iax]); if ((!plan) || (len!=plan->length())) - plan.reset(new pocketfft_r<T>(len)); + plan = get_plan<pocketfft_r<T>>(len); #ifdef POCKETFFT_OPENMP #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax])) -- GitLab