diff --git a/pypocketfft.cc b/pypocketfft.cc index e96a0da8fb6171ab93b8b19c72fa73cf9e30d44f..2b4b873af0ee88a35959bf63d1d5aab9ee04e828 100644 --- a/pypocketfft.cc +++ b/pypocketfft.cc @@ -31,17 +31,10 @@ namespace py = pybind11; auto c64 = py::dtype("complex64"); auto c128 = py::dtype("complex128"); +//auto c256 = py::dtype("complex256"); auto f32 = py::dtype("float32"); auto f64 = py::dtype("float64"); - -bool tcheck(const py::array &arr, const py::object &t1, const py::object &t2) - { - if (arr.dtype().is(t1)) - return true; - if (arr.dtype().is(t2)) - return false; - throw runtime_error("unsupported data type"); - } +//auto f128 = py::dtype("float128"); shape_t copy_shape(const py::array &arr) { @@ -77,6 +70,13 @@ shape_t makeaxes(const py::array &in, py::object axes) return tmp; } +#define DISPATCH(arr, T1, T2, T3, func, args) \ + auto dtype = arr.dtype(); \ + if (dtype.is(T1)) return func<double> args; \ + if (dtype.is(T2)) return func<float> args; \ +/* if (dtype.is(T3)) return func<long double> args; */ \ + throw runtime_error("unsupported data type"); + template<typename T> py::array xfftn_internal(const py::array &in, const shape_t &axes, double fct, bool inplace, bool fwd) { @@ -91,12 +91,13 @@ template<typename T> py::array xfftn_internal(const py::array &in, py::array xfftn(const py::array &a, py::object axes, double fct, bool inplace, bool fwd) { - return tcheck(a, c128, c64) ? - xfftn_internal<double>(a, makeaxes(a, axes), fct, inplace, fwd) : - xfftn_internal<float> (a, makeaxes(a, axes), fct, inplace, fwd); + DISPATCH(a, c128, c64, c256, xfftn_internal, (a, makeaxes(a, axes), fct, + inplace, fwd)) } + py::array fftn(const py::array &a, py::object axes, double fct, bool inplace) { return xfftn(a, axes, fct, inplace, true); } + py::array ifftn(const py::array &a, py::object axes, double fct, bool inplace) { return xfftn(a, axes, fct, inplace, false); } @@ -112,11 +113,12 @@ template<typename T> py::array rfftn_internal(const py::array &in, reinterpret_cast<complex<T> *>(res.mutable_data()), T(fct)); return res; } + py::array rfftn(const py::array &in, py::object axes_, double fct) { - return tcheck(in, f64, f32) ? rfftn_internal<double>(in, axes_, fct) - : rfftn_internal<float> (in, axes_, fct); + DISPATCH(in, f64, f32, f128, rfftn_internal, (in, axes_, fct)) } + template<typename T> py::array xrfft_scipy(const py::array &in, size_t axis, double fct, bool inplace, bool fwd) { @@ -127,18 +129,16 @@ template<typename T> py::array xrfft_scipy(const py::array &in, reinterpret_cast<T *>(res.mutable_data()), T(fct)); return res; } + py::array rfft_scipy(const py::array &in, size_t axis, double fct, bool inplace) { - return tcheck(in, f64, f32) ? - xrfft_scipy<double>(in, axis, fct, inplace, true) : - xrfft_scipy<float> (in, axis, fct, inplace, true); + DISPATCH(in, f64, f32, f128, xrfft_scipy, (in, axis, fct, inplace, true)) } + py::array irfft_scipy(const py::array &in, size_t axis, double fct, bool inplace) { - return tcheck(in, f64, f32) ? - xrfft_scipy<double>(in, axis, fct, inplace, false) : - xrfft_scipy<float> (in, axis, fct, inplace, false); + DISPATCH(in, f64, f32, f128, xrfft_scipy, (in, axis, fct, inplace, false)) } template<typename T> py::array irfftn_internal(const py::array &in, py::object axes_, size_t lastsize, T fct) @@ -156,12 +156,11 @@ template<typename T> py::array irfftn_internal(const py::array &in, reinterpret_cast<T *>(res.mutable_data()), T(fct)); return res; } + py::array irfftn(const py::array &in, py::object axes_, size_t lastsize, double fct) { - return tcheck(in, c128, c64) ? - irfftn_internal<double>(in, axes_, lastsize, fct) : - irfftn_internal<float> (in, axes_, lastsize, fct); + DISPATCH(in, c128, c64, c256, irfftn_internal, (in, axes_, lastsize, fct)) } template<typename T> py::array hartley_internal(const py::array &in, @@ -174,12 +173,11 @@ template<typename T> py::array hartley_internal(const py::array &in, reinterpret_cast<T *>(res.mutable_data()), T(fct)); return res; } + py::array hartley(const py::array &in, py::object axes_, double fct, bool inplace) { - return tcheck(in, f64, f32) ? - hartley_internal<double>(in, axes_, fct, inplace) : - hartley_internal<float> (in, axes_, fct, inplace); + DISPATCH(in, f64, f32, f128, hartley_internal, (in, axes_, fct, inplace)) } template<typename T>py::array complex2hartley(const py::array &in, @@ -224,12 +222,13 @@ template<typename T>py::array complex2hartley(const py::array &in, } return out; } + py::array mycomplex2hartley(const py::array &in, const py::array &tmp, py::object axes_, bool inplace) { - return tcheck(in, f64, f32) ? complex2hartley<double>(in, tmp, axes_, inplace) - : complex2hartley<float> (in, tmp, axes_, inplace); + DISPATCH(in, f64, f32, f128, complex2hartley, (in, tmp, axes_, inplace)) } + py::array hartley2(const py::array &in, py::object axes_, double fct, bool inplace) { return mycomplex2hartley(in, rfftn(in, axes_, fct), axes_, inplace); }