diff --git a/pypocketfft.cc b/pypocketfft.cc index 8b43ac40f586b486ace25d42f5c24122118606d7..247c7dc9c1b243e259c9af1d42aeeb52ac53c092 100644 --- a/pypocketfft.cc +++ b/pypocketfft.cc @@ -38,6 +38,55 @@ using f64 = double; using flong = ldbl_t; auto None = py::none(); +enum class transform_type : unsigned char + { + c2c, + rc, + r2r, + dct, + dst + }; + +enum class data_type : unsigned char + { + flt, dbl, ldbl + }; + +struct plan + { + transform_type ttype_; + data_type dtype_; + shared_ptr<void> plan_; + }; + +template <typename T> struct type_enum; +template <> struct type_enum<float> : + integral_constant<data_type, data_type::flt> {}; +template <> struct type_enum<double> : + integral_constant<data_type, data_type::dbl> {}; +template <> struct type_enum<long double> : + integral_constant<data_type, data_type::ldbl> {}; + +template <typename T> struct plan_caster; +template <typename T> struct plan_caster<c2c_plan_t<T>> + { + static shared_ptr<c2c_plan_t<T>> cast (const plan &plan) + { + if (plan->ttype_ != transform_type::c2c) + throw invalid_argument("wrong type of plan"); + + if (plan->dtype_ != type_enum<T>::value) + throw invalid_argument("plan is for wrong dtype"); + + return reinterpret_pointer_cast<c2c_plan_t<T>>(plan->plan_); + } + }; + +template <typename T> plan_cast(const plan &plan) + { + return plan_caster<T>::cast(plan); + } + shape_t copy_shape(const py::array &arr) { shape_t res(size_t(arr.ndim()));