Commit d3d3204c authored by Martin Reinecke's avatar Martin Reinecke
Browse files

add caching

parent 5705b6a5
...@@ -503,7 +503,7 @@ struct util // hack to avoid duplicate symbols ...@@ -503,7 +503,7 @@ struct util // hack to avoid duplicate symbols
size_t axis) size_t axis)
{ {
if (nthreads==1) return 1; if (nthreads==1) return 1;
if (prod(shape)/shape[axis] < 20) return 1; if (prod(shape) < 20*shape[axis]) return 1;
return (nthreads==0) ? size_t(omp_get_max_threads()) : nthreads; return (nthreads==0) ? size_t(omp_get_max_threads()) : nthreads;
} }
#else #else
...@@ -2200,6 +2200,25 @@ template<typename T0> class pocketfft_r ...@@ -2200,6 +2200,25 @@ template<typename T0> class pocketfft_r
// multi-D infrastructure // multi-D infrastructure
// //
template<typename T> shared_ptr<T> get_plan(size_t length)
{
constexpr size_t nmax=10;
static vector<shared_ptr<T>> cache(nmax);
static size_t lastpos=0;
// searching phase
size_t pos = lastpos;
do {
if (!cache[pos]) break;
if (cache[pos]->length()==length)
return cache[pos];
pos = (pos==0) ? nmax-1 : pos-1;
} while (pos != lastpos);
// adding phase
if (++lastpos>=nmax) lastpos=0;
cache[lastpos] = make_shared<T>(length);
return cache[lastpos];
}
class arr_info class arr_info
{ {
protected: protected:
...@@ -2455,14 +2474,14 @@ template<typename T> NOINLINE void general_c( ...@@ -2455,14 +2474,14 @@ template<typename T> NOINLINE void general_c(
const cndarr<cmplx<T>> &in, ndarr<cmplx<T>> &out, const cndarr<cmplx<T>> &in, ndarr<cmplx<T>> &out,
const shape_t &axes, bool forward, T fct, size_t POCKETFFT_NTHREADS) const shape_t &axes, bool forward, T fct, size_t POCKETFFT_NTHREADS)
{ {
unique_ptr<pocketfft_c<T>> plan; shared_ptr<pocketfft_c<T>> plan;
for (size_t iax=0; iax<axes.size(); ++iax) for (size_t iax=0; iax<axes.size(); ++iax)
{ {
constexpr auto vlen = VLEN<T>::val; constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axes[iax]); size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length())) if ((!plan) || (len!=plan->length()))
plan.reset(new pocketfft_c<T>(len)); plan = get_plan<pocketfft_c<T>>(len);
#ifdef POCKETFFT_OPENMP #ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax])) #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
...@@ -2522,14 +2541,14 @@ template<typename T> NOINLINE void general_hartley( ...@@ -2522,14 +2541,14 @@ template<typename T> NOINLINE void general_hartley(
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, T fct, const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, T fct,
size_t POCKETFFT_NTHREADS) size_t POCKETFFT_NTHREADS)
{ {
unique_ptr<pocketfft_r<T>> plan; shared_ptr<pocketfft_r<T>> plan;
for (size_t iax=0; iax<axes.size(); ++iax) for (size_t iax=0; iax<axes.size(); ++iax)
{ {
constexpr auto vlen = VLEN<T>::val; constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axes[iax]); size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length())) if ((!plan) || (len!=plan->length()))
plan.reset(new pocketfft_r<T>(len)); plan = get_plan<pocketfft_r<T>>(len);
#ifdef POCKETFFT_OPENMP #ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax])) #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
...@@ -2590,7 +2609,7 @@ template<typename T> NOINLINE void general_r2c( ...@@ -2590,7 +2609,7 @@ template<typename T> NOINLINE void general_r2c(
const cndarr<T> &in, ndarr<cmplx<T>> &out, size_t axis, bool forward, T fct, const cndarr<T> &in, ndarr<cmplx<T>> &out, size_t axis, bool forward, T fct,
size_t POCKETFFT_NTHREADS) size_t POCKETFFT_NTHREADS)
{ {
pocketfft_r<T> plan(in.shape(axis)); auto plan = get_plan<pocketfft_r<T>>(in.shape(axis));
constexpr auto vlen = VLEN<T>::val; constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axis); size_t len=in.shape(axis);
#ifdef POCKETFFT_OPENMP #ifdef POCKETFFT_OPENMP
...@@ -2609,7 +2628,7 @@ template<typename T> NOINLINE void general_r2c( ...@@ -2609,7 +2628,7 @@ template<typename T> NOINLINE void general_r2c(
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = in[it.iofs(j,i)]; tdatav[i][j] = in[it.iofs(j,i)];
plan.forward(tdatav, fct); plan->forward(tdatav, fct);
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,0)].Set(tdatav[0][j]); out[it.oofs(j,0)].Set(tdatav[0][j]);
size_t i=1, ii=1; size_t i=1, ii=1;
...@@ -2632,7 +2651,7 @@ template<typename T> NOINLINE void general_r2c( ...@@ -2632,7 +2651,7 @@ template<typename T> NOINLINE void general_r2c(
auto tdata = reinterpret_cast<T *>(storage.data()); auto tdata = reinterpret_cast<T *>(storage.data());
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
tdata[i] = in[it.iofs(i)]; tdata[i] = in[it.iofs(i)];
plan.forward(tdata, fct); plan->forward(tdata, fct);
out[it.oofs(0)].Set(tdata[0]); out[it.oofs(0)].Set(tdata[0]);
size_t i=1, ii=1; size_t i=1, ii=1;
if (forward) if (forward)
...@@ -2650,7 +2669,7 @@ template<typename T> NOINLINE void general_c2r( ...@@ -2650,7 +2669,7 @@ template<typename T> NOINLINE void general_c2r(
const cndarr<cmplx<T>> &in, ndarr<T> &out, size_t axis, bool forward, T fct, const cndarr<cmplx<T>> &in, ndarr<T> &out, size_t axis, bool forward, T fct,
size_t POCKETFFT_NTHREADS) size_t POCKETFFT_NTHREADS)
{ {
pocketfft_r<T> plan(out.shape(axis)); auto plan = get_plan<pocketfft_r<T>>(out.shape(axis));
constexpr auto vlen = VLEN<T>::val; constexpr auto vlen = VLEN<T>::val;
size_t len=out.shape(axis); size_t len=out.shape(axis);
#ifdef POCKETFFT_OPENMP #ifdef POCKETFFT_OPENMP
...@@ -2688,7 +2707,7 @@ template<typename T> NOINLINE void general_c2r( ...@@ -2688,7 +2707,7 @@ template<typename T> NOINLINE void general_c2r(
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = in[it.iofs(j,ii)].r; tdatav[i][j] = in[it.iofs(j,ii)].r;
} }
plan.backward(tdatav, fct); plan->backward(tdatav, fct);
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,i)] = tdatav[i][j]; out[it.oofs(j,i)] = tdatav[i][j];
...@@ -2716,7 +2735,7 @@ template<typename T> NOINLINE void general_c2r( ...@@ -2716,7 +2735,7 @@ template<typename T> NOINLINE void general_c2r(
if (i<len) if (i<len)
tdata[i] = in[it.iofs(ii)].r; tdata[i] = in[it.iofs(ii)].r;
} }
plan.backward(tdata, fct); plan->backward(tdata, fct);
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tdata[i]; out[it.oofs(i)] = tdata[i];
} }
...@@ -2727,14 +2746,14 @@ template<typename T> NOINLINE void general_r( ...@@ -2727,14 +2746,14 @@ template<typename T> NOINLINE void general_r(
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, bool r2c, const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, bool r2c,
bool forward, T fct, size_t POCKETFFT_NTHREADS) bool forward, T fct, size_t POCKETFFT_NTHREADS)
{ {
unique_ptr<pocketfft_r<T>> plan; shared_ptr<pocketfft_r<T>> plan;
for (size_t iax=0; iax<axes.size(); ++iax) for (size_t iax=0; iax<axes.size(); ++iax)
{ {
constexpr auto vlen = VLEN<T>::val; constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axes[iax]); size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length())) if ((!plan) || (len!=plan->length()))
plan.reset(new pocketfft_r<T>(len)); plan = get_plan<pocketfft_r<T>>(len);
#ifdef POCKETFFT_OPENMP #ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax])) #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment