Commit c63461d6 authored by Peter Bell's avatar Peter Bell

WIP: Pre-planned transforms

parent c0f74f61
......@@ -2876,57 +2876,89 @@ template<typename T0> class T_dcst4
// multi-D infrastructure
//
template<typename T> shared_ptr<T> get_plan(size_t length)
template <typename T> class planner
{
#if POCKETFFT_CACHE_SIZE==0
return make_shared<T>(length);
#else
constexpr size_t nmax=POCKETFFT_CACHE_SIZE;
static array<shared_ptr<T>, nmax> cache;
static array<size_t, nmax> last_access{{0}};
static size_t access_counter = 0;
static mutex mut;
public:
virtual shared_ptr<T> get_plan(size_t len) = 0;
virtual ~planner() = default;
};
auto find_in_cache = [&]() -> shared_ptr<T>
{
for (size_t i=0; i<nmax; ++i)
if (cache[i] && (cache[i]->length()==length))
{
// no need to update if this is already the most recent entry
if (last_access[i]!=access_counter)
template <typename T> class plan_cache: public planner<T>
{
vector<shared_ptr<T>> cache_;
vector<size_t> last_access_;
size_t access_counter_ = 0;
mutex mut_;
shared_ptr<T> find_in_cache(size_t length)
{
for (size_t i=0; i<cache_.size(); ++i)
if (cache_[i] && (cache_[i]->length()==length))
{
last_access[i] = ++access_counter;
// Guard against overflow
if (access_counter == 0)
last_access.fill(0);
// no need to update if this is already the most recent entry
if (last_access_[i]!=access_counter_)
{
last_access_[i] = ++access_counter_;
// Guard against overflow
if (access_counter_ == 0)
std::fill(last_access_.begin(), last_access_.end(), 0);
}
return cache_[i];
}
return cache[i];
return {};
};
virtual shared_ptr<T> get_plan(size_t length)
{
{
unique_lock<mutex> lock(mut_);
if (cache_.size() == 0)
{
lock.unlock();
return make_shared<T>(length);
}
auto p = find_in_cache(length);
if (p) return p;
}
auto plan = make_shared<T>(length);
{
lock_guard<mutex> lock(mut_);
auto p = find_in_cache(length);
if (p) return p;
return nullptr;
};
size_t lru = 0;
for (size_t i=1; i<cache_.size(); ++i)
if (last_access_[i] < last_access_[lru])
lru = i;
{
lock_guard<mutex> lock(mut);
auto p = find_in_cache();
if (p) return p;
}
auto plan = make_shared<T>(length);
{
lock_guard<mutex> lock(mut);
auto p = find_in_cache();
if (p) return p;
cache_[lru] = plan;
last_access_[lru] = ++access_counter_;
}
return plan;
}
size_t lru = 0;
for (size_t i=1; i<nmax; ++i)
if (last_access[i] < last_access[lru])
lru = i;
public:
cache[lru] = plan;
last_access[lru] = ++access_counter;
}
return plan;
#endif
plan_cache(size_t cache_size):
cache_(cache_size),
last_access_(cache_size)
{
}
void resize(size_t new_size)
{
lock_guard<mutex> lock(mut_);
cache_.resize(new_size);
last_access_.resize(new_size);
// TODO: When shrinking, remove lru elements first
}
};
template <typename T> plan_cache<T> & get_plan_cache()
{
static plan_cache<T> cache(POCKETFFT_CACHE_SIZE);
return cache;
}
class arr_info
......@@ -3235,15 +3267,17 @@ template <typename T> using add_vec_t = typename add_vec<T>::type;
template<typename Tplan, typename T, typename T0, typename Exec>
POCKETFFT_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec,
const bool allow_inplace=true)
planner<Tplan> *planner=nullptr, const bool allow_inplace=true)
{
shared_ptr<Tplan> plan;
if (!planner)
planner = &get_plan_cache<Tplan>();
for (size_t iax=0; iax<axes.size(); ++iax)
{
size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length()))
plan = get_plan<Tplan>(len);
plan = planner->get_plan(len);
threading::thread_map(
util::thread_count(nthreads, in.shape(), axes[iax], VLEN<T>::val),
......@@ -3348,9 +3382,9 @@ struct ExecDcst
template<typename T> POCKETFFT_NOINLINE void general_r2c(
const cndarr<T> &in, ndarr<cmplx<T>> &out, size_t axis, bool forward, T fct,
size_t nthreads)
size_t nthreads, planner<pocketfft_r<T>> *planner=nullptr)
{
auto plan = get_plan<pocketfft_r<T>>(in.shape(axis));
auto plan = planner->get_plan(in.shape(axis));
size_t len=in.shape(axis);
threading::thread_map(
util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),
......@@ -3403,9 +3437,11 @@ template<typename T> POCKETFFT_NOINLINE void general_r2c(
}
template<typename T> POCKETFFT_NOINLINE void general_c2r(
const cndarr<cmplx<T>> &in, ndarr<T> &out, size_t axis, bool forward, T fct,
size_t nthreads)
size_t nthreads, planner<pocketfft_r<T>> *planner=nullptr)
{
auto plan = get_plan<pocketfft_r<T>>(out.shape(axis));
if (!planner)
planner = &get_plan_cache<pocketfft_r<T>>();
auto plan = planner->get_plan(out.shape(axis));
size_t len=out.shape(axis);
threading::thread_map(
util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),
......@@ -3481,16 +3517,56 @@ struct ExecR2R
}
};
template <typename T> class preplanned_transform: public planner<T>
{
vector<shared_ptr<T>> plans;
public:
void add_plan(shared_ptr<T> plan)
{
plans.push_back(move(plan));
}
virtual shared_ptr<T> get_plan(size_t length)
{
auto it = std::find(plans.begin(), plans.end(), length);
if (it == plans.end())
throw invalid_argument("Bad pre-planned transform, axis length did not match");
return *it;
}
};
template <typename T> using c2c_plan_t = preplanned_transform<pocketfft_c<T>>;
template <typename T> c2c_plan_t<T> plan_c2c(const shape_t &shape, const shape_t &axes)
{
std::vector<size_t> lengths;
for (auto a : axes)
{
if (a >= shape.size())
throw invalid_argument("Axes exceeds dimensionality of input");
lengths.push_back(shape[a]);
}
std::sort(lengths.begin(), lengths.end());
auto last = std::unique(lengths.begin(), lengths.end());
c2c_plan_t<T> plan;
for (auto l : lengths)
plan.add_plan(make_shared<pocketfft_c<T>>(l));
return move(plan);
}
template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,
const stride_t &stride_out, const shape_t &axes, bool forward,
const complex<T> *data_in, complex<T> *data_out, T fct,
size_t nthreads=1)
size_t nthreads=1, c2c_plan_t<T> *plan=nullptr)
{
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<cmplx<T>> ain(data_in, shape, stride_in);
ndarr<cmplx<T>> aout(data_out, shape, stride_out);
general_nd<pocketfft_c<T>>(ain, aout, axes, fct, nthreads, ExecC2C{forward});
general_nd<pocketfft_c<T>>(ain, aout, axes, fct, nthreads, ExecC2C{forward},
plan);
}
template<typename T> void dct(const shape_t &shape,
......@@ -3623,7 +3699,7 @@ template<typename T> void r2r_separable_hartley(const shape_t &shape,
cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out);
general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads, ExecHartley{},
false);
nullptr, false);
}
} // namespace detail
......@@ -3632,6 +3708,7 @@ using detail::FORWARD;
using detail::BACKWARD;
using detail::shape_t;
using detail::stride_t;
using detail::plan_c2c;
using detail::c2c;
using detail::c2r;
using detail::r2c;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment