From 007a96f6fc15421c82b01bff7c05bbd9d2a2a68b Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Mon, 29 Apr 2019 11:53:00 +0200 Subject: [PATCH] simplify --- pocketfft.cc | 79 ++++++---------------------------------------------- 1 file changed, 8 insertions(+), 71 deletions(-) diff --git a/pocketfft.cc b/pocketfft.cc index 73f7b3a..8488848 100644 --- a/pocketfft.cc +++ b/pocketfft.cc @@ -1968,63 +1968,14 @@ template<typename T0> class pocketfft_r using shape_t = vector<size_t>; using stride_t = vector<int64_t>; -struct diminfo_io - { size_t n; int64_t s_i, s_o; }; - -class multi_iter_io +template<size_t N, typename Ti, typename To> class multi_iter { public: + struct diminfo_io + { size_t n; int64_t s_i, s_o; }; vector<diminfo_io> dim; shape_t pos; - int64_t ofs_i, ofs_o; - size_t len_i, len_o; - int64_t str_i, str_o; - size_t rem; - - public: - multi_iter_io(const shape_t &shape_i, const stride_t &stride_i, - const shape_t &shape_o, const stride_t &stride_o, size_t idim) - : pos(shape_i.size()-1, 0), ofs_i(0), ofs_o(0), len_i(shape_i[idim]), - len_o(shape_o[idim]), str_i(stride_i[idim]), - str_o(stride_o[idim]), rem(1) - { - dim.reserve(shape_i.size()-1); - for (size_t i=0; i<shape_i.size(); ++i) - if (i!=idim) - { - dim.push_back({shape_i[i], stride_i[i], stride_o[i]}); - rem *= shape_i[i]; - } - } - void advance() - { - if (--rem==0) return; - for (int i=pos.size()-1; i>=0; --i) - { - ++pos[i]; - ofs_i += dim[i].s_i; - ofs_o += dim[i].s_o; - if (pos[i] < dim[i].n) - return; - pos[i] = 0; - ofs_i -= dim[i].n*dim[i].s_i; - ofs_o -= dim[i].n*dim[i].s_o; - } - } - int64_t offset_in() const { return ofs_i; } - int64_t offset_out() const { return ofs_o; } - size_t length_in() const { return len_i; } - size_t length_out() const { return len_o; } - int64_t stride_in() const { return str_i; } - int64_t stride_out() const { return str_o; } - size_t remaining() const { return rem; } - }; - -template<size_t N, typename Ti, typename To> class multi_iter - { private: - vector<diminfo_io> dim; - shape_t pos; const Ti *p_ii, *p_i[N]; To *p_oi, *p_o[N]; size_t len_i, len_o; @@ -2148,18 +2099,6 @@ template<typename T> arr<char> alloc_tmp(const shape_t &shape, return arr<char>(tmpsize*elemsize); } -template<size_t vlen> struct multioffset_io - { - int64_t ofs_i[vlen], ofs_o[vlen]; - multioffset_io(multi_iter_io &it) - { - for (size_t i=0; i<vlen; ++i) - { ofs_i[i] = it.offset_in(); ofs_o[i] = it.offset_out(); it.advance(); } - } - int64_t in(size_t i) const { return ofs_i[i]; } - int64_t out(size_t i) const { return ofs_o[i]; } - }; - template<typename T> NOINLINE void pocketfft_general_c(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, bool forward, const cmplx<T> *data_in, @@ -2536,17 +2475,15 @@ template<typename T>py::array complex2hartley(const py::array &in, auto stride_tmp(copy_strides(tmp)), stride_out(copy_strides(out)); auto axes = makeaxes(in, axes_); size_t axis = axes.back(); - multi_iter_io it(dims_tmp, stride_tmp, dims_out, stride_out, axis); const cmplx<T> *tdata = (const cmplx<T> *)tmp.data(); T *odata = (T *)out.mutable_data(); + multi_iter<1,cmplx<T>,T> it(tdata, dims_tmp, stride_tmp, odata, dims_out, stride_out, axis); vector<bool> swp(ndim-1,false); for (auto i: axes) if (i!=axis) swp[i<axis ? i : i-1] = true; while(it.remaining()>0) { - auto tofs = it.offset_in(); - auto ofs = it.offset_out(); int64_t rofs = 0; for (size_t i=0; i<it.pos.size(); ++i) { @@ -2558,15 +2495,15 @@ template<typename T>py::array complex2hartley(const py::array &in, rofs += x*it.dim[i].s_o; } } + it.advance(1); for (size_t i=0; i<it.length_in(); ++i) { - auto re = tdata[tofs + i*it.stride_in()].r; - auto im = tdata[tofs + i*it.stride_in()].i; + auto re = it.in(0,i).r; + auto im = it.in(0,i).i; auto rev_i = (i==0) ? 0 : it.length_out()-i; - odata[ofs + i*it.stride_out()] = re+im; + it.out(0,i) = re+im; odata[rofs + rev_i*it.stride_out()] = re-im; } - it.advance(); } return out; } -- GitLab