Planned maintenance on Wednesday, 2021-01-20, 17:00-18:00. Expect some interruptions during that time

Commit 514e1afb by Martin Reinecke

improve interface

parent 0f5781c9
 ... @@ -2296,7 +2296,7 @@ template class T_cosq ... @@ -2296,7 +2296,7 @@ template class T_cosq twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length))); twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length))); } } template POCKETFFT_NOINLINE void backward(T c[], T0 fct) template POCKETFFT_NOINLINE void type2(T c[], T0 fct) { { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); size_t N=length(); ... @@ -2315,9 +2315,6 @@ template class T_cosq ... @@ -2315,9 +2315,6 @@ template class T_cosq c[i] = T0(0.5)*(c[i]-c[i-1]); c[i] = T0(0.5)*(c[i]-c[i-1]); c[i-1] = xim1; c[i-1] = xim1; } } // c[0] *= 2; // if ((N&1)==0) // c[N-1] *= 2; fftplan.backward(c, fct); fftplan.backward(c, fct); for (size_t k=1, kc=N-1; k class T_cosq ... @@ -2336,7 +2333,7 @@ template class T_cosq c[0] *= 2; c[0] *= 2; } } template POCKETFFT_NOINLINE void forward(T c[], T0 fct) template POCKETFFT_NOINLINE void type3(T c[], T0 fct) { { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); size_t N=length(); ... @@ -2386,26 +2383,26 @@ template class T_sinq ... @@ -2386,26 +2383,26 @@ template class T_sinq POCKETFFT_NOINLINE T_sinq(size_t length) POCKETFFT_NOINLINE T_sinq(size_t length) : cosq(length) {} : cosq(length) {} template POCKETFFT_NOINLINE void backward(T c[], T0 fct) template POCKETFFT_NOINLINE void type2(T c[], T0 fct) { { size_t N=length(); size_t N=length(); if (N==1) { c[0]*=2*fct; return; } if (N==1) { c[0]*=2*fct; return; } for (size_t k=1; k POCKETFFT_NOINLINE void forward(T c[], T0 fct) template POCKETFFT_NOINLINE void type3(T c[], T0 fct) { { size_t N=length(); size_t N=length(); if (N==1) { c[0]*=fct; return; } if (N==1) { c[0]*=fct; return; } size_t NS2 = N/2; size_t NS2 = N/2; for (size_t k=0, kc=N-1; k POCKETFFT_NOINLINE void general_hartley( ... @@ -2858,7 +2855,7 @@ template POCKETFFT_NOINLINE void general_hartley( } } template POCKETFFT_NOINLINE void general_dct23( template POCKETFFT_NOINLINE void general_dct23( const cndarr &in, ndarr &out, const shape_t &axes, bool forward, T fct, const cndarr &in, ndarr &out, const shape_t &axes, bool type2, T fct, size_t POCKETFFT_NTHREADS) size_t POCKETFFT_NTHREADS) { { shared_ptr> plan; shared_ptr> plan; ... @@ -2887,7 +2884,7 @@ template POCKETFFT_NOINLINE void general_dct23( ... @@ -2887,7 +2884,7 @@ template POCKETFFT_NOINLINE void general_dct23( for (size_t i=0; iforward (tdatav, fct) : plan->backward(tdatav, fct); type2 ? plan->type2(tdatav, fct) : plan->type3(tdatav, fct); for (size_t i=0; i POCKETFFT_NOINLINE void general_dct23( ... @@ -2898,20 +2895,20 @@ template POCKETFFT_NOINLINE void general_dct23( it.advance(1); it.advance(1); auto tdata = reinterpret_cast(storage.data()); auto tdata = reinterpret_cast(storage.data()); if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place forward ? plan->forward (&out[it.oofs(0)], fct) type2 ? plan->type2(&out[it.oofs(0)], fct) : plan->backward(&out[it.oofs(0)], fct); : plan->type3(&out[it.oofs(0)], fct); else if (it.stride_out()==sizeof(T)) // compute FFT in output location else if (it.stride_out()==sizeof(T)) // compute FFT in output location { { for (size_t i=0; iforward (&out[it.oofs(0)], fct) type2 ? plan->type2(&out[it.oofs(0)], fct) : plan->backward(&out[it.oofs(0)], fct); : plan->type3(&out[it.oofs(0)], fct); } } else else { { for (size_t i=0; iforward (tdata, fct) : plan->backward(tdata, fct); type2 ? plan->type2(tdata, fct) : plan->type3(tdata, fct); for (size_t i=0; i POCKETFFT_NOINLINE void general_dct23( ... @@ -2921,7 +2918,7 @@ template POCKETFFT_NOINLINE void general_dct23( } } } } template POCKETFFT_NOINLINE void general_dst23( template POCKETFFT_NOINLINE void general_dst23( const cndarr &in, ndarr &out, const shape_t &axes, bool forward, T fct, const cndarr &in, ndarr &out, const shape_t &axes, bool type2, T fct, size_t POCKETFFT_NTHREADS) size_t POCKETFFT_NTHREADS) { { shared_ptr> plan; shared_ptr> plan; ... @@ -2950,7 +2947,7 @@ template POCKETFFT_NOINLINE void general_dst23( ... @@ -2950,7 +2947,7 @@ template POCKETFFT_NOINLINE void general_dst23( for (size_t i=0; iforward (tdatav, fct) : plan->backward(tdatav, fct); type2 ? plan->type2(tdatav, fct) : plan->type3(tdatav, fct); for (size_t i=0; i POCKETFFT_NOINLINE void general_dst23( ... @@ -2961,20 +2958,20 @@ template POCKETFFT_NOINLINE void general_dst23( it.advance(1); it.advance(1); auto tdata = reinterpret_cast(storage.data()); auto tdata = reinterpret_cast(storage.data()); if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place forward ? plan->forward (&out[it.oofs(0)], fct) type2 ? plan->type2(&out[it.oofs(0)], fct) : plan->backward(&out[it.oofs(0)], fct); : plan->type3(&out[it.oofs(0)], fct); else if (it.stride_out()==sizeof(T)) // compute FFT in output location else if (it.stride_out()==sizeof(T)) // compute FFT in output location { { for (size_t i=0; iforward (&out[it.oofs(0)], fct) type2 ? plan->type2(&out[it.oofs(0)], fct) : plan->backward(&out[it.oofs(0)], fct); : plan->type3(&out[it.oofs(0)], fct); } } else else { { for (size_t i=0; iforward (tdata, fct) : plan->backward(tdata, fct); type2 ? plan->type2(tdata, fct) : plan->type3(tdata, fct); for (size_t i=0; i void c2c(const shape_t &shape, const stride_t &stride_in, ... @@ -3228,26 +3225,30 @@ template void c2c(const shape_t &shape, const stride_t &stride_in, general_c(ain, aout, axes, forward, fct, nthreads); general_c(ain, aout, axes, forward, fct, nthreads); } } template void r2r_dct23(const shape_t &shape, template void r2r_dct(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, bool forward, const T *data_in, T *data_out, T fct, size_t nthreads=1) int type, const T *data_in, T *data_out, T fct, size_t nthreads=1) { { if ((type<1) || (type>4)) throw runtime_error("invalid DCT type"); if ((type==1) || (type==4)) throw runtime_error("DCT type not yet supported "); if (util::prod(shape)==0) return; if (util::prod(shape)==0) return; util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); cndarr ain(data_in, shape, stride_in); cndarr ain(data_in, shape, stride_in); ndarr aout(data_out, shape, stride_out); ndarr aout(data_out, shape, stride_out); general_dct23(ain, aout, axes, forward, fct, nthreads); general_dct23(ain, aout, axes, type==2, fct, nthreads); } } template void r2r_dst23(const shape_t &shape, template void r2r_dst(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, bool forward, const T *data_in, T *data_out, T fct, size_t nthreads=1) int type, const T *data_in, T *data_out, T fct, size_t nthreads=1) { { if ((type<1) || (type>4)) throw runtime_error("invalid DST type"); if ((type==1) || (type==4)) throw runtime_error("DST type not yet supported "); if (util::prod(shape)==0) return; if (util::prod(shape)==0) return; util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); cndarr ain(data_in, shape, stride_in); cndarr ain(data_in, shape, stride_in); ndarr aout(data_out, shape, stride_out); ndarr aout(data_out, shape, stride_out); general_dst23(ain, aout, axes, forward, fct, nthreads); general_dst23(ain, aout, axes, type==2, fct, nthreads); } } template void r2c(const shape_t &shape_in, template void r2c(const shape_t &shape_in, ... @@ -3356,8 +3357,8 @@ using detail::c2r; ... @@ -3356,8 +3357,8 @@ using detail::c2r; using detail::r2c; using detail::r2c; using detail::r2r_fftpack; using detail::r2r_fftpack; using detail::r2r_separable_hartley; using detail::r2r_separable_hartley; using detail::r2r_dct23; using detail::r2r_dct; using detail::r2r_dst23; using detail::r2r_dst; } // namespace pocketfft } // namespace pocketfft ... ...
 ... @@ -92,12 +92,12 @@ template T norm_fct(int inorm, size_t N) ... @@ -92,12 +92,12 @@ template T norm_fct(int inorm, size_t N) } } template T norm_fct(int inorm, const shape_t &shape, template T norm_fct(int inorm, const shape_t &shape, const shape_t &axes) const shape_t &axes, size_t fct=1, int delta=0) { { if (inorm==0) return T(1); if (inorm==0) return T(1); size_t N(1); size_t N(1); for (auto a: axes) for (auto a: axes) N *= shape[a]; N *= size_t(int64_t(shape[a]*fct)+delta); return norm_fct(inorm, N); return norm_fct(inorm, N); } } ... @@ -226,8 +226,8 @@ py::array r2r_fftpack(const py::array &in, const py::object &axes_, ... @@ -226,8 +226,8 @@ py::array r2r_fftpack(const py::array &in, const py::object &axes_, real2hermitian, forward, inorm, out_, nthreads)) real2hermitian, forward, inorm, out_, nthreads)) } } template py::array r2r_dct23_internal(const py::array &in, template py::array r2r_dct_internal(const py::array &in, const py::object &axes_, bool forward, int inorm, py::object &out_, const py::object &axes_, int type, int inorm, py::object &out_, size_t nthreads) size_t nthreads) { { auto axes = makeaxes(in, axes_); auto axes = makeaxes(in, axes_); ... @@ -239,24 +239,27 @@ template py::array r2r_dct23_internal(const py::array &in, ... @@ -239,24 +239,27 @@ template py::array r2r_dct23_internal(const py::array &in, auto d_out=reinterpret_cast(res.mutable_data()); auto d_out=reinterpret_cast(res.mutable_data()); { { py::gil_scoped_release release; py::gil_scoped_release release; T fct = norm_fct(inorm, dims, axes); if ((type==2)||(type==3)) if (inorm==2) fct*=T(1/ldbl_t(2)); { if (inorm==1) fct*=T(1/sqrt(ldbl_t(2))); T fct = norm_fct(inorm, dims, axes, 2); pocketfft::r2r_dct23(dims, s_in, s_out, axes, forward, d_in, d_out, fct, pocketfft::r2r_dct(dims, s_in, s_out, axes, type, d_in, d_out, fct, nthreads); nthreads); } } } return res; return res; } } py::array r2r_dct23(const py::array &in, const py::object &axes_, py::array r2r_dct(const py::array &in, int type, const py::object &axes_, bool forward, int inorm, py::object &out_, size_t nthreads) int inorm, py::object &out_, size_t nthreads) { { DISPATCH(in, f64, f32, flong, r2r_dct23_internal, (in, axes_, if ((type<1) || (type>4)) throw runtime_error("invalid DCT type"); forward, inorm, out_, nthreads)) if ((type==1) || (type==4)) throw runtime_error("DCT type not yet supported "); DISPATCH(in, f64, f32, flong, r2r_dct_internal, (in, axes_, type, inorm, out_, nthreads)) } } template py::array r2r_dst23_internal(const py::array &in, template py::array r2r_dst_internal(const py::array &in, const py::object &axes_, bool forward, int inorm, py::object &out_, const py::object &axes_, int type, int inorm, py::object &out_, size_t nthreads) size_t nthreads) { { auto axes = makeaxes(in, axes_); auto axes = makeaxes(in, axes_); ... @@ -268,20 +271,23 @@ template py::array r2r_dst23_internal(const py::array &in, ... @@ -268,20 +271,23 @@ template py::array r2r_dst23_internal(const py::array &in, auto d_out=reinterpret_cast(res.mutable_data()); auto d_out=reinterpret_cast(res.mutable_data()); { { py::gil_scoped_release release; py::gil_scoped_release release; T fct = norm_fct(inorm, dims, axes); if ((type==2)||(type==3)) if (inorm==2) fct*=T(1/ldbl_t(2)); { if (inorm==1) fct*=T(1/sqrt(ldbl_t(2))); T fct = norm_fct(inorm, dims, axes, 2); pocketfft::r2r_dst23(dims, s_in, s_out, axes, forward, d_in, d_out, fct, pocketfft::r2r_dst(dims, s_in, s_out, axes, type, d_in, d_out, fct, nthreads); nthreads); } } } return res; return res; } } py::array r2r_dst23(const py::array &in, const py::object &axes_, py::array r2r_dst(const py::array &in, int type, const py::object &axes_, bool forward, int inorm, py::object &out_, size_t nthreads) int inorm, py::object &out_, size_t nthreads) { { DISPATCH(in, f64, f32, flong, r2r_dst23_internal, (in, axes_, if ((type<1) || (type>4)) throw runtime_error("invalid DST type"); forward, inorm, out_, nthreads)) if ((type==1) || (type==4)) throw runtime_error("DST type not yet supported "); DISPATCH(in, f64, f32, flong, r2r_dst_internal, (in, axes_, type, inorm, out_, nthreads)) } } template py::array c2r_internal(const py::array &in, template py::array c2r_internal(const py::array &in, ... @@ -601,8 +607,8 @@ PYBIND11_MODULE(pypocketfft, m) ... @@ -601,8 +607,8 @@ PYBIND11_MODULE(pypocketfft, m) "axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1); "axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1); m.def("genuine_hartley", genuine_hartley, genuine_hartley_DS, "a"_a, m.def("genuine_hartley", genuine_hartley, genuine_hartley_DS, "a"_a, "axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1); "axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1); m.def("r2r_dct23", r2r_dct23, /*r2r_dct23_DS,*/ "a"_a, m.def("r2r_dct", r2r_dct, /*r2r_dct_DS,*/ "a"_a, "type"_a, "axes"_a=None, "axes"_a=None, "forward"_a, "inorm"_a=0, "out"_a=None, "nthreads"_a=1); "inorm"_a=0, "out"_a=None, "nthreads"_a=1); m.def("r2r_dst23", r2r_dst23, /*r2r_dst23_DS,*/ "a"_a, m.def("r2r_dst", r2r_dst, /*r2r_dst_DS,*/ "a"_a, "type"_a, "axes"_a=None, "axes"_a=None, "forward"_a, "inorm"_a=0, "out"_a=None, "nthreads"_a=1); "inorm"_a=0, "out"_a=None, "nthreads"_a=1); } }
 ... @@ -211,9 +211,9 @@ def testdcst1D(len, inorm): ... @@ -211,9 +211,9 @@ def testdcst1D(len, inorm): a = np.random.rand(len)-0.5 a = np.random.rand(len)-0.5 b = a.astype(np.float32) b = a.astype(np.float32) c = a.astype(np.float128) c = a.astype(np.float128) _assert_close(a, pypocketfft.r2r_dct23(pypocketfft.r2r_dct23(c, inorm=inorm, forward=True), inorm=2-inorm, forward=False), 2e-18) _assert_close(a, pypocketfft.r2r_dct(pypocketfft.r2r_dct(c, inorm=inorm, type=2), inorm=2-inorm, type=3), 2e-18) _assert_close(a, pypocketfft.r2r_dct23(pypocketfft.r2r_dct23(a, inorm=inorm, forward=True), inorm=2-inorm, forward=False), 1.5e-15) _assert_close(a, pypocketfft.r2r_dct(pypocketfft.r2r_dct(a, inorm=inorm, type=2), inorm=2-inorm, type=3), 1.5e-15) _assert_close(b, pypocketfft.r2r_dct23(pypocketfft.r2r_dct23(b, inorm=inorm, forward=True), inorm=2-inorm, forward=False), 6e-7) _assert_close(b, pypocketfft.r2r_dct(pypocketfft.r2r_dct(b, inorm=inorm, type=2), inorm=2-inorm, type=3), 6e-7) _assert_close(a, pypocketfft.r2r_dst23(pypocketfft.r2r_dst23(c, inorm=inorm, forward=True), inorm=2-inorm, forward=False), 2e-18) _assert_close(a, pypocketfft.r2r_dst(pypocketfft.r2r_dst(c, inorm=inorm, type=2), inorm=2-inorm, type=3), 2e-18) _assert_close(a, pypocketfft.r2r_dst23(pypocketfft.r2r_dst23(a, inorm=inorm, forward=True), inorm=2-inorm, forward=False), 1.5e-15) _assert_close(a, pypocketfft.r2r_dst(pypocketfft.r2r_dst(a, inorm=inorm, type=2), inorm=2-inorm, type=3), 1.5e-15) _assert_close(b, pypocketfft.r2r_dst23(pypocketfft.r2r_dst23(b, inorm=inorm, forward=True), inorm=2-inorm, forward=False), 6e-7) _assert_close(b, pypocketfft.r2r_dst(pypocketfft.r2r_dst(b, inorm=inorm, type=2), inorm=2-inorm, type=3), 6e-7)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment