From 46b8e3af62e4963bb01926a2db237a9707d307a7 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Wed, 24 Apr 2019 00:27:17 +0200 Subject: [PATCH] cleanup --- pocketfft.cc | 101 +++++++++++++++++++++++++-------------------------- 1 file changed, 50 insertions(+), 51 deletions(-) diff --git a/pocketfft.cc b/pocketfft.cc index 6f54dcc..80505b5 100644 --- a/pocketfft.cc +++ b/pocketfft.cc @@ -1581,7 +1581,7 @@ template<typename T> NOINLINE void radbg(size_t ido, size_t ip, size_t l1, #undef MULPM #undef WA -template<typename T> static void copy_and_norm(T *c, T *p1, size_t n, T0 fct) +template<typename T> void copy_and_norm(T *c, T *p1, size_t n, T0 fct) { if (p1!=c) { @@ -1963,6 +1963,9 @@ template<typename T0> class pocketfft_r // multi-D infrastructure // +using shape_t = vector<size_t>; +using stride_t = vector<int64_t>; + struct diminfo { size_t n; int64_t s; }; class multiarr @@ -1971,7 +1974,7 @@ class multiarr vector<diminfo> dim; public: - multiarr (const vector<size_t> &n, const vector<int64_t> &s) + multiarr (const shape_t &n, const stride_t &s) { dim.reserve(n.size()); for (size_t i=0; i<n.size(); ++i) @@ -1986,7 +1989,7 @@ class multi_iter { public: vector<diminfo> dim; - vector<size_t> pos; + shape_t pos; size_t ofs_, len; int64_t str; int64_t rem; @@ -2054,7 +2057,7 @@ template<> struct VTYPE<float> { using type = __m128; }; #endif -template<typename T> arr<char> alloc_tmp(const vector<size_t> &shape, +template<typename T> arr<char> alloc_tmp(const shape_t &shape, size_t tmpsize, size_t elemsize) { #ifdef HAVE_VECSUPPORT @@ -2064,8 +2067,8 @@ template<typename T> arr<char> alloc_tmp(const vector<size_t> &shape, #endif return arr<char>(tmpsize*elemsize); } -template<typename T> arr<char> alloc_tmp(const vector<size_t> &shape, - const vector<size_t> &axes, size_t elemsize) +template<typename T> arr<char> alloc_tmp(const shape_t &shape, + const shape_t &axes, size_t elemsize) { size_t fullsize=1; size_t ndim = shape.size(); @@ -2078,9 +2081,9 @@ template<typename T> arr<char> alloc_tmp(const vector<size_t> &shape, return alloc_tmp<T>(shape, tmpsize, elemsize); } -template<typename T> void pocketfft_general_c(const vector<size_t> &shape, - const vector<int64_t> &stride_in, const vector<int64_t> &stride_out, - const vector<size_t> &axes, bool forward, const cmplx<T> *data_in, +template<typename T> void pocketfft_general_c(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, + const shape_t &axes, bool forward, const cmplx<T> *data_in, cmplx<T> *data_out, T fct) { auto storage = alloc_tmp<T>(shape, axes, sizeof(cmplx<T>)); @@ -2144,9 +2147,9 @@ template<typename T> void pocketfft_general_c(const vector<size_t> &shape, } } -template<typename T> void pocketfft_general_hartley(const vector<size_t> &shape, - const vector<int64_t> &stride_in, const vector<int64_t> &stride_out, - const vector<size_t> &axes, const T *data_in, T *data_out, T fct) +template<typename T> void pocketfft_general_hartley(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, + const shape_t &axes, const T *data_in, T *data_out, T fct) { auto storage = alloc_tmp<T>(shape, axes, sizeof(T)); #ifdef HAVE_VECSUPPORT @@ -2218,8 +2221,8 @@ template<typename T> void pocketfft_general_hartley(const vector<size_t> &shape, } } -template<typename T> void pocketfft_general_r2c(const vector<size_t> &shape, - const vector<int64_t> &stride_in, const vector<int64_t> &stride_out, size_t axis, +template<typename T> void pocketfft_general_r2c(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, size_t axis, const T *data_in, cmplx<T> *data_out, T fct) { auto storage = alloc_tmp<T>(shape, shape[axis], sizeof(T)); @@ -2299,8 +2302,8 @@ template<typename T> void pocketfft_general_r2c(const vector<size_t> &shape, it_out.advance(); } } -template<typename T> void pocketfft_general_c2r(const vector<size_t> &shape_out, - const vector<int64_t> &stride_in, const vector<int64_t> &stride_out, size_t axis, +template<typename T> void pocketfft_general_c2r(const shape_t &shape_out, + const stride_t &stride_in, const stride_t &stride_out, size_t axis, const cmplx<T> *data_in, T *data_out, T fct) { auto storage = alloc_tmp<T>(shape_out, shape_out[axis], sizeof(T)); @@ -2384,17 +2387,26 @@ auto c128 = py::dtype("complex128"); auto f32 = py::dtype("float32"); auto f64 = py::dtype("float64"); -vector<size_t> copy_shape(const py::array &arr) +bool tcheck(const py::array &arr, const py::object &t1, const py::object &t2) + { + if (arr.dtype().is(t1)) + return true; + if (arr.dtype().is(t2)) + return false; + throw runtime_error("unsupported data type"); + } + +shape_t copy_shape(const py::array &arr) { - vector<size_t> res(arr.ndim()); + shape_t res(arr.ndim()); for (size_t i=0; i<res.size(); ++i) res[i] = arr.shape(i); return res; } -vector<int64_t> copy_strides(const py::array &arr) +stride_t copy_strides(const py::array &arr) { - vector<int64_t> res(arr.ndim()); + stride_t res(arr.ndim()); for (size_t i=0; i<res.size(); ++i) { res[i] = arr.strides(i)/arr.itemsize(); @@ -2404,16 +2416,16 @@ vector<int64_t> copy_strides(const py::array &arr) return res; } -vector<size_t> makeaxes(const py::array &in, py::object axes) +shape_t makeaxes(const py::array &in, py::object axes) { if (axes.is(py::none())) { - vector<size_t> res(in.ndim()); + shape_t res(in.ndim()); for (size_t i=0; i<res.size(); ++i) res[i]=i; return res; } - auto tmp=axes.cast<vector<size_t>>(); + auto tmp=axes.cast<shape_t>(); if ((tmp.size()>size_t(in.ndim())) || (tmp.size()==0)) throw runtime_error("bad axes argument"); for (size_t i=0; i<tmp.size(); ++i) @@ -2423,7 +2435,7 @@ vector<size_t> makeaxes(const py::array &in, py::object axes) } template<typename T> py::array xfftn_internal(const py::array &in, - const vector<size_t> &axes, double fct, bool fwd) + const shape_t &axes, double fct, bool fwd) { auto dims(copy_shape(in)); py::array res = py::array_t<complex<T>>(dims); @@ -2435,11 +2447,9 @@ template<typename T> py::array xfftn_internal(const py::array &in, py::array xfftn(const py::array &a, py::object axes, double fct, bool fwd) { - if (a.dtype().is(c128)) - return xfftn_internal<double>(a, makeaxes(a, axes), fct, fwd); - else if (a.dtype().is(c64)) - return xfftn_internal<float>(a, makeaxes(a, axes), fct, fwd); - throw runtime_error("unsupported data type"); + return tcheck(a, c128, c64) ? + xfftn_internal<double>(a, makeaxes(a, axes), fct, fwd) : + xfftn_internal<float> (a, makeaxes(a, axes), fct, fwd); } py::array fftn(const py::array &a, py::object axes, double fct) { return xfftn(a, axes, fct, true); } @@ -2457,7 +2467,7 @@ template<typename T> py::array rfftn_internal(const py::array &in, pocketfft_general_r2c<T>(dims_in, s_i, s_o, axes.back(), (const T *)in.data(), (cmplx<T> *)res.mutable_data(), fct); if (axes.size()==1) return res; - vector<size_t> axes2(axes.size()-1); + shape_t axes2(axes.size()-1); for (size_t i=0; i<axes2.size(); ++i) axes2[i] = axes[i]; pocketfft_general_c<T>(dims_out, s_o, s_o, axes2, true, @@ -2466,11 +2476,8 @@ template<typename T> py::array rfftn_internal(const py::array &in, } py::array rfftn(const py::array &in, py::object axes_, double fct) { - if (in.dtype().is(f64)) - return rfftn_internal<double>(in, axes_, fct); - else if (in.dtype().is(f32)) - return rfftn_internal<float>(in, axes_, fct); - else throw runtime_error("unsupported data type"); + return tcheck(in, f64, f32) ? rfftn_internal<double>(in, axes_, fct) + : rfftn_internal<float> (in, axes_, fct); } template<typename T> py::array irfftn_internal(const py::array &in, py::object axes_, size_t lastsize, T fct) @@ -2479,7 +2486,7 @@ template<typename T> py::array irfftn_internal(const py::array &in, py::array inter; if (axes.size()>1) // need to do complex transforms { - vector<size_t> axes2(axes.size()-1); + shape_t axes2(axes.size()-1); for (size_t i=0; i<axes2.size(); ++i) axes2[i] = axes[i]; inter = xfftn_internal<T>(in, axes2, 1., false); @@ -2502,11 +2509,9 @@ template<typename T> py::array irfftn_internal(const py::array &in, py::array irfftn(const py::array &in, py::object axes_, size_t lastsize, double fct) { - if (in.dtype().is(c128)) - return irfftn_internal<double>(in, axes_, lastsize, fct); - else if (in.dtype().is(c64)) - return irfftn_internal<float>(in, axes_, lastsize, fct); - throw runtime_error("unsupported data type"); + return tcheck(in, c128, c64) ? + irfftn_internal<double>(in, axes_, lastsize, fct) : + irfftn_internal<float> (in, axes_, lastsize, fct); } template<typename T> py::array hartley_internal(const py::array &in, @@ -2522,11 +2527,8 @@ template<typename T> py::array hartley_internal(const py::array &in, } py::array hartley(const py::array &in, py::object axes_, double fct) { - if (in.dtype().is(f64)) - return hartley_internal<double>(in, axes_, fct); - else if (in.dtype().is(f32)) - return hartley_internal<float>(in, axes_, fct); - throw runtime_error("unsupported data type"); + return tcheck(in, f64, f32) ? hartley_internal<double>(in, axes_, fct) + : hartley_internal<float> (in, axes_, fct); } template<typename T>py::array complex2hartley(const py::array &in, @@ -2579,11 +2581,8 @@ template<typename T>py::array complex2hartley(const py::array &in, py::array mycomplex2hartley(const py::array &in, const py::array &tmp, py::object axes_) { - if (in.dtype().is(f64)) - return complex2hartley<double>(in, tmp, axes_); - else if (in.dtype().is(f32)) - return complex2hartley<float>(in, tmp, axes_); - else throw runtime_error("unsupported data type"); + return tcheck(in, f64, f32) ? complex2hartley<double>(in, tmp, axes_) + : complex2hartley<float> (in, tmp, axes_); } py::array hartley2(const py::array &in, py::object axes_, double fct) { return mycomplex2hartley(in, rfftn(in, axes_, fct), axes_); } -- GitLab