/* This file is part of pocketfft. Copyright (C) 2010-2020 Max-Planck-Society Copyright (C) 2019 Peter Bell For the odd-sized DCT-IV transforms: Copyright (C) 2003, 2007-14 Matteo Frigo Copyright (C) 2003, 2007-14 Massachusetts Institute of Technology Authors: Martin Reinecke, Peter Bell All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ #ifndef MRUTIL_FFT_H #define MRUTIL_FFT_H #include "mr_util/math/fft1d.h" #ifndef POCKETFFT_CACHE_SIZE #define POCKETFFT_CACHE_SIZE 16 #endif #include #include #include #include #include #include #include #if POCKETFFT_CACHE_SIZE!=0 #include #endif #include "mr_util/infra/threading.h" #include "mr_util/infra/simd.h" #include "mr_util/infra/mav.h" #ifndef MRUTIL_NO_THREADING #include #endif namespace mr { namespace detail_fft { constexpr bool FORWARD = true, BACKWARD = false; struct util // hack to avoid duplicate symbols { static void sanity_check_axes(size_t ndim, const shape_t &axes) { shape_t tmp(ndim,0); if (axes.empty()) throw std::invalid_argument("no axes specified"); for (auto ax : axes) { if (ax>=ndim) throw std::invalid_argument("bad axis number"); if (++tmp[ax]>1) throw std::invalid_argument("axis specified repeatedly"); } } static MRUTIL_NOINLINE void sanity_check_onetype(const fmav_info &a1, const fmav_info &a2, bool inplace, const shape_t &axes) { sanity_check_axes(a1.ndim(), axes); MR_assert(a1.conformable(a2), "array sizes are not conformable"); if (inplace) MR_assert(a1.stride()==a2.stride(), "stride mismatch"); } static MRUTIL_NOINLINE void sanity_check_cr(const fmav_info &ac, const fmav_info &ar, const shape_t &axes) { sanity_check_axes(ac.ndim(), axes); MR_assert(ac.ndim()==ar.ndim(), "dimension mismatch"); for (size_t i=0; i=ac.ndim()) throw std::invalid_argument("bad axis number"); MR_assert(ac.ndim()==ar.ndim(), "dimension mismatch"); for (size_t i=0; i class T_dct1 { private: pocketfft_r fftplan; public: MRUTIL_NOINLINE T_dct1(size_t length) : fftplan(2*(length-1)) {} template MRUTIL_NOINLINE void exec(T c[], T0 fct, bool ortho, int /*type*/, bool /*cosine*/) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=fftplan.length(), n=N/2+1; if (ortho) { c[0]*=sqrt2; c[n-1]*=sqrt2; } aligned_array tmp(N); tmp[0] = c[0]; for (size_t i=1; i class T_dst1 { private: pocketfft_r fftplan; public: MRUTIL_NOINLINE T_dst1(size_t length) : fftplan(2*(length+1)) {} template MRUTIL_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/, int /*type*/, bool /*cosine*/) const { size_t N=fftplan.length(), n=N/2-1; aligned_array tmp(N); tmp[0] = tmp[n+1] = c[0]*0; for (size_t i=0; i class T_dcst23 { private: pocketfft_r fftplan; std::vector twiddle; public: MRUTIL_NOINLINE T_dcst23(size_t length) : fftplan(length), twiddle(length) { UnityRoots> tw(4*length); for (size_t i=0; i MRUTIL_NOINLINE void exec(T c[], T0 fct, bool ortho, int type, bool cosine) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); size_t NS2 = (N+1)/2; if (type==2) { if (!cosine) for (size_t k=1; k class T_dcst4 { private: size_t N; std::unique_ptr> fft; std::unique_ptr> rfft; aligned_array> C2; public: MRUTIL_NOINLINE T_dcst4(size_t length) : N(length), fft((N&1) ? nullptr : new pocketfft_c(N/2)), rfft((N&1)? new pocketfft_r(N) : nullptr), C2((N&1) ? 0 : N/2) { if ((N&1)==0) { UnityRoots> tw(16*N); for (size_t i=0; i MRUTIL_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/, int /*type*/, bool cosine) const { size_t n2 = N/2; if (!cosine) for (size_t k=0, kc=N-1; k y(N); { size_t i=0, m=n2; for (; mexec(y.data(), fct, true); { auto SGN = [](size_t i) { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); return (i&2) ? -sqrt2 : sqrt2; }; c[n2] = y[0]*SGN(n2+1); size_t i=0, i1=1, k=1; for (; k> y(n2); for(size_t i=0; iexec(y.data(), fct, true); for(size_t i=0, ic=n2-1; i std::shared_ptr get_plan(size_t length) { #if POCKETFFT_CACHE_SIZE==0 return std::make_shared(length); #else constexpr size_t nmax=POCKETFFT_CACHE_SIZE; static std::array, nmax> cache; static std::array last_access{{0}}; static size_t access_counter = 0; auto find_in_cache = [&]() -> std::shared_ptr { for (size_t i=0; ilength()==length)) { // no need to update if this is already the most recent entry if (last_access[i]!=access_counter) { last_access[i] = ++access_counter; // Guard against overflow if (access_counter == 0) last_access.fill(0); } return cache[i]; } return nullptr; }; #ifdef MRUTIL_NO_THREADING auto p = find_in_cache(); if (p) return p; auto plan = std::make_shared(length); size_t lru = 0; for (size_t i=1; i lock(mut); auto p = find_in_cache(); if (p) return p; } auto plan = std::make_shared(length); { std::lock_guard lock(mut); auto p = find_in_cache(); if (p) return p; size_t lru = 0; for (size_t i=1; i class multi_iter { private: shape_t pos; fmav_info iarr, oarr; ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o; size_t idim, rem; void advance_i() { for (int i_=int(pos.size())-1; i_>=0; --i_) { auto i = size_t(i_); if (i==idim) continue; p_ii += iarr.stride(i); p_oi += oarr.stride(i); if (++pos[i] < iarr.shape(i)) return; pos[i] = 0; p_ii -= ptrdiff_t(iarr.shape(i))*iarr.stride(i); p_oi -= ptrdiff_t(oarr.shape(i))*oarr.stride(i); } } public: multi_iter(const fmav_info &iarr_, const fmav_info &oarr_, size_t idim_, size_t nshares, size_t myshare) : pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0), str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)), idim(idim_), rem(iarr.size()/iarr.shape(idim)) { if (nshares==1) return; if (nshares==0) throw std::runtime_error("can't run with zero threads"); if (myshare>=nshares) throw std::runtime_error("impossible share requested"); size_t nbase = rem/nshares; size_t additional = rem%nshares; size_t lo = myshare*nbase + ((myshare=0; --i_) { auto i = size_t(i_); p += arr.stride(i); if (++pos[i] < arr.shape(i)) return; pos[i] = 0; p -= ptrdiff_t(arr.shape(i))*arr.stride(i); } } ptrdiff_t ofs() const { return p; } size_t remaining() const { return rem; } }; class rev_iter { private: shape_t pos; fmav_info arr; std::vector rev_axis; std::vector rev_jump; size_t last_axis, last_size; shape_t shp; ptrdiff_t p, rp; size_t rem; public: rev_iter(const fmav_info &arr_, const shape_t &axes) : pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0), rev_jump(arr_.ndim(), 1), p(0), rp(0) { for (auto ax: axes) rev_axis[ax]=1; last_axis = axes.back(); last_size = arr.shape(last_axis)/2 + 1; shp = arr.shape(); shp[last_axis] = last_size; rem=1; for (auto i: shp) rem *= i; } void advance() { --rem; for (int i_=int(pos.size())-1; i_>=0; --i_) { auto i = size_t(i_); p += arr.stride(i); if (!rev_axis[i]) rp += arr.stride(i); else { rp -= arr.stride(i); if (rev_jump[i]) { rp += ptrdiff_t(arr.shape(i))*arr.stride(i); rev_jump[i] = 0; } } if (++pos[i] < shp[i]) return; pos[i] = 0; p -= ptrdiff_t(shp[i])*arr.stride(i); if (rev_axis[i]) { rp -= ptrdiff_t(arr.shape(i)-shp[i])*arr.stride(i); rev_jump[i] = 1; } else rp -= ptrdiff_t(shp[i])*arr.stride(i); } } ptrdiff_t ofs() const { return p; } ptrdiff_t rev_ofs() const { return rp; } size_t remaining() const { return rem; } }; template aligned_array alloc_tmp (const fmav_info &info, size_t axsize) { auto othersize = info.size()/axsize; constexpr auto vlen = native_simd::size(); auto tmpsize = axsize*((othersize>=vlen) ? vlen : 1); return aligned_array(tmpsize); } template void copy_input(const multi_iter &it, const fmav> &src, Cmplx> *MRUTIL_RESTRICT dst) { for (size_t i=0; i void copy_input(const multi_iter &it, const fmav &src, native_simd *MRUTIL_RESTRICT dst) { for (size_t i=0; i void copy_input(const multi_iter &it, const fmav &src, T *MRUTIL_RESTRICT dst) { if (dst == &src[it.iofs(0)]) return; // in-place for (size_t i=0; i void copy_output(const multi_iter &it, const Cmplx> *MRUTIL_RESTRICT src, fmav> &dst) { auto ptr=dst.vdata(); for (size_t i=0; i void copy_output(const multi_iter &it, const native_simd *MRUTIL_RESTRICT src, fmav &dst) { auto ptr=dst.vdata(); for (size_t i=0; i void copy_output(const multi_iter &it, const T *MRUTIL_RESTRICT src, fmav &dst) { auto ptr=dst.vdata(); if (src == &dst[it.oofs(0)]) return; // in-place for (size_t i=0; i struct add_vec { using type = native_simd; }; template struct add_vec> { using type = Cmplx>; }; template using add_vec_t = typename add_vec::type; template MRUTIL_NOINLINE void general_nd(const fmav &in, fmav &out, const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec, const bool allow_inplace=true) { std::shared_ptr plan; for (size_t iax=0; iaxlength())) plan = get_plan(len); execParallel( util::thread_count(nthreads, in, axes[iax], native_simd::size()), [&](Scheduler &sched) { constexpr auto vlen = native_simd::size(); auto storage = alloc_tmp(in, len); const auto &tin(iax==0? in : out); multi_iter it(tin, out, axes[iax], sched.num_threads(), sched.thread_num()); #ifndef MRUTIL_NO_SIMD if (vlen>1) while (it.remaining()>=vlen) { it.advance(vlen); auto tdatav = reinterpret_cast *>(storage.data()); exec(it, tin, out, tdatav, *plan, fct); } #endif while (it.remaining()>0) { it.advance(1); auto buf = allow_inplace && it.stride_out() == 1 ? &out.vraw(it.oofs(0)) : reinterpret_cast(storage.data()); exec(it, tin, out, buf, *plan, fct); } }); // end of parallel region fct = T0(1); // factor has been applied, use 1 for remaining axes } } struct ExecC2C { bool forward; template void operator() ( const multi_iter &it, const fmav> &in, fmav> &out, T *buf, const pocketfft_c &plan, T0 fct) const { copy_input(it, in, buf); plan.exec(buf, fct, forward); copy_output(it, buf, out); } }; template void copy_hartley(const multi_iter &it, const native_simd *MRUTIL_RESTRICT src, fmav &dst) { auto ptr = dst.vdata(); for (size_t j=0; j void copy_hartley(const multi_iter &it, const T *MRUTIL_RESTRICT src, fmav &dst) { auto ptr = dst.vdata(); ptr[it.oofs(0)] = src[0]; size_t i=1, i1=1, i2=it.length_out()-1; for (i=1; i void operator () ( const multi_iter &it, const fmav &in, fmav &out, T * buf, const pocketfft_r &plan, T0 fct) const { copy_input(it, in, buf); plan.exec(buf, fct, true); copy_hartley(it, buf, out); } }; struct ExecDcst { bool ortho; int type; bool cosine; template void operator () (const multi_iter &it, const fmav &in, fmav &out, T * buf, const Tplan &plan, T0 fct) const { copy_input(it, in, buf); plan.exec(buf, fct, ortho, type, cosine); copy_output(it, buf, out); } }; template MRUTIL_NOINLINE void general_r2c( const fmav &in, fmav> &out, size_t axis, bool forward, T fct, size_t nthreads) { auto plan = get_plan>(in.shape(axis)); size_t len=in.shape(axis); execParallel( util::thread_count(nthreads, in, axis, native_simd::size()), [&](Scheduler &sched) { constexpr auto vlen = native_simd::size(); auto storage = alloc_tmp(in, len); multi_iter it(in, out, axis, sched.num_threads(), sched.thread_num()); #ifndef MRUTIL_NO_SIMD if (vlen>1) while (it.remaining()>=vlen) { it.advance(vlen); auto tdatav = reinterpret_cast *>(storage.data()); copy_input(it, in, tdatav); plan->exec(tdatav, fct, true); auto vout = out.vdata(); for (size_t j=0; j0) { it.advance(1); auto tdata = reinterpret_cast(storage.data()); copy_input(it, in, tdata); plan->exec(tdata, fct, true); auto vout = out.vdata(); vout[it.oofs(0)].Set(tdata[0]); size_t i=1, ii=1; if (forward) for (; i MRUTIL_NOINLINE void general_c2r( const fmav> &in, fmav &out, size_t axis, bool forward, T fct, size_t nthreads) { auto plan = get_plan>(out.shape(axis)); size_t len=out.shape(axis); execParallel( util::thread_count(nthreads, in, axis, native_simd::size()), [&](Scheduler &sched) { constexpr auto vlen = native_simd::size(); auto storage = alloc_tmp(out, len); multi_iter it(in, out, axis, sched.num_threads(), sched.thread_num()); #ifndef MRUTIL_NO_SIMD if (vlen>1) while (it.remaining()>=vlen) { it.advance(vlen); auto tdatav = reinterpret_cast *>(storage.data()); for (size_t j=0; jexec(tdatav, fct, false); copy_output(it, tdatav, out); } #endif while (it.remaining()>0) { it.advance(1); auto tdata = reinterpret_cast(storage.data()); tdata[0]=in[it.iofs(0)].r; { size_t i=1, ii=1; if (forward) for (; iexec(tdata, fct, false); copy_output(it, tdata, out); } }); // end of parallel region } struct ExecR2R { bool r2c, forward; template void operator () ( const multi_iter &it, const fmav &in, fmav &out, T * buf, const pocketfft_r &plan, T0 fct) const { copy_input(it, in, buf); if ((!r2c) && forward) for (size_t i=2; i void c2c(const fmav> &in, fmav> &out, const shape_t &axes, bool forward, T fct, size_t nthreads=1) { util::sanity_check_onetype(in, out, in.data()==out.data(), axes); if (in.size()==0) return; fmav> in2(reinterpret_cast *>(in.data()), in); fmav> out2(reinterpret_cast *>(out.vdata()), out, out.writable()); general_nd>(in2, out2, axes, fct, nthreads, ExecC2C{forward}); } template void dct(const fmav &in, fmav &out, const shape_t &axes, int type, T fct, bool ortho, size_t nthreads=1) { if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type"); util::sanity_check_onetype(in, out, in.data()==out.data(), axes); if (in.size()==0) return; const ExecDcst exec{ortho, type, true}; if (type==1) general_nd>(in, out, axes, fct, nthreads, exec); else if (type==4) general_nd>(in, out, axes, fct, nthreads, exec); else general_nd>(in, out, axes, fct, nthreads, exec); } template void dst(const fmav &in, fmav &out, const shape_t &axes, int type, T fct, bool ortho, size_t nthreads=1) { if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type"); util::sanity_check_onetype(in, out, in.data()==out.data(), axes); const ExecDcst exec{ortho, type, false}; if (type==1) general_nd>(in, out, axes, fct, nthreads, exec); else if (type==4) general_nd>(in, out, axes, fct, nthreads, exec); else general_nd>(in, out, axes, fct, nthreads, exec); } template void r2c(const fmav &in, fmav> &out, size_t axis, bool forward, T fct, size_t nthreads=1) { util::sanity_check_cr(out, in, axis); if (in.size()==0) return; fmav> out2(reinterpret_cast *>(out.vdata()), out, out.writable()); general_r2c(in, out2, axis, forward, fct, nthreads); } template void r2c(const fmav &in, fmav> &out, const shape_t &axes, bool forward, T fct, size_t nthreads=1) { util::sanity_check_cr(out, in, axes); if (in.size()==0) return; r2c(in, out, axes.back(), forward, fct, nthreads); if (axes.size()==1) return; auto newaxes = shape_t{axes.begin(), --axes.end()}; c2c(out, out, newaxes, forward, T(1), nthreads); } template void c2r(const fmav> &in, fmav &out, size_t axis, bool forward, T fct, size_t nthreads=1) { util::sanity_check_cr(in, out, axis); if (in.size()==0) return; fmav> in2(reinterpret_cast *>(in.data()), in); general_c2r(in2, out, axis, forward, fct, nthreads); } template void c2r(const fmav> &in, fmav &out, const shape_t &axes, bool forward, T fct, size_t nthreads=1) { if (axes.size()==1) return c2r(in, out, axes[0], forward, fct, nthreads); util::sanity_check_cr(in, out, axes); if (in.size()==0) return; fmav> atmp(in.shape()); auto newaxes = shape_t{axes.begin(), --axes.end()}; c2c(in, atmp, newaxes, forward, T(1), nthreads); c2r(atmp, out, axes.back(), forward, fct, nthreads); } template void r2r_fftpack(const fmav &in, fmav &out, const shape_t &axes, bool real2hermitian, bool forward, T fct, size_t nthreads=1) { util::sanity_check_onetype(in, out, in.data()==out.data(), axes); if (in.size()==0) return; general_nd>(in, out, axes, fct, nthreads, ExecR2R{real2hermitian, forward}); } template void r2r_separable_hartley(const fmav &in, fmav &out, const shape_t &axes, T fct, size_t nthreads=1) { util::sanity_check_onetype(in, out, in.data()==out.data(), axes); if (in.size()==0) return; general_nd>(in, out, axes, fct, nthreads, ExecHartley{}, false); } template void r2r_genuine_hartley(const fmav &in, fmav &out, const shape_t &axes, T fct, size_t nthreads=1) { if (axes.size()==1) return r2r_separable_hartley(in, out, axes, fct, nthreads); util::sanity_check_onetype(in, out, in.data()==out.data(), axes); if (in.size()==0) return; shape_t tshp(in.shape()); tshp[axes.back()] = tshp[axes.back()]/2+1; fmav> atmp(tshp); r2c(in, atmp, axes, true, fct, nthreads); simple_iter iin(atmp); rev_iter iout(out, axes); auto vout = out.vdata(); while(iin.remaining()>0) { auto v = atmp[iin.ofs()]; vout[iout.ofs()] = v.real()+v.imag(); vout[iout.rev_ofs()] = v.real()-v.imag(); iin.advance(); iout.advance(); } } } // namespace detail_fft using detail_fft::FORWARD; using detail_fft::BACKWARD; using detail_fft::c2c; using detail_fft::c2r; using detail_fft::r2c; using detail_fft::r2r_fftpack; using detail_fft::r2r_separable_hartley; using detail_fft::r2r_genuine_hartley; using detail_fft::dct; using detail_fft::dst; } // namespace mr #endif // POCKETFFT_HDRONLY_H