Commit 589b7f50 authored by Peter Bell's avatar Peter Bell

WIP: Pre-planned transforms

parent c0f74f61
...@@ -2876,57 +2876,89 @@ template<typename T0> class T_dcst4 ...@@ -2876,57 +2876,89 @@ template<typename T0> class T_dcst4
// multi-D infrastructure // multi-D infrastructure
// //
template<typename T> shared_ptr<T> get_plan(size_t length) template <typename T> class planner
{ {
#if POCKETFFT_CACHE_SIZE==0 public:
return make_shared<T>(length); virtual shared_ptr<T> get_plan(size_t len) = 0;
#else virtual ~planner() = default;
constexpr size_t nmax=POCKETFFT_CACHE_SIZE; };
static array<shared_ptr<T>, nmax> cache;
static array<size_t, nmax> last_access{{0}};
static size_t access_counter = 0;
static mutex mut;
auto find_in_cache = [&]() -> shared_ptr<T> template <typename T> class plan_cache: public planner<T>
{ {
for (size_t i=0; i<nmax; ++i) vector<shared_ptr<T>> cache_;
if (cache[i] && (cache[i]->length()==length)) vector<size_t> last_access_;
{ size_t access_counter_ = 0;
// no need to update if this is already the most recent entry mutex mut_;
if (last_access[i]!=access_counter)
shared_ptr<T> find_in_cache(size_t length)
{
for (size_t i=0; i<cache_.size(); ++i)
if (cache_[i] && (cache_[i]->length()==length))
{ {
last_access[i] = ++access_counter; // no need to update if this is already the most recent entry
// Guard against overflow if (last_access_[i]!=access_counter_)
if (access_counter == 0) {
last_access.fill(0); last_access_[i] = ++access_counter_;
// Guard against overflow
if (access_counter_ == 0)
std::fill(last_access_.begin(), last_access_.end(), 0);
}
return cache_[i];
} }
return cache[i];
return {};
};
virtual shared_ptr<T> get_plan(size_t length)
{
{
unique_lock<mutex> lock(mut_);
if (cache_.size() == 0)
{
lock.unlock();
return make_shared<T>(length);
} }
auto p = find_in_cache(length);
if (p) return p;
}
auto plan = make_shared<T>(length);
{
lock_guard<mutex> lock(mut_);
auto p = find_in_cache(length);
if (p) return p;
return nullptr; size_t lru = 0;
}; for (size_t i=1; i<cache_.size(); ++i)
if (last_access_[i] < last_access_[lru])
lru = i;
{ cache_[lru] = plan;
lock_guard<mutex> lock(mut); last_access_[lru] = ++access_counter_;
auto p = find_in_cache(); }
if (p) return p; return plan;
} }
auto plan = make_shared<T>(length);
{
lock_guard<mutex> lock(mut);
auto p = find_in_cache();
if (p) return p;
size_t lru = 0; public:
for (size_t i=1; i<nmax; ++i)
if (last_access[i] < last_access[lru])
lru = i;
cache[lru] = plan; plan_cache(size_t cache_size):
last_access[lru] = ++access_counter; cache_(cache_size),
} last_access_(cache_size)
return plan; {
#endif }
void resize(size_t new_size)
{
lock_guard<mutex> lock(mut_);
cache_.resize(new_size);
last_access_.resize(new_size);
// TODO: When shrinking, remove lru elements first
}
};
template <typename T> plan_cache<T> & get_plan_cache()
{
static plan_cache<T> cache(POCKETFFT_CACHE_SIZE);
return cache;
} }
class arr_info class arr_info
...@@ -3235,15 +3267,17 @@ template <typename T> using add_vec_t = typename add_vec<T>::type; ...@@ -3235,15 +3267,17 @@ template <typename T> using add_vec_t = typename add_vec<T>::type;
template<typename Tplan, typename T, typename T0, typename Exec> template<typename Tplan, typename T, typename T0, typename Exec>
POCKETFFT_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out, POCKETFFT_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec, const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec,
const bool allow_inplace=true) planner<Tplan> *planner=nullptr, const bool allow_inplace=true)
{ {
shared_ptr<Tplan> plan; shared_ptr<Tplan> plan;
if (!planner)
planner = &get_plan_cache<Tplan>();
for (size_t iax=0; iax<axes.size(); ++iax) for (size_t iax=0; iax<axes.size(); ++iax)
{ {
size_t len=in.shape(axes[iax]); size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length())) if ((!plan) || (len!=plan->length()))
plan = get_plan<Tplan>(len); plan = planner->get_plan(len);
threading::thread_map( threading::thread_map(
util::thread_count(nthreads, in.shape(), axes[iax], VLEN<T>::val), util::thread_count(nthreads, in.shape(), axes[iax], VLEN<T>::val),
...@@ -3348,9 +3382,11 @@ struct ExecDcst ...@@ -3348,9 +3382,11 @@ struct ExecDcst
template<typename T> POCKETFFT_NOINLINE void general_r2c( template<typename T> POCKETFFT_NOINLINE void general_r2c(
const cndarr<T> &in, ndarr<cmplx<T>> &out, size_t axis, bool forward, T fct, const cndarr<T> &in, ndarr<cmplx<T>> &out, size_t axis, bool forward, T fct,
size_t nthreads) size_t nthreads, planner<pocketfft_r<T>> *planner=nullptr)
{ {
auto plan = get_plan<pocketfft_r<T>>(in.shape(axis)); if (!planner)
planner = &get_plan_cache<pocketfft_r<T>>();
auto plan = planner->get_plan(in.shape(axis));
size_t len=in.shape(axis); size_t len=in.shape(axis);
threading::thread_map( threading::thread_map(
util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val), util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),
...@@ -3403,9 +3439,11 @@ template<typename T> POCKETFFT_NOINLINE void general_r2c( ...@@ -3403,9 +3439,11 @@ template<typename T> POCKETFFT_NOINLINE void general_r2c(
} }
template<typename T> POCKETFFT_NOINLINE void general_c2r( template<typename T> POCKETFFT_NOINLINE void general_c2r(
const cndarr<cmplx<T>> &in, ndarr<T> &out, size_t axis, bool forward, T fct, const cndarr<cmplx<T>> &in, ndarr<T> &out, size_t axis, bool forward, T fct,
size_t nthreads) size_t nthreads, planner<pocketfft_r<T>> *planner=nullptr)
{ {
auto plan = get_plan<pocketfft_r<T>>(out.shape(axis)); if (!planner)
planner = &get_plan_cache<pocketfft_r<T>>();
auto plan = planner->get_plan(out.shape(axis));
size_t len=out.shape(axis); size_t len=out.shape(axis);
threading::thread_map( threading::thread_map(
util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val), util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),
...@@ -3481,16 +3519,56 @@ struct ExecR2R ...@@ -3481,16 +3519,56 @@ struct ExecR2R
} }
}; };
template <typename T> class preplanned_transform: public planner<T>
{
vector<shared_ptr<T>> plans;
public:
void add_plan(shared_ptr<T> plan)
{
plans.push_back(move(plan));
}
virtual shared_ptr<T> get_plan(size_t length)
{
auto it = std::find(plans.begin(), plans.end(), length);
if (it == plans.end())
throw invalid_argument("Bad pre-planned transform, axis length did not match");
return *it;
}
};
template <typename T> using c2c_plan_t = preplanned_transform<pocketfft_c<T>>;
template <typename T> c2c_plan_t<T> plan_c2c(const shape_t &shape, const shape_t &axes)
{
std::vector<size_t> lengths;
for (auto a : axes)
{
if (a >= shape.size())
throw invalid_argument("Axes exceeds dimensionality of input");
lengths.push_back(shape[a]);
}
std::sort(lengths.begin(), lengths.end());
auto last = std::unique(lengths.begin(), lengths.end());
c2c_plan_t<T> plan;
for (auto l : lengths)
plan.add_plan(make_shared<pocketfft_c<T>>(l));
return move(plan);
}
template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in, template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,
const stride_t &stride_out, const shape_t &axes, bool forward, const stride_t &stride_out, const shape_t &axes, bool forward,
const complex<T> *data_in, complex<T> *data_out, T fct, const complex<T> *data_in, complex<T> *data_out, T fct,
size_t nthreads=1) size_t nthreads=1, c2c_plan_t<T> *plan=nullptr)
{ {
if (util::prod(shape)==0) return; if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<cmplx<T>> ain(data_in, shape, stride_in); cndarr<cmplx<T>> ain(data_in, shape, stride_in);
ndarr<cmplx<T>> aout(data_out, shape, stride_out); ndarr<cmplx<T>> aout(data_out, shape, stride_out);
general_nd<pocketfft_c<T>>(ain, aout, axes, fct, nthreads, ExecC2C{forward}); general_nd<pocketfft_c<T>>(ain, aout, axes, fct, nthreads, ExecC2C{forward},
plan);
} }
template<typename T> void dct(const shape_t &shape, template<typename T> void dct(const shape_t &shape,
...@@ -3623,7 +3701,7 @@ template<typename T> void r2r_separable_hartley(const shape_t &shape, ...@@ -3623,7 +3701,7 @@ template<typename T> void r2r_separable_hartley(const shape_t &shape,
cndarr<T> ain(data_in, shape, stride_in); cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out); ndarr<T> aout(data_out, shape, stride_out);
general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads, ExecHartley{}, general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads, ExecHartley{},
false); nullptr, false);
} }
} // namespace detail } // namespace detail
...@@ -3632,6 +3710,7 @@ using detail::FORWARD; ...@@ -3632,6 +3710,7 @@ using detail::FORWARD;
using detail::BACKWARD; using detail::BACKWARD;
using detail::shape_t; using detail::shape_t;
using detail::stride_t; using detail::stride_t;
using detail::plan_c2c;
using detail::c2c; using detail::c2c;
using detail::c2r; using detail::c2r;
using detail::r2c; using detail::r2c;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment