diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index b74b1c653c748d72c0613506099d5a733ba12e01..ea6538eb194e888ad1cc94e7cfcdfa6893ce0983 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -2902,25 +2902,16 @@ template POCKETFFT_NOINLINE void general_c( while (it.remaining()>0) { it.advance(1); - auto tdata = reinterpret_cast *>(storage.data()); - if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(cmplx))) // fully in-place - forward ? plan->forward (&out[it.oofs(0)], fct) - : plan->backward(&out[it.oofs(0)], fct); - else if (it.stride_out()==sizeof(cmplx)) // compute FFT in output location - { - for (size_t i=0; iforward (&out[it.oofs(0)], fct) - : plan->backward(&out[it.oofs(0)], fct); - } - else - { + auto buf = it.stride_out() == sizeof(cmplx) ? + &out[it.oofs(0)] : reinterpret_cast *>(storage.data()); + + if (buf != &tin[it.iofs(0)]) for (size_t i=0; iforward (tdata, fct) : plan->backward(tdata, fct); + buf[i] = tin[it.iofs(i)]; + forward ? plan->forward (buf, fct) : plan->backward(buf, fct); + if (buf != &out[it.oofs(0)]) for (size_t i=0; i POCKETFFT_NOINLINE void general_dcst( while (it.remaining()>0) { it.advance(1); - auto tdata = reinterpret_cast(storage.data()); - if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place - plan->exec(&out[it.oofs(0)], fct, ortho, type, cosine); - else if (it.stride_out()==sizeof(T)) // compute FFT in output location - { - for (size_t i=0; iexec(&out[it.oofs(0)], fct, ortho, type, cosine); - } - else - { + auto buf = it.stride_out() == sizeof(T) ? &out[it.oofs(0)] + : reinterpret_cast(storage.data()); + + if (buf != &tin[it.iofs(0)]) for (size_t i=0; iexec(tdata, fct, ortho, type, cosine); + buf[i] = tin[it.iofs(i)]; + plan->exec(buf, fct, ortho, type, cosine); + if (buf != &out[it.oofs(0)]) for (size_t i=0; i POCKETFFT_NOINLINE void general_r( while (it.remaining()>0) { it.advance(1); - auto tdata = reinterpret_cast(storage.data()); - if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place - { - if ((!r2c) && forward) - for (size_t i=2; iforward (&out[it.oofs(0)], fct) - : plan->backward(&out[it.oofs(0)], fct); - if (r2c && (!forward)) - for (size_t i=2; iforward (&out[it.oofs(0)], fct) - : plan->backward(&out[it.oofs(0)], fct); - if (r2c && (!forward)) - for (size_t i=2; i(storage.data()); + + if (buf != &tin[it.iofs(0)]) for (size_t i=0; iforward (tdata, fct) : plan->backward(tdata, fct); - if (r2c && (!forward)) - for (size_t i=2; iforward(buf, fct) : plan->backward(buf, fct); + if (r2c && (!forward)) + for (size_t i=2; i