From d3d3204c55a2cc9ca37d382e1187ca139fff4d67 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Thu, 20 Jun 2019 13:04:26 +0200
Subject: [PATCH] add caching

---
 pocketfft_hdronly.h | 45 ++++++++++++++++++++++++++++++++-------------
 1 file changed, 32 insertions(+), 13 deletions(-)

diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h
index 584279d..bda7d5e 100644
--- a/pocketfft_hdronly.h
+++ b/pocketfft_hdronly.h
@@ -503,7 +503,7 @@ struct util // hack to avoid duplicate symbols
       size_t axis)
       {
       if (nthreads==1) return 1;
-      if (prod(shape)/shape[axis] < 20) return 1;
+      if (prod(shape) < 20*shape[axis]) return 1;
       return (nthreads==0) ? size_t(omp_get_max_threads()) : nthreads;
       }
 #else
@@ -2200,6 +2200,25 @@ template<typename T0> class pocketfft_r
 // multi-D infrastructure
 //
 
+template<typename T> shared_ptr<T> get_plan(size_t length)
+  {
+  constexpr size_t nmax=10;
+  static vector<shared_ptr<T>> cache(nmax);
+  static size_t lastpos=0;
+  // searching phase
+  size_t pos = lastpos;
+  do {
+    if (!cache[pos]) break;
+    if (cache[pos]->length()==length)
+      return cache[pos];
+    pos = (pos==0) ? nmax-1 : pos-1;
+    } while (pos != lastpos);
+  // adding phase
+  if (++lastpos>=nmax) lastpos=0;
+  cache[lastpos] = make_shared<T>(length);
+  return cache[lastpos];
+  }
+
 class arr_info
   {
   protected:
@@ -2455,14 +2474,14 @@ template<typename T> NOINLINE void general_c(
   const cndarr<cmplx<T>> &in, ndarr<cmplx<T>> &out,
   const shape_t &axes, bool forward, T fct, size_t POCKETFFT_NTHREADS)
   {
-  unique_ptr<pocketfft_c<T>> plan;
+  shared_ptr<pocketfft_c<T>> plan;
 
   for (size_t iax=0; iax<axes.size(); ++iax)
     {
     constexpr auto vlen = VLEN<T>::val;
     size_t len=in.shape(axes[iax]);
     if ((!plan) || (len!=plan->length()))
-      plan.reset(new pocketfft_c<T>(len));
+      plan = get_plan<pocketfft_c<T>>(len);
 
 #ifdef POCKETFFT_OPENMP
 #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
@@ -2522,14 +2541,14 @@ template<typename T> NOINLINE void general_hartley(
   const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, T fct,
   size_t POCKETFFT_NTHREADS)
   {
-  unique_ptr<pocketfft_r<T>> plan;
+  shared_ptr<pocketfft_r<T>> plan;
 
   for (size_t iax=0; iax<axes.size(); ++iax)
     {
     constexpr auto vlen = VLEN<T>::val;
     size_t len=in.shape(axes[iax]);
     if ((!plan) || (len!=plan->length()))
-      plan.reset(new pocketfft_r<T>(len));
+      plan = get_plan<pocketfft_r<T>>(len);
 
 #ifdef POCKETFFT_OPENMP
 #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
@@ -2590,7 +2609,7 @@ template<typename T> NOINLINE void general_r2c(
   const cndarr<T> &in, ndarr<cmplx<T>> &out, size_t axis, bool forward, T fct,
   size_t POCKETFFT_NTHREADS)
   {
-  pocketfft_r<T> plan(in.shape(axis));
+  auto plan = get_plan<pocketfft_r<T>>(in.shape(axis));
   constexpr auto vlen = VLEN<T>::val;
   size_t len=in.shape(axis);
 #ifdef POCKETFFT_OPENMP
@@ -2609,7 +2628,7 @@ template<typename T> NOINLINE void general_r2c(
       for (size_t i=0; i<len; ++i)
         for (size_t j=0; j<vlen; ++j)
           tdatav[i][j] = in[it.iofs(j,i)];
-      plan.forward(tdatav, fct);
+      plan->forward(tdatav, fct);
       for (size_t j=0; j<vlen; ++j)
         out[it.oofs(j,0)].Set(tdatav[0][j]);
       size_t i=1, ii=1;
@@ -2632,7 +2651,7 @@ template<typename T> NOINLINE void general_r2c(
     auto tdata = reinterpret_cast<T *>(storage.data());
     for (size_t i=0; i<len; ++i)
       tdata[i] = in[it.iofs(i)];
-    plan.forward(tdata, fct);
+    plan->forward(tdata, fct);
     out[it.oofs(0)].Set(tdata[0]);
     size_t i=1, ii=1;
     if (forward)
@@ -2650,7 +2669,7 @@ template<typename T> NOINLINE void general_c2r(
   const cndarr<cmplx<T>> &in, ndarr<T> &out, size_t axis, bool forward, T fct,
   size_t POCKETFFT_NTHREADS)
   {
-  pocketfft_r<T> plan(out.shape(axis));
+  auto plan = get_plan<pocketfft_r<T>>(out.shape(axis));
   constexpr auto vlen = VLEN<T>::val;
   size_t len=out.shape(axis);
 #ifdef POCKETFFT_OPENMP
@@ -2688,7 +2707,7 @@ template<typename T> NOINLINE void general_c2r(
         for (size_t j=0; j<vlen; ++j)
           tdatav[i][j] = in[it.iofs(j,ii)].r;
       }
-      plan.backward(tdatav, fct);
+      plan->backward(tdatav, fct);
       for (size_t i=0; i<len; ++i)
         for (size_t j=0; j<vlen; ++j)
           out[it.oofs(j,i)] = tdatav[i][j];
@@ -2716,7 +2735,7 @@ template<typename T> NOINLINE void general_c2r(
     if (i<len)
       tdata[i] = in[it.iofs(ii)].r;
     }
-    plan.backward(tdata, fct);
+    plan->backward(tdata, fct);
     for (size_t i=0; i<len; ++i)
       out[it.oofs(i)] = tdata[i];
     }
@@ -2727,14 +2746,14 @@ template<typename T> NOINLINE void general_r(
   const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, bool r2c,
   bool forward, T fct, size_t POCKETFFT_NTHREADS)
   {
-  unique_ptr<pocketfft_r<T>> plan;
+  shared_ptr<pocketfft_r<T>> plan;
 
   for (size_t iax=0; iax<axes.size(); ++iax)
     {
     constexpr auto vlen = VLEN<T>::val;
     size_t len=in.shape(axes[iax]);
     if ((!plan) || (len!=plan->length()))
-      plan.reset(new pocketfft_r<T>(len));
+      plan = get_plan<pocketfft_r<T>>(len);
 
 #ifdef POCKETFFT_OPENMP
 #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
-- 
GitLab