...
 
Commits (2)
......@@ -3385,7 +3385,7 @@ template<typename T> POCKETFFT_NOINLINE void general_r2c(
size_t nthreads, planner<pocketfft_r<T>> *planner=nullptr)
{
if (!planner)
planner = &get_plan_cache<pocketfft_r<T>>();
planner = &get_plan_cache<pocketfft_r<T>>();
auto plan = planner->get_plan(in.shape(axis));
size_t len=in.shape(axis);
threading::thread_map(
......@@ -3523,6 +3523,23 @@ template <typename T> class preplanned_transform: public planner<T>
{
vector<shared_ptr<T>> plans;
public:
preplanned_transform(const shape_t &shape, const shape_t &axes)
{
std::vector<size_t> lengths;
for (auto a : axes)
{
if (a >= shape.size())
throw invalid_argument("Axes exceeds dimensionality of input");
lengths.push_back(shape[a]);
}
std::sort(lengths.begin(), lengths.end());
auto last = std::unique(lengths.begin(), lengths.end());
for (auto l : lengths)
add_plan(make_shared<T>(l));
}
void add_plan(shared_ptr<T> plan)
{
plans.push_back(move(plan));
......@@ -3530,9 +3547,11 @@ template <typename T> class preplanned_transform: public planner<T>
virtual shared_ptr<T> get_plan(size_t length)
{
auto it = std::find(plans.begin(), plans.end(), length);
auto it = std::find_if(plans.begin(), plans.end(),
[&](const shared_ptr<T> &plan){ return plan->length() == length; });
if (it == plans.end())
throw invalid_argument("Bad pre-planned transform, axis length did not match");
throw invalid_argument(
"Bad pre-planned transform, axis length did not match");
return *it;
}
};
......@@ -3540,24 +3559,6 @@ template <typename T> class preplanned_transform: public planner<T>
template <typename T> using c2c_plan_t = preplanned_transform<pocketfft_c<T>>;
template <typename T> c2c_plan_t<T> plan_c2c(const shape_t &shape, const shape_t &axes)
{
std::vector<size_t> lengths;
for (auto a : axes)
{
if (a >= shape.size())
throw invalid_argument("Axes exceeds dimensionality of input");
lengths.push_back(shape[a]);
}
std::sort(lengths.begin(), lengths.end());
auto last = std::unique(lengths.begin(), lengths.end());
c2c_plan_t<T> plan;
for (auto l : lengths)
plan.add_plan(make_shared<pocketfft_c<T>>(l));
return move(plan);
}
template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,
const stride_t &stride_out, const shape_t &axes, bool forward,
const complex<T> *data_in, complex<T> *data_out, T fct,
......@@ -3571,10 +3572,45 @@ template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,
plan);
}
template <typename T> struct dct_plan
{
unique_ptr<preplanned_transform<T_dct1<T>>> dct1_;
unique_ptr<preplanned_transform<T_dcst4<T>>> dct4_;
unique_ptr<preplanned_transform<T_dcst23<T>>> dct23_;
int type_;
dct_plan(): type_(-1) {}
dct_plan(int type, const shape_t &shape, const shape_t &axes):
type_(type)
{
switch (type)
{
case 1:
dct1_.reset(new preplanned_transform<T_dct1<T>>(shape, axes));
break;
case 2: case 3:
dct23_.reset(new preplanned_transform<T_dcst23<T>>(shape, axes));
break;
case 4:
dct4_.reset(new preplanned_transform<T_dcst4<T>>(shape, axes));
break;
default:
throw invalid_argument("Type must be in range [1, 4]");
}
}
};
template<typename T> void dct(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1)
int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1,
dct_plan<T> *plan=nullptr)
{
if (plan && plan->type_ != type)
throw invalid_argument("plan type mismatch");
static dct_plan<T> null_plan;
if (!plan) plan = &null_plan;
if ((type<1) || (type>4)) throw invalid_argument("invalid DCT type");
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
......@@ -3582,17 +3618,55 @@ template<typename T> void dct(const shape_t &shape,
ndarr<T> aout(data_out, shape, stride_out);
const ExecDcst exec{ortho, type, true};
if (type==1)
general_nd<T_dct1<T>>(ain, aout, axes, fct, nthreads, exec);
general_nd<T_dct1<T>>(ain, aout, axes, fct, nthreads, exec,
plan->dct1_.get());
else if (type==4)
general_nd<T_dcst4<T>>(ain, aout, axes, fct, nthreads, exec);
general_nd<T_dcst4<T>>(ain, aout, axes, fct, nthreads, exec,
plan->dct4_.get());
else
general_nd<T_dcst23<T>>(ain, aout, axes, fct, nthreads, exec);
general_nd<T_dcst23<T>>(ain, aout, axes, fct, nthreads, exec,
plan->dct23_.get());
}
template <typename T> struct dst_plan
{
unique_ptr<preplanned_transform<T_dst1<T>>> dst1_;
unique_ptr<preplanned_transform<T_dcst4<T>>> dst4_;
unique_ptr<preplanned_transform<T_dcst23<T>>> dst23_;
int type_;
dst_plan(): type_(-1) {}
dst_plan(int type, const shape_t &shape, const shape_t &axes):
type_(type)
{
switch (type)
{
case 1:
dst1_.reset(new preplanned_transform<T_dst1<T>>(shape, axes));
break;
case 2: case 3:
dst23_.reset(new preplanned_transform<T_dcst23<T>>(shape, axes));
break;
case 4:
dst4_.reset(new preplanned_transform<T_dcst4<T>>(shape, axes));
break;
default:
throw invalid_argument("Type must be in range [1, 4]");
}
}
};
template<typename T> void dst(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1)
int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1,
dst_plan<T> *plan=nullptr)
{
if (plan && plan->type_ != type)
throw invalid_argument("plan type mismatch");
static dst_plan<T> null_plan;
if (!plan) plan = &null_plan;
if ((type<1) || (type>4)) throw invalid_argument("invalid DST type");
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
......@@ -3600,17 +3674,50 @@ template<typename T> void dst(const shape_t &shape,
ndarr<T> aout(data_out, shape, stride_out);
const ExecDcst exec{ortho, type, false};
if (type==1)
general_nd<T_dst1<T>>(ain, aout, axes, fct, nthreads, exec);
general_nd<T_dst1<T>>(ain, aout, axes, fct, nthreads, exec,
plan->dst1_.get());
else if (type==4)
general_nd<T_dcst4<T>>(ain, aout, axes, fct, nthreads, exec);
general_nd<T_dcst4<T>>(ain, aout, axes, fct, nthreads, exec,
plan->dst4_.get());
else
general_nd<T_dcst23<T>>(ain, aout, axes, fct, nthreads, exec);
general_nd<T_dcst23<T>>(ain, aout, axes, fct, nthreads, exec,
plan->dst23_.get());
}
template <typename T> class preplanned_transform_1d : public planner<T>
{
shared_ptr<T> plan_;
public:
preplanned_transform_1d(size_t len):
plan_(make_shared<T>(len)) {}
shared_ptr<T> get_plan(size_t len)
{
if (plan_->length() != len)
throw invalid_argument(
"bad pre-planned transform, axis length did not match");
return plan_;
}
};
template <typename T> struct rc_plan
{
preplanned_transform_1d<pocketfft_r<T>> real_;
c2c_plan_t<T> cmplx_;
rc_plan(const shape_t &shape, const shape_t &axes):
real_(axes.back() < shape.size() ? shape[axes.back()] : 1),
cmplx_(shape, shape_t(axes.begin(), axes.begin() + axes.size() - 1))
{
if (axes.back() >= shape.size())
throw invalid_argument("bad axes argument");
}
};
template<typename T> void r2c(const shape_t &shape_in,
const stride_t &stride_in, const stride_t &stride_out, size_t axis,
bool forward, const T *data_in, complex<T> *data_out, T fct,
size_t nthreads=1)
size_t nthreads=1, rc_plan<T> *plan=nullptr)
{
if (util::prod(shape_in)==0) return;
util::sanity_check(shape_in, stride_in, stride_out, false, axis);
......@@ -3618,31 +3725,32 @@ template<typename T> void r2c(const shape_t &shape_in,
shape_t shape_out(shape_in);
shape_out[axis] = shape_in[axis]/2 + 1;
ndarr<cmplx<T>> aout(data_out, shape_out, stride_out);
general_r2c(ain, aout, axis, forward, fct, nthreads);
general_r2c(ain, aout, axis, forward, fct, nthreads,
plan ? &plan->real_ : nullptr);
}
template<typename T> void r2c(const shape_t &shape_in,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
bool forward, const T *data_in, complex<T> *data_out, T fct,
size_t nthreads=1)
size_t nthreads=1, rc_plan<T> *plan=nullptr)
{
if (util::prod(shape_in)==0) return;
util::sanity_check(shape_in, stride_in, stride_out, false, axes);
r2c(shape_in, stride_in, stride_out, axes.back(), forward, data_in, data_out,
fct, nthreads);
fct, nthreads, plan);
if (axes.size()==1) return;
shape_t shape_out(shape_in);
shape_out[axes.back()] = shape_in[axes.back()]/2 + 1;
auto newaxes = shape_t{axes.begin(), --axes.end()};
c2c(shape_out, stride_out, stride_out, newaxes, forward, data_out, data_out,
T(1), nthreads);
T(1), nthreads, plan ? &plan->cmplx_ : nullptr);
}
template<typename T> void c2r(const shape_t &shape_out,
const stride_t &stride_in, const stride_t &stride_out, size_t axis,
bool forward, const complex<T> *data_in, T *data_out, T fct,
size_t nthreads=1)
size_t nthreads=1, rc_plan<T> *plan=nullptr)
{
if (util::prod(shape_out)==0) return;
util::sanity_check(shape_out, stride_in, stride_out, false, axis);
......@@ -3650,18 +3758,19 @@ template<typename T> void c2r(const shape_t &shape_out,
shape_in[axis] = shape_out[axis]/2 + 1;
cndarr<cmplx<T>> ain(data_in, shape_in, stride_in);
ndarr<T> aout(data_out, shape_out, stride_out);
general_c2r(ain, aout, axis, forward, fct, nthreads);
general_c2r(ain, aout, axis, forward, fct, nthreads,
plan ? &plan->real_ : nullptr);
}
template<typename T> void c2r(const shape_t &shape_out,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
bool forward, const complex<T> *data_in, T *data_out, T fct,
size_t nthreads=1)
size_t nthreads=1, rc_plan<T> *plan=nullptr)
{
if (util::prod(shape_out)==0) return;
if (axes.size()==1)
return c2r(shape_out, stride_in, stride_out, axes[0], forward,
data_in, data_out, fct, nthreads);
data_in, data_out, fct, nthreads, plan);
util::sanity_check(shape_out, stride_in, stride_out, false, axes);
auto shape_in = shape_out;
shape_in[axes.back()] = shape_out[axes.back()]/2 + 1;
......@@ -3674,34 +3783,38 @@ template<typename T> void c2r(const shape_t &shape_out,
arr<complex<T>> tmp(nval);
auto newaxes = shape_t({axes.begin(), --axes.end()});
c2c(shape_in, stride_in, stride_inter, newaxes, forward, data_in, tmp.data(),
T(1), nthreads);
T(1), nthreads, plan ? &plan->cmplx_ : nullptr);
c2r(shape_out, stride_inter, stride_out, axes.back(), forward,
tmp.data(), data_out, fct, nthreads);
tmp.data(), data_out, fct, nthreads, plan);
}
template <typename T> using r2r_plan_t = preplanned_transform<pocketfft_r<T>>;
template<typename T> void r2r_fftpack(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
bool real2hermitian, bool forward, const T *data_in, T *data_out, T fct,
size_t nthreads=1)
size_t nthreads=1, r2r_plan_t<T> *plan=nullptr)
{
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out);
general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads,
ExecR2R{real2hermitian, forward});
ExecR2R{real2hermitian, forward}, plan);
}
using detail::dct_plan;
template<typename T> void r2r_separable_hartley(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
const T *data_in, T *data_out, T fct, size_t nthreads=1)
const T *data_in, T *data_out, T fct, size_t nthreads=1,
r2r_plan_t<T> *plan=nullptr)
{
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out);
general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads, ExecHartley{},
nullptr, false);
plan, false);
}
} // namespace detail
......@@ -3710,13 +3823,17 @@ using detail::FORWARD;
using detail::BACKWARD;
using detail::shape_t;
using detail::stride_t;
using detail::plan_c2c;
using detail::c2c_plan_t;
using detail::c2c;
using detail::rc_plan;
using detail::c2r;
using detail::r2c;
using detail::r2r_plan_t;
using detail::r2r_fftpack;
using detail::r2r_separable_hartley;
using detail::dct_plan;
using detail::dct;
using detail::dst_plan;
using detail::dst;
} // namespace pocketfft
......
......@@ -38,6 +38,55 @@ using f64 = double;
using flong = ldbl_t;
auto None = py::none();
enum class transform_type : unsigned char
{
c2c,
rc,
r2r,
dct,
dst
};
enum class data_type : unsigned char
{
flt, dbl, ldbl
};
struct plan
{
transform_type ttype_;
data_type dtype_;
shared_ptr<void> plan_;
};
template <typename T> struct type_enum;
template <> struct type_enum<float> :
integral_constant<data_type, data_type::flt> {};
template <> struct type_enum<double> :
integral_constant<data_type, data_type::dbl> {};
template <> struct type_enum<long double> :
integral_constant<data_type, data_type::ldbl> {};
template <typename T> struct plan_caster;
template <typename T> struct plan_caster<c2c_plan_t<T>>
{
static shared_ptr<c2c_plan_t<T>> cast (const plan &plan)
{
if (plan->ttype_ != transform_type::c2c)
throw invalid_argument("wrong type of plan");
if (plan->dtype_ != type_enum<T>::value)
throw invalid_argument("plan is for wrong dtype");
return reinterpret_pointer_cast<c2c_plan_t<T>>(plan->plan_);
}
};
template <typename T> plan_cast(const plan &plan)
{
return plan_caster<T>::cast(plan);
}
shape_t copy_shape(const py::array &arr)
{
shape_t res(size_t(arr.ndim()));
......