diff --git a/README.md b/README.md index 1fb76bdecffe4e2559a84086494056654f9e11bf..9948ff1caf2132f1eb80c536da9dd76e8ee39d92 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ pypocketfft =========== -This package provides Fast Fourier and Hartley transforms with a simple -Python interface. +This package provides Fast Fourier, trigonometric and Hartley transforms with a +simple Python interface. The central algorithms are derived from Paul Swarztrauber's FFTPACK code (http://www.netlib.org/fftpack). @@ -10,11 +10,11 @@ The central algorithms are derived from Paul Swarztrauber's FFTPACK code Features -------- - supports fully complex and half-complex (i.e. complex-to-real and - real-to-complex) FFTs -- supports multidimensional arrays and selection of the axes to be transformed. -- supports single and double precision + real-to-complex) FFTs, discrete sine/cosine transforms and Hartley transforms +- achieves very high accuracy for all transforms +- supports multidimensional arrays and selection of the axes to be transformed +- supports single, double, and long double precision - makes use of CPU vector instructions when performing 2D and higher-dimensional transforms -- does not have persistent transform plans, which makes the interface simpler - supports prime-length transforms without degrading to O(N**2) performance -- Has optional OpenMP support for multidimensional transforms +- has optional OpenMP support for multidimensional transforms diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index 6a94b114910f759a8165aaa7b605eae19ad52060..eba33984df9556b44581b039dd7226507a040ff6 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -2,7 +2,13 @@ This file is part of pocketfft. Copyright (C) 2010-2019 Max-Planck-Society -Author: Martin Reinecke +Copyright (C) 2019 Peter Bell + +For the odd-sized DCT-IV transforms: + Copyright (C) 2003, 2007-14 Matteo Frigo + Copyright (C) 2003, 2007-14 Massachusetts Institute of Technology + +Authors: Martin Reinecke, Peter Bell All rights reserved. @@ -196,6 +202,13 @@ template struct cmplx { { r+=other.r; i+=other.i; return *this; } templatecmplx &operator*= (T2 other) { r*=other; i*=other; return *this; } + templatecmplx &operator*= (const cmplx &other) + { + T tmp = r*other.r - i*other.i; + i = r*other.i + i*other.r; + r = tmp; + return *this; + } cmplx operator+ (const cmplx &other) const { return cmplx(r+other.r, i+other.i); } cmplx operator- (const cmplx &other) const @@ -474,8 +487,8 @@ struct util // hack to avoid duplicate symbols shape_t tmp(ndim,0); for (auto ax : axes) { - if (ax>=ndim) throw runtime_error("bad axis number"); - if (++tmp[ax]>1) throw runtime_error("axis specified repeatedly"); + if (ax>=ndim) throw invalid_argument("bad axis number"); + if (++tmp[ax]>1) throw invalid_argument("axis specified repeatedly"); } } @@ -484,7 +497,7 @@ struct util // hack to avoid duplicate symbols size_t axis) { sanity_check(shape, stride_in, stride_out, inplace); - if (axis>=shape.size()) throw runtime_error("bad axis number"); + if (axis>=shape.size()) throw invalid_argument("bad axis number"); } #ifdef POCKETFFT_OPENMP @@ -531,7 +544,7 @@ template void pass2 (size_t ido, size_t l1, auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; - auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+cdim*c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i-1+x*(ido-1)]; }; @@ -585,7 +598,7 @@ template void pass3 (size_t ido, size_t l1, auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; - auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+cdim*c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i-1+x*(ido-1)]; }; @@ -623,7 +636,7 @@ template void pass4 (size_t ido, size_t l1, auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; - auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+cdim*c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i-1+x*(ido-1)]; }; @@ -705,7 +718,7 @@ template void pass5 (size_t ido, size_t l1, auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; - auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+cdim*c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i-1+x*(ido-1)]; }; @@ -779,7 +792,7 @@ template void pass7(size_t ido, size_t l1, auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; - auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+cdim*c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i-1+x*(ido-1)]; }; @@ -843,7 +856,7 @@ template void pass8 (size_t ido, size_t l1, auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; - auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+cdim*c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i-1+x*(ido-1)]; }; @@ -971,7 +984,7 @@ template void pass11 (size_t ido, size_t l1, auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; - auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+cdim*c)]; }; auto WA = [wa, ido](size_t x, size_t i) { return wa[i-1+x*(ido-1)]; }; @@ -1245,7 +1258,7 @@ template void pass_all(T c[], T0 fct) POCKETFFT_NOINLINE cfftp(size_t length_) : length(length_) { - if (length==0) throw runtime_error("zero length FFT requested"); + if (length==0) throw runtime_error("zero-length FFT requested"); if (length==1) return; factorize(); mem.resize(twsize()); @@ -1290,7 +1303,7 @@ template void radf2 (size_t ido, size_t l1, auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+l1*c)]; }; - auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + auto CH = [ch,ido,cdim](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+cdim*c)]; }; for (size_t k=0; k void radf3(size_t ido, size_t l1, auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+l1*c)]; }; - auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + auto CH = [ch,ido,cdim](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+cdim*c)]; }; for (size_t k=0; k void radf4(size_t ido, size_t l1, auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+l1*c)]; }; - auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + auto CH = [ch,ido,cdim](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+cdim*c)]; }; for (size_t k=0; k void radf5(size_t ido, size_t l1, auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+l1*c)]; }; - auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + auto CH = [ch,ido,cdim](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+cdim*c)]; }; for (size_t k=0; k void radb2(size_t ido, size_t l1, constexpr size_t cdim=2; auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; - auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+cdim*c)]; }; auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; @@ -1653,7 +1666,7 @@ template void radb3(size_t ido, size_t l1, constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L); auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; - auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+cdim*c)]; }; auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; @@ -1694,7 +1707,7 @@ template void radb4(size_t ido, size_t l1, constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; - auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+cdim*c)]; }; auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; @@ -1750,7 +1763,7 @@ template void radb5(size_t ido, size_t l1, ti12= T0(0.5877852522924731291687059546390728L); auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; - auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& { return cc[a+ido*(b+cdim*c)]; }; auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& { return ch[a+ido*(b+l1*c)]; }; @@ -2085,7 +2098,7 @@ template void radbg(size_t ido, size_t ip, size_t l1, POCKETFFT_NOINLINE rfftp(size_t length_) : length(length_) { - if (length==0) throw runtime_error("zero-sized FFT"); + if (length==0) throw runtime_error("zero-length FFT requested"); if (length==1) return; factorize(); mem.resize(twsize()); @@ -2221,10 +2234,10 @@ template class pocketfft_c packplan=unique_ptr>(new cfftp(length)); } - template POCKETFFT_NOINLINE void backward(cmplx c[], T0 fct) + template POCKETFFT_NOINLINE void backward(cmplx c[], T0 fct) const { packplan ? packplan->backward(c,fct) : blueplan->backward(c,fct); } - template POCKETFFT_NOINLINE void forward(cmplx c[], T0 fct) + template POCKETFFT_NOINLINE void forward(cmplx c[], T0 fct) const { packplan ? packplan->forward(c,fct) : blueplan->forward(c,fct); } size_t length() const { return len; } @@ -2261,13 +2274,13 @@ template class pocketfft_r packplan=unique_ptr>(new rfftp(length)); } - template POCKETFFT_NOINLINE void backward(T c[], T0 fct) + template POCKETFFT_NOINLINE void backward(T c[], T0 fct) const { packplan ? packplan->backward(c,fct) : blueplan->backward_r(c,fct); } - template POCKETFFT_NOINLINE void forward(T c[], T0 fct) + template POCKETFFT_NOINLINE void forward(T c[], T0 fct) const { packplan ? packplan->forward(c,fct) : blueplan->forward_r(c,fct); @@ -2276,6 +2289,365 @@ template class pocketfft_r size_t length() const { return len; } }; + +// +// sine/cosine transforms +// +template class T_dct1 + { + private: + pocketfft_r fftplan; + + public: + POCKETFFT_NOINLINE T_dct1(size_t length) + : fftplan(2*(length-1)) {} + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=fftplan.length(), n=N/2+1; + if (ortho) + { c[0]*=sqrt2; c[n-1]*=sqrt2; } + arr tmp(N); + tmp[0] = c[0]; + for (size_t i=1; i class T_dct2 + { + private: + pocketfft_r fftplan; + vector twiddle; + + public: + POCKETFFT_NOINLINE T_dct2(size_t length) + : fftplan(length), twiddle(length) + { + constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); + for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=length(); + if (N==1) + c[0]*=2*fct; + else if (N==2) + { + T x1 = 2*fct*(c[0]+c[1]); + c[1] = sqrt2*fct*(c[0]-c[1]); + c[0] = x1; + } + else + { + size_t NS2 = (N+1)/2; + for (size_t i=2; i class T_dct3 + { + private: + pocketfft_r fftplan; + vector twiddle; + + public: + POCKETFFT_NOINLINE T_dct3(size_t length) + : fftplan(length), twiddle(length) + { + constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); + for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=length(); + if (ortho) c[0]*=sqrt2; + if (N==1) + c[0]*=fct; + else if (N==2) + { + T TSQX = sqrt2*c[1]; + c[1] = fct*(c[0]-TSQX); + c[0] = fct*(c[0]+TSQX); + } + else + { + size_t NS2 = (N+1)/2; + for (size_t k=1, kc=N-1; k class T_dct4 + { + // even length algorithm from + // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/ + private: + size_t N; + unique_ptr> fft; + unique_ptr> rfft; + arr> C2; + + public: + POCKETFFT_NOINLINE T_dct4(size_t length) + : N(length), + fft((N&1) ? nullptr : new pocketfft_c(N/2)), + rfft((N&1)? new pocketfft_r(N) : nullptr), + C2((N&1) ? 0 : N/2) + { + constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); + if ((N&1)==0) + for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + if (N&1) + { + // The following code is derived from the FFTW3 function apply_re11() + // and is released under the 3-clause BSD license with friendly + // permission of Matteo Frigo. + + auto SGN_SET = [](T x, size_t i) {return (i%2) ? -x : x;}; + arr y(N); + size_t n2 = N/2; + size_t i; + { + size_t m; + for (i=0, m=n2; mforward(y.data(), fct); + for (i=0; i+i+1> y(N/2); + for(size_t i=0; iforward(y.data(), fct); + for(size_t i=0; i class T_dst1 + { + private: + pocketfft_r fftplan; + + public: + POCKETFFT_NOINLINE T_dst1(size_t length) + : fftplan(2*(length+1)) {} + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const + { + size_t N=fftplan.length(), n=N/2-1; + arr tmp(N); + tmp[0] = tmp[n+1] = c[0]*0; + for (size_t i=0; i class T_dst2 + { + private: + T_dct2 dct; + + public: + POCKETFFT_NOINLINE T_dst2(size_t length) + : dct(length) {} + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=length(); + if (N==1) + c[0]*=2*fct; + else + { + for (size_t k=1; k class T_dst3 + { + private: + T_dct3 dct; + + public: + POCKETFFT_NOINLINE T_dst3(size_t length) + : dct(length) {} + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=length(); + if (ortho) c[0]*=sqrt2; + if (N==1) + c[0]*=fct; + else + { + size_t NS2 = N/2; + for (size_t k=0, kc=N-1; k class T_dst4 + { + private: + T_dct4 dct; + + public: + POCKETFFT_NOINLINE T_dst4(size_t length) + : dct(length) {} + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) + { + size_t N=length(); + //if (N==1) { c[0]*=fct; return; } + size_t NS2 = N/2; + for (size_t k=0, kc=N-1; k POCKETFFT_NOINLINE void general_hartley( } } +template POCKETFFT_NOINLINE void general_dcst( + const cndarr &in, ndarr &out, const shape_t &axes, + T fct, bool ortho, size_t POCKETFFT_NTHREADS) + { + shared_ptr plan; + + for (size_t iax=0; iax::val; + size_t len=in.shape(axes[iax]); + if ((!plan) || (len!=plan->length())) + plan = get_plan(len); + +#ifdef POCKETFFT_OPENMP +#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax])) +#endif +{ + auto storage = alloc_tmp(in.shape(), len, sizeof(T)); + const auto &tin(iax==0 ? in : out); + multi_iter it(tin, out, axes[iax]); +#ifndef POCKETFFT_NO_VECTORS + if (vlen>1) + while (it.remaining()>=vlen) + { + using vtype = typename VTYPE::type; + it.advance(vlen); + auto tdatav = reinterpret_cast(storage.data()); + for (size_t i=0; iexec(tdatav, fct, ortho); + for (size_t i=0; i0) + { + it.advance(1); + auto tdata = reinterpret_cast(storage.data()); + if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place + plan->exec(&out[it.oofs(0)], fct, ortho); + else if (it.stride_out()==sizeof(T)) // compute FFT in output location + { + for (size_t i=0; iexec(&out[it.oofs(0)], fct, ortho); + } + else + { + for (size_t i=0; iexec(tdata, fct, ortho); + for (size_t i=0; i POCKETFFT_NOINLINE void general_r2c( const cndarr &in, ndarr> &out, size_t axis, bool forward, T fct, size_t POCKETFFT_NTHREADS) @@ -2963,6 +3397,48 @@ template void c2c(const shape_t &shape, const stride_t &stride_in, general_c(ain, aout, axes, forward, fct, nthreads); } +template void dct(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) + { + if ((type<1) || (type>4)) throw invalid_argument("invalid DCT type"); + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + if (type==1) + general_dcst>(ain, aout, axes, fct, ortho, nthreads); + else if (type==2) + general_dcst>(ain, aout, axes, fct, ortho, nthreads); + else if (type==3) + general_dcst>(ain, aout, axes, fct, ortho, nthreads); + else if (type==4) + general_dcst>(ain, aout, axes, fct, ortho, nthreads); + else + throw runtime_error("unsupported DCT type"); + } + +template void dst(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) + { + if ((type<1) || (type>4)) throw invalid_argument("invalid DST type"); + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + if (type==1) + general_dcst>(ain, aout, axes, fct, ortho, nthreads); + else if (type==2) + general_dcst>(ain, aout, axes, fct, ortho, nthreads); + else if (type==3) + general_dcst>(ain, aout, axes, fct, ortho, nthreads); + else if (type==4) + general_dcst>(ain, aout, axes, fct, ortho, nthreads); + else + throw runtime_error("unsupported DST type"); + } + template void r2c(const shape_t &shape_in, const stride_t &stride_in, const stride_t &stride_out, size_t axis, bool forward, const T *data_in, complex *data_out, T fct, @@ -3069,6 +3545,8 @@ using detail::c2r; using detail::r2c; using detail::r2r_fftpack; using detail::r2r_separable_hartley; +using detail::dct; +using detail::dst; } // namespace pocketfft diff --git a/pypocketfft.cc b/pypocketfft.cc index efe85f0bbf551044dc78fc5772541470446925d4..1efd2f3b2abe001b832acd1b749e1af4cf292dfd 100644 --- a/pypocketfft.cc +++ b/pypocketfft.cc @@ -7,7 +7,9 @@ * Python interface. * * Copyright (C) 2019 Max-Planck-Society + * Copyright (C) 2019 Peter Bell * \author Martin Reinecke + * \author Peter Bell */ #include @@ -92,12 +94,12 @@ template T norm_fct(int inorm, size_t N) } template T norm_fct(int inorm, const shape_t &shape, - const shape_t &axes) + const shape_t &axes, size_t fct=1, int delta=0) { if (inorm==0) return T(1); size_t N(1); for (auto a: axes) - N *= shape[a]; + N *= fct * size_t(int64_t(shape[a])+delta); return norm_fct(inorm, N); } @@ -226,6 +228,66 @@ py::array r2r_fftpack(const py::array &in, const py::object &axes_, real2hermitian, forward, inorm, out_, nthreads)) } +template py::array dct_internal(const py::array &in, + const py::object &axes_, int type, int inorm, py::object &out_, + size_t nthreads) + { + auto axes = makeaxes(in, axes_); + auto dims(copy_shape(in)); + py::array res = prepare_output(out_, dims); + auto s_in=copy_strides(in); + auto s_out=copy_strides(res); + auto d_in=reinterpret_cast(in.data()); + auto d_out=reinterpret_cast(res.mutable_data()); + { + py::gil_scoped_release release; + T fct = (type==1) ? norm_fct(inorm, dims, axes, 2, -1) + : norm_fct(inorm, dims, axes, 2); + bool ortho = inorm == 1; + pocketfft::dct(dims, s_in, s_out, axes, type, d_in, d_out, fct, ortho, + nthreads); + } + return res; + } + +py::array dct(const py::array &in, int type, const py::object &axes_, + int inorm, py::object &out_, size_t nthreads) + { + if ((type<1) || (type>4)) throw invalid_argument("invalid DCT type"); + DISPATCH(in, f64, f32, flong, dct_internal, (in, axes_, type, inorm, out_, + nthreads)) + } + +template py::array dst_internal(const py::array &in, + const py::object &axes_, int type, int inorm, py::object &out_, + size_t nthreads) + { + auto axes = makeaxes(in, axes_); + auto dims(copy_shape(in)); + py::array res = prepare_output(out_, dims); + auto s_in=copy_strides(in); + auto s_out=copy_strides(res); + auto d_in=reinterpret_cast(in.data()); + auto d_out=reinterpret_cast(res.mutable_data()); + { + py::gil_scoped_release release; + T fct = (type==1) ? norm_fct(inorm, dims, axes, 2, 1) + : norm_fct(inorm, dims, axes, 2); + bool ortho = inorm == 1; + pocketfft::dst(dims, s_in, s_out, axes, type, d_in, d_out, fct, ortho, + nthreads); + } + return res; + } + +py::array dst(const py::array &in, int type, const py::object &axes_, + int inorm, py::object &out_, size_t nthreads) + { + if ((type<1) || (type>4)) throw invalid_argument("invalid DST type"); + DISPATCH(in, f64, f32, flong, dst_internal, (in, axes_, type, inorm, + out_, nthreads)) + } + template py::array c2r_internal(const py::array &in, const py::object &axes_, size_t lastsize, bool forward, int inorm, py::object &out_, size_t nthreads) @@ -235,7 +297,7 @@ template py::array c2r_internal(const py::array &in, shape_t dims_in(copy_shape(in)), dims_out=dims_in; if (lastsize==0) lastsize=2*dims_in[axis]-1; if ((lastsize/2) + 1 != dims_in[axis]) - throw runtime_error("bad lastsize"); + throw invalid_argument("bad lastsize"); dims_out[axis] = lastsize; py::array res = prepare_output(out_, dims_out); auto s_in=copy_strides(in); @@ -297,7 +359,6 @@ templatepy::array complex2hartley(const py::array &in, py::gil_scoped_release release; simple_iter iin(atmp); rev_iter iout(aout, axes); - if (iin.remaining()!=iout.remaining()) throw runtime_error("oops"); while(iin.remaining()>0) { auto v = atmp[iin.ofs()]; @@ -524,6 +585,83 @@ numpy.ndarray (same shape and data type as `a`) The transformed data )"""; +const char *dct_DS = R"""(Performs a discrete cosine transform. + +Parameters +---------- +a : numpy.ndarray (any real type) + The input data +type : integer + the type of DCT. Must be in [1; 4]. +axes : list of integers + The axes along which the transform is carried out. + If not set, all axes will be transformed. +inorm : int + Normalization type + 0 : no normalization + 1 : make transform orthogonal and divide by sqrt(N) + 2 : divide by N + where N is the product of n_i for every transformed axis i. + n_i is 2*(-1 for type 1 and 2* + for types 2, 3, 4. + Making the transform orthogonal involves the following additional steps + for every 1D sub-transform: + Type 1 : multiply first and last input value by sqrt(2) + divide first and last output value by sqrt(2) + Type 2 : divide first output value by sqrt(2) + Type 3 : multiply first input value by sqrt(2) + Type 4 : nothing +out : numpy.ndarray (same shape and data type as `a`) + May be identical to `a`, but if it isn't, it must not overlap with `a`. + If None, a new array is allocated to store the output. +nthreads : int + Number of threads to use. If 0, use the system default (typically governed + by the `OMP_NUM_THREADS` environment variable). + +Returns +------- +numpy.ndarray (same shape and data type as `a`) + The transformed data +)"""; + +const char *dst_DS = R"""(Performs a discrete sine transform. + +Parameters +---------- +a : numpy.ndarray (any real type) + The input data +type : integer + the type of DST. Must be in [1; 4]. +axes : list of integers + The axes along which the transform is carried out. + If not set, all axes will be transformed. +inorm : int + Normalization type + 0 : no normalization + 1 : make transform orthogonal and divide by sqrt(N) + 2 : divide by N + where N is the product of n_i for every transformed axis i. + n_i is 2*(+1 for type 1 and 2* + for types 2, 3, 4. + Making the transform orthogonal involves the following additional steps + for every 1D sub-transform: + Type 1 : nothing + Type 2 : divide first output value by sqrt(2) + Type 3 : multiply first input value by sqrt(2) + Type 4 : nothing +out : numpy.ndarray (same shape and data type as `a`) + May be identical to `a`, but if it isn't, it must not overlap with `a`. + If None, a new array is allocated to store the output. +nthreads : int + Number of threads to use. If 0, use the system default (typically governed + by the `OMP_NUM_THREADS` environment variable). + +Returns +------- +numpy.ndarray (same shape and data type as `a`) + The transformed data +)"""; + } // unnamed namespace PYBIND11_MODULE(pypocketfft, m) @@ -543,4 +681,8 @@ PYBIND11_MODULE(pypocketfft, m) "axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1); m.def("genuine_hartley", genuine_hartley, genuine_hartley_DS, "a"_a, "axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1); + m.def("dct", dct, dct_DS, "a"_a, "type"_a, "axes"_a=None, "inorm"_a=0, + "out"_a=None, "nthreads"_a=1); + m.def("dst", dst, dst_DS, "a"_a, "type"_a, "axes"_a=None, "inorm"_a=0, + "out"_a=None, "nthreads"_a=1); } diff --git a/test.py b/test.py index 5c43f82776355e9d1920476c42467cc23c4f172d..c718c8698f6e657028d17b8c615c578a09b00e25 100644 --- a/test.py +++ b/test.py @@ -17,6 +17,13 @@ def _l2error(a, b): return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2)) +def _assert_close(a, b, epsilon): + err = _l2error(a, b) + if (err >= epsilon): + print("Error: {} > {}".format(err, epsilon)) + assert_(err 1: # there are no length-1 type 1 DCTs + _assert_close(a, pypocketfft.dct(pypocketfft.dct(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), eps) + _assert_close(a, pypocketfft.dst(pypocketfft.dst(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), eps)