Commit d3d3204c authored by Martin Reinecke's avatar Martin Reinecke
Browse files

add caching

parent 5705b6a5
......@@ -503,7 +503,7 @@ struct util // hack to avoid duplicate symbols
size_t axis)
{
if (nthreads==1) return 1;
if (prod(shape)/shape[axis] < 20) return 1;
if (prod(shape) < 20*shape[axis]) return 1;
return (nthreads==0) ? size_t(omp_get_max_threads()) : nthreads;
}
#else
......@@ -2200,6 +2200,25 @@ template<typename T0> class pocketfft_r
// multi-D infrastructure
//
template<typename T> shared_ptr<T> get_plan(size_t length)
{
constexpr size_t nmax=10;
static vector<shared_ptr<T>> cache(nmax);
static size_t lastpos=0;
// searching phase
size_t pos = lastpos;
do {
if (!cache[pos]) break;
if (cache[pos]->length()==length)
return cache[pos];
pos = (pos==0) ? nmax-1 : pos-1;
} while (pos != lastpos);
// adding phase
if (++lastpos>=nmax) lastpos=0;
cache[lastpos] = make_shared<T>(length);
return cache[lastpos];
}
class arr_info
{
protected:
......@@ -2455,14 +2474,14 @@ template<typename T> NOINLINE void general_c(
const cndarr<cmplx<T>> &in, ndarr<cmplx<T>> &out,
const shape_t &axes, bool forward, T fct, size_t POCKETFFT_NTHREADS)
{
unique_ptr<pocketfft_c<T>> plan;
shared_ptr<pocketfft_c<T>> plan;
for (size_t iax=0; iax<axes.size(); ++iax)
{
constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length()))
plan.reset(new pocketfft_c<T>(len));
plan = get_plan<pocketfft_c<T>>(len);
#ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
......@@ -2522,14 +2541,14 @@ template<typename T> NOINLINE void general_hartley(
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, T fct,
size_t POCKETFFT_NTHREADS)
{
unique_ptr<pocketfft_r<T>> plan;
shared_ptr<pocketfft_r<T>> plan;
for (size_t iax=0; iax<axes.size(); ++iax)
{
constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length()))
plan.reset(new pocketfft_r<T>(len));
plan = get_plan<pocketfft_r<T>>(len);
#ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
......@@ -2590,7 +2609,7 @@ template<typename T> NOINLINE void general_r2c(
const cndarr<T> &in, ndarr<cmplx<T>> &out, size_t axis, bool forward, T fct,
size_t POCKETFFT_NTHREADS)
{
pocketfft_r<T> plan(in.shape(axis));
auto plan = get_plan<pocketfft_r<T>>(in.shape(axis));
constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axis);
#ifdef POCKETFFT_OPENMP
......@@ -2609,7 +2628,7 @@ template<typename T> NOINLINE void general_r2c(
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = in[it.iofs(j,i)];
plan.forward(tdatav, fct);
plan->forward(tdatav, fct);
for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,0)].Set(tdatav[0][j]);
size_t i=1, ii=1;
......@@ -2632,7 +2651,7 @@ template<typename T> NOINLINE void general_r2c(
auto tdata = reinterpret_cast<T *>(storage.data());
for (size_t i=0; i<len; ++i)
tdata[i] = in[it.iofs(i)];
plan.forward(tdata, fct);
plan->forward(tdata, fct);
out[it.oofs(0)].Set(tdata[0]);
size_t i=1, ii=1;
if (forward)
......@@ -2650,7 +2669,7 @@ template<typename T> NOINLINE void general_c2r(
const cndarr<cmplx<T>> &in, ndarr<T> &out, size_t axis, bool forward, T fct,
size_t POCKETFFT_NTHREADS)
{
pocketfft_r<T> plan(out.shape(axis));
auto plan = get_plan<pocketfft_r<T>>(out.shape(axis));
constexpr auto vlen = VLEN<T>::val;
size_t len=out.shape(axis);
#ifdef POCKETFFT_OPENMP
......@@ -2688,7 +2707,7 @@ template<typename T> NOINLINE void general_c2r(
for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = in[it.iofs(j,ii)].r;
}
plan.backward(tdatav, fct);
plan->backward(tdatav, fct);
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,i)] = tdatav[i][j];
......@@ -2716,7 +2735,7 @@ template<typename T> NOINLINE void general_c2r(
if (i<len)
tdata[i] = in[it.iofs(ii)].r;
}
plan.backward(tdata, fct);
plan->backward(tdata, fct);
for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tdata[i];
}
......@@ -2727,14 +2746,14 @@ template<typename T> NOINLINE void general_r(
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, bool r2c,
bool forward, T fct, size_t POCKETFFT_NTHREADS)
{
unique_ptr<pocketfft_r<T>> plan;
shared_ptr<pocketfft_r<T>> plan;
for (size_t iax=0; iax<axes.size(); ++iax)
{
constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length()))
plan.reset(new pocketfft_r<T>(len));
plan = get_plan<pocketfft_r<T>>(len);
#ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment