diff --git a/pocketfft.cc b/pocketfft.cc index 3721f05cdd798505fd3fda18a1c349222d766ca7..73f7b3a51b7414252c854d546ecac1ddfa465a6c 100644 --- a/pocketfft.cc +++ b/pocketfft.cc @@ -2020,6 +2020,68 @@ class multi_iter_io size_t remaining() const { return rem; } }; +template<size_t N, typename Ti, typename To> class multi_iter + { + private: + vector<diminfo_io> dim; + shape_t pos; + const Ti *p_ii, *p_i[N]; + To *p_oi, *p_o[N]; + size_t len_i, len_o; + int64_t str_i, str_o; + size_t rem; + + void advance_i() + { + for (int i=pos.size()-1; i>=0; --i) + { + ++pos[i]; + p_ii += dim[i].s_i; + p_oi += dim[i].s_o; + if (pos[i] < dim[i].n) + return; + pos[i] = 0; + p_ii -= dim[i].n*dim[i].s_i; + p_oi -= dim[i].n*dim[i].s_o; + } + } + + public: + multi_iter(const Ti *d_i, const shape_t &shape_i, const stride_t &stride_i, + To *d_o, const shape_t &shape_o, const stride_t &stride_o, size_t idim) + : pos(shape_i.size()-1, 0), p_ii(d_i), p_oi(d_o), len_i(shape_i[idim]), + len_o(shape_o[idim]), str_i(stride_i[idim]), + str_o(stride_o[idim]), rem(1) + { + dim.reserve(shape_i.size()-1); + for (size_t i=0; i<shape_i.size(); ++i) + if (i!=idim) + { + dim.push_back({shape_i[i], stride_i[i], stride_o[i]}); + rem *= shape_i[i]; + } + } + void advance(size_t n) + { + if (rem<n) throw runtime_error("underrun"); + for (size_t i=0; i<n; ++i) + { + p_i[i] = p_ii; + p_o[i] = p_oi; + advance_i(); + } + rem -= n; + } + const Ti &in (size_t j, size_t i) const { return p_i[j][i*str_i]; } + To &out (size_t j, size_t i) const { return p_o[j][i*str_o]; } + size_t length_in() const { return len_i; } + size_t length_out() const { return len_o; } + int64_t stride_in() const { return str_i; } + int64_t stride_out() const { return str_o; } + size_t remaining() const { return rem; } + bool inplace() const { return p_ii==p_oi; } + }; + #if (defined(__AVX512F__)) #define HAVE_VECSUPPORT @@ -2104,61 +2166,58 @@ template<typename T> NOINLINE void pocketfft_general_c(const shape_t &shape, cmplx<T> *data_out, T fct) { auto storage = alloc_tmp<T>(shape, axes, sizeof(cmplx<T>)); - using vtype = typename VTYPE<T>::type; - auto tdatav = (cmplx<vtype> *)storage.data(); - auto tdata = (cmplx<T> *)storage.data(); - constexpr int vlen = VTYPE<T>::vlen; - unique_ptr<pocketfft_c<T>> plan; for (size_t iax=0; iax<axes.size(); ++iax) { - multi_iter_io it(shape, iax==0 ? stride_in : stride_out, shape, stride_out, - axes[iax]); + constexpr int vlen = VTYPE<T>::vlen; + multi_iter<vlen, cmplx<T>, cmplx<T>> it(iax==0? data_in : data_out, shape, + iax==0 ? stride_in : stride_out, data_out, shape, stride_out, axes[iax]); size_t len=it.length_in(); - int64_t s_i=it.stride_in(), s_o=it.stride_out(); if ((!plan) || (len!=plan->length())) plan.reset(new pocketfft_c<T>(len)); if (vlen>1) while (it.remaining()>=vlen) { - multioffset_io<vlen> p(it); + using vtype = typename VTYPE<T>::type; + it.advance(vlen); + auto tdatav = (cmplx<vtype> *)storage.data(); for (size_t i=0; i<len; ++i) for (size_t j=0; j<vlen; ++j) { - tdatav[i].r[j] = data_in[p.in(j)+i*s_i].r; - tdatav[i].i[j] = data_in[p.out(j)+i*s_i].i; + tdatav[i].r[j] = it.in(j,i).r; + tdatav[i].i[j] = it.in(j,i).i; } forward ? plan->forward((vtype *)tdatav, fct) : plan->backward((vtype *)tdatav, fct); for (size_t i=0; i<len; ++i) for (size_t j=0; j<vlen; ++j) - data_out[p.out(j)+i*s_o].Set(tdatav[i].r[j],tdatav[i].i[j]); + it.out(j,i).Set(tdatav[i].r[j],tdatav[i].i[j]); } while (it.remaining()>0) { - if ((data_in==data_out) && s_i==1) // fully in-place - forward ? plan->forward((T *)(data_in+it.offset_in()), fct) - : plan->backward((T *)(data_in+it.offset_in()), fct); - else if (s_o==1) // compute FFT in output location + it.advance(1); + auto tdata = (cmplx<T> *)storage.data(); + if (it.inplace() && it.stride_in()==1) // fully in-place + forward ? plan->forward((T *)(&it.in(0,0)), fct) + : plan->backward((T *)(&it.in(0,0)), fct); + else if (it.stride_out()==1) // compute FFT in output location { for (size_t i=0; i<len; ++i) - data_out[it.offset_out()+i] = data_in[it.offset_in()+i*s_i]; - forward ? plan->forward((T *)(data_out+it.offset_out()), fct) - : plan->backward((T *)(data_out+it.offset_out()), fct); + it.out(0,i) = it.in(0,i); + forward ? plan->forward((T *)(&it.out(0,0)), fct) + : plan->backward((T *)(&it.out(0,0)), fct); } else { for (size_t i=0; i<len; ++i) - tdata[i] = data_in[it.offset_in() + i*s_i]; + tdata[i] = it.in(0,i); forward ? plan->forward((T *)tdata, fct) : plan->backward((T *)tdata, fct); for (size_t i=0; i<len; ++i) - data_out[it.offset_out()+i*s_o] = tdata[i]; + it.out(0,i) = tdata[i]; } - it.advance(); } - data_in = data_out; // after the first dimension, take data from output array fct = T(1); // factor has been applied, use 1 for remaining axes } } @@ -2168,60 +2227,57 @@ template<typename T> NOINLINE void pocketfft_general_hartley(const shape_t &shap const shape_t &axes, const T *data_in, T *data_out, T fct) { auto storage = alloc_tmp<T>(shape, axes, sizeof(T)); - using vtype = typename VTYPE<T>::type; - auto tdatav = (vtype *)storage.data(); - constexpr int vlen = VTYPE<T>::vlen; - auto tdata = (T *)storage.data(); - unique_ptr<pocketfft_r<T>> plan; for (size_t iax=0; iax<axes.size(); ++iax) { - multi_iter_io it(shape, iax==0 ? stride_in : stride_out, shape, stride_out, - axes[iax]); + constexpr int vlen = VTYPE<T>::vlen; + multi_iter<vlen, T, T> it(iax==0 ? data_in : data_out, shape, + iax==0 ? stride_in : stride_out, data_out, shape, stride_out, axes[iax]); size_t len=it.length_in(); - int64_t s_i=it.stride_in(), s_o=it.stride_out(); if ((!plan) || (len!=plan->length())) plan.reset(new pocketfft_r<T>(len)); if (vlen>1) while (it.remaining()>=vlen) { - multioffset_io<vlen> p(it); + using vtype = typename VTYPE<T>::type; + it.advance(vlen); + auto tdatav = (vtype *)storage.data(); for (size_t i=0; i<len; ++i) for (size_t j=0; j<vlen; ++j) - tdatav[i][j] = data_in[p.in(j)+i*s_i]; + tdatav[i][j] = it.in(j,i); plan->forward((vtype *)tdatav, fct); for (size_t j=0; j<vlen; ++j) - data_out[p.out(j)] = tdatav[0][j]; + it.out(j,0) = tdatav[0][j]; size_t i=1, i1=1, i2=len-1; for (i=1; i<len-1; i+=2, ++i1, --i2) for (size_t j=0; j<vlen; ++j) { - data_out[p.out(j)+i1*s_o] = tdatav[i][j]+tdatav[i+1][j]; - data_out[p.out(j)+i2*s_o] = tdatav[i][j]-tdatav[i+1][j]; + it.out(j,i1) = tdatav[i][j]+tdatav[i+1][j]; + it.out(j,i2) = tdatav[i][j]-tdatav[i+1][j]; } if (i<len) for (size_t j=0; j<vlen; ++j) - data_out[p.out(j)+i1*s_o] = tdatav[i][j]; + it.out(j,i1) = tdatav[i][j]; } while (it.remaining()>0) { + it.advance(1); + auto tdata = (T *)storage.data(); for (size_t i=0; i<len; ++i) - tdata[i] = data_in[it.offset_in()+i*s_i]; + tdata[i] = it.in(0,i); plan->forward((T *)tdata, fct); // Hartley order - data_out[it.offset_out()] = tdata[0]; + it.out(0,0) = tdata[0]; size_t i=1, i1=1, i2=len-1; for (i=1; i<len-1; i+=2, ++i1, --i2) { - data_out[it.offset_out()+i1*s_o] = tdata[i]+tdata[i+1]; - data_out[it.offset_out()+i2*s_o] = tdata[i]-tdata[i+1]; + it.out(0,i1) = tdata[i]+tdata[i+1]; + it.out(0,i2) = tdata[i]-tdata[i+1]; } if (i<len) - data_out[it.offset_out()+i1*s_o] = tdata[i]; - it.advance(); + it.out(0,i1) = tdata[i]; } - data_in = data_out; // after the first dimension, take data from output array fct = T(1); // factor has been applied, use 1 for remaining axes } } @@ -2231,47 +2287,45 @@ template<typename T> NOINLINE void pocketfft_general_r2c(const shape_t &shape, const T *data_in, cmplx<T> *data_out, T fct) { auto storage = alloc_tmp<T>(shape, shape[axis], sizeof(T)); - using vtype = typename VTYPE<T>::type; - auto tdatav = (vtype *)storage.data(); - constexpr int vlen = VTYPE<T>::vlen; - auto tdata = (T *)storage.data(); pocketfft_r<T> plan(shape[axis]); - multi_iter_io it(shape, stride_in, shape, stride_out, axis); + constexpr int vlen = VTYPE<T>::vlen; + multi_iter<vlen, T, cmplx<T>> + it(data_in, shape, stride_in, data_out, shape, stride_out, axis); size_t len=shape[axis]; - int64_t s_i=it.stride_in(), s_o=it.stride_out(); if (vlen>1) while (it.remaining()>=vlen) { - multioffset_io<vlen> p(it); + using vtype = typename VTYPE<T>::type; + it.advance(vlen); + auto tdatav = (vtype *)storage.data(); for (size_t i=0; i<len; ++i) for (size_t j=0; j<vlen; ++j) - tdatav[i][j] = data_in[p.in(j)+i*s_i]; + tdatav[i][j] = it.in(j,i); plan.forward((vtype *)tdatav, fct); for (size_t j=0; j<vlen; ++j) - data_out[p.out(j)].Set(tdatav[0][j]); + it.out(j,0).Set(tdatav[0][j]); size_t i=1, ii=1; for (; i<len-1; i+=2, ++ii) for (size_t j=0; j<vlen; ++j) - data_out[p.out(j)+ii*s_o].Set(tdatav[i][j], tdatav[i+1][j]); + it.out(j,ii).Set(tdatav[i][j], tdatav[i+1][j]); if (i<len) for (size_t j=0; j<vlen; ++j) - data_out[p.out(j)+ii*s_o].Set(tdatav[i][j]); + it.out(j,ii).Set(tdatav[i][j]); } while (it.remaining()>0) { - const T *d_i = data_in+it.offset_in(); - cmplx<T> *d_o = data_out+it.offset_out(); + it.advance(1); + auto tdata = (T *)storage.data(); for (size_t i=0; i<len; ++i) - tdata[i] = d_i[i*s_i]; + tdata[i] = it.in(0,i); plan.forward(tdata, fct); - d_o[0].Set(tdata[0]); + it.out(0,0).Set(tdata[0]); size_t i=1, ii=1; for (; i<len-1; i+=2, ++ii) - d_o[ii*s_o].Set(tdata[i], tdata[i+1]); + it.out(0,ii).Set(tdata[i], tdata[i+1]); if (i<len) - d_o[ii*s_o].Set(tdata[i]); - it.advance(); + it.out(0,ii).Set(tdata[i]); } } template<typename T> NOINLINE void pocketfft_general_c2r(const shape_t &shape_out, @@ -2279,53 +2333,51 @@ template<typename T> NOINLINE void pocketfft_general_c2r(const shape_t &shape_ou const cmplx<T> *data_in, T *data_out, T fct) { auto storage = alloc_tmp<T>(shape_out, shape_out[axis], sizeof(T)); - using vtype = typename VTYPE<T>::type; - auto tdatav = (vtype *)storage.data(); - constexpr int vlen = VTYPE<T>::vlen; - auto tdata = (T *)storage.data(); - pocketfft_r<T> plan(shape_out[axis]); - multi_iter_io it(shape_out, stride_in, shape_out, stride_out, axis); + + constexpr int vlen = VTYPE<T>::vlen; + multi_iter<vlen, cmplx<T>, T> + it(data_in, shape_out, stride_in, data_out, shape_out, stride_out, axis); size_t len=shape_out[axis]; - int64_t s_i=it.stride_in(), s_o=it.stride_out(); if (vlen>1) while (it.remaining()>=vlen) { - multioffset_io<vlen> p(it); + using vtype = typename VTYPE<T>::type; + it.advance(vlen); + auto tdatav = (vtype *)storage.data(); for (size_t j=0; j<vlen; ++j) - tdatav[0][j]=data_in[p.in(j)].r; + tdatav[0][j]=it.in(j,0).r; size_t i=1, ii=1; for (; i<len-1; i+=2, ++ii) for (size_t j=0; j<vlen; ++j) { - tdatav[i][j] = data_in[p.in(j)+ii*s_i].r; - tdatav[i+1][j] = data_in[p.in(j)+ii*s_i].i; + tdatav[i][j] = it.in(j,ii).r; + tdatav[i+1][j] = it.in(j,ii).i; } if (i<len) for (size_t j=0; j<vlen; ++j) - tdatav[i][j] = data_in[p.in(j)+ii*s_i].r; + tdatav[i][j] = it.in(j,ii).r; plan.backward(tdatav, fct); for (size_t i=0; i<len; ++i) for (size_t j=0; j<vlen; ++j) - data_out[p.out(j)+i*s_o] = tdatav[i][j]; + it.out(j,i) = tdatav[i][j]; } while (it.remaining()>0) { - const cmplx<T> *d_i = data_in+it.offset_in(); - T *d_o = data_out+it.offset_out(); - tdata[0]=d_i[0].r; + it.advance(1); + auto tdata = (T *)storage.data(); + tdata[0]=it.in(0,0).r; size_t i=1, ii=1; for (; i<len-1; i+=2, ++ii) { - tdata[i] = d_i[ii*s_i].r; - tdata[i+1] = d_i[ii*s_i].i; + tdata[i] = it.in(0,ii).r; + tdata[i+1] = it.in(0,ii).i; } if (i<len) - tdata[i] = d_i[ii*s_i].r; + tdata[i] = it.in(0,ii).r; plan.backward(tdata, fct); for (size_t i=0; i<len; ++i) - d_o[i*s_o] = tdata[i]; - it.advance(); + it.out(0,i) = tdata[i]; } }