Commit 12f9cf44 authored by Martin Reinecke's avatar Martin Reinecke

add type I DCTs/DSTs; performance is suboptimal, but accuracy is good

parent 514e1afb
...@@ -2261,13 +2261,13 @@ template<typename T0> class pocketfft_r ...@@ -2261,13 +2261,13 @@ template<typename T0> class pocketfft_r
packplan=unique_ptr<rfftp<T0>>(new rfftp<T0>(length)); packplan=unique_ptr<rfftp<T0>>(new rfftp<T0>(length));
} }
template<typename T> POCKETFFT_NOINLINE void backward(T c[], T0 fct) template<typename T> POCKETFFT_NOINLINE void backward(T c[], T0 fct) const
{ {
packplan ? packplan->backward(c,fct) packplan ? packplan->backward(c,fct)
: blueplan->backward_r(c,fct); : blueplan->backward_r(c,fct);
} }
template<typename T> POCKETFFT_NOINLINE void forward(T c[], T0 fct) template<typename T> POCKETFFT_NOINLINE void forward(T c[], T0 fct) const
{ {
packplan ? packplan->forward(c,fct) packplan ? packplan->forward(c,fct)
: blueplan->forward_r(c,fct); : blueplan->forward_r(c,fct);
...@@ -2280,15 +2280,39 @@ template<typename T0> class pocketfft_r ...@@ -2280,15 +2280,39 @@ template<typename T0> class pocketfft_r
// //
// sine/cosine transforms // sine/cosine transforms
// //
template<typename T0> class T_dct1
{
private:
pocketfft_r<T0> fftplan;
template<typename T0> class T_cosq public:
POCKETFFT_NOINLINE T_dct1(size_t length)
: fftplan(2*(length-1)) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const
{
size_t N=fftplan.length(), n=N/2+1;
arr<T> tmp(N);
tmp[0] = c[0];
for (size_t i=1; i<n; ++i)
tmp[i] = tmp[N-i] = c[i];
fftplan.forward(tmp.data(), fct);
c[0] = tmp[0];
for (size_t i=1; i<n; ++i)
c[i] = tmp[2*i-1];
}
size_t length() const { return fftplan.length()/2+1; }
};
template<typename T0> class T_dct2
{ {
private: private:
pocketfft_r<T0> fftplan; pocketfft_r<T0> fftplan;
vector<T0> twiddle; vector<T0> twiddle;
public: public:
POCKETFFT_NOINLINE T_cosq(size_t length) POCKETFFT_NOINLINE T_dct2(size_t length)
: fftplan(length), twiddle(length) : fftplan(length), twiddle(length)
{ {
constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
...@@ -2296,7 +2320,7 @@ template<typename T0> class T_cosq ...@@ -2296,7 +2320,7 @@ template<typename T0> class T_cosq
twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length))); twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length)));
} }
template<typename T> POCKETFFT_NOINLINE void type2(T c[], T0 fct) template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const
{ {
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length(); size_t N=length();
...@@ -2333,7 +2357,25 @@ template<typename T0> class T_cosq ...@@ -2333,7 +2357,25 @@ template<typename T0> class T_cosq
c[0] *= 2; c[0] *= 2;
} }
template<typename T> POCKETFFT_NOINLINE void type3(T c[], T0 fct) size_t length() const { return fftplan.length(); }
};
template<typename T0> class T_dct3
{
private:
pocketfft_r<T0> fftplan;
vector<T0> twiddle;
public:
POCKETFFT_NOINLINE T_dct3(size_t length)
: fftplan(length), twiddle(length)
{
constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
for (size_t i=0; i<length; ++i)
twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length)));
}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const
{ {
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length(); size_t N=length();
...@@ -2374,40 +2416,79 @@ template<typename T0> class T_cosq ...@@ -2374,40 +2416,79 @@ template<typename T0> class T_cosq
size_t length() const { return fftplan.length(); } size_t length() const { return fftplan.length(); }
}; };
template<typename T0> class T_sinq template<typename T0> class T_dst1
{ {
private: private:
T_cosq<T0> cosq; pocketfft_r<T0> fftplan;
public:
POCKETFFT_NOINLINE T_dst1(size_t length)
: fftplan(2*(length+1)) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const
{
size_t N=fftplan.length(), n=N/2-1;
arr<T> tmp(N);
tmp[0] = tmp[n+1] = c[0]*0;
for (size_t i=0; i<n; ++i)
{
tmp[i+1] = c[i];
tmp[N-1-i] = -c[i];
}
fftplan.forward(tmp.data(), fct);
for (size_t i=0; i<n; ++i)
c[i] = -tmp[2*i+2];
}
size_t length() const { return fftplan.length()/2-1; }
};
template<typename T0> class T_dst2
{
private:
T_dct2<T0> dct;
public: public:
POCKETFFT_NOINLINE T_sinq(size_t length) POCKETFFT_NOINLINE T_dst2(size_t length)
: cosq(length) {} : dct(length) {}
template<typename T> POCKETFFT_NOINLINE void type2(T c[], T0 fct) template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const
{ {
size_t N=length(); size_t N=length();
if (N==1) { c[0]*=2*fct; return; } if (N==1) { c[0]*=2*fct; return; }
for (size_t k=1; k<N; k+=2) for (size_t k=1; k<N; k+=2)
c[k] = -c[k]; c[k] = -c[k];
cosq.type2(c, fct); dct.exec(c, fct);
for (size_t k=0, kc=N-1; k<kc; ++k, --kc) for (size_t k=0, kc=N-1; k<kc; ++k, --kc)
swap(c[k], c[kc]); swap(c[k], c[kc]);
} }
template<typename T> POCKETFFT_NOINLINE void type3(T c[], T0 fct) size_t length() const { return dct.length(); }
};
template<typename T0> class T_dst3
{
private:
T_dct3<T0> dct;
public:
POCKETFFT_NOINLINE T_dst3(size_t length)
: dct(length) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct)
{ {
size_t N=length(); size_t N=length();
if (N==1) { c[0]*=fct; return; } if (N==1) { c[0]*=fct; return; }
size_t NS2 = N/2; size_t NS2 = N/2;
for (size_t k=0, kc=N-1; k<NS2; ++k, --kc) for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
swap(c[k], c[kc]); swap(c[k], c[kc]);
cosq.type3(c, fct); dct.exec(c, fct);
for (size_t k=1; k<N; k+=2) for (size_t k=1; k<N; k+=2)
c[k] = -c[k]; c[k] = -c[k];
} }
size_t length() const { return cosq.length(); } size_t length() const { return dct.length(); }
}; };
...@@ -2854,81 +2935,18 @@ template<typename T> POCKETFFT_NOINLINE void general_hartley( ...@@ -2854,81 +2935,18 @@ template<typename T> POCKETFFT_NOINLINE void general_hartley(
} }
} }
template<typename T> POCKETFFT_NOINLINE void general_dct23( template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, bool type2, T fct, const cndarr<T> &in, ndarr<T> &out, const shape_t &axes,
size_t POCKETFFT_NTHREADS) T fct, size_t POCKETFFT_NTHREADS)
{
shared_ptr<T_cosq<T>> plan;
for (size_t iax=0; iax<axes.size(); ++iax)
{
constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length()))
plan = get_plan<T_cosq<T>>(len);
#ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
#endif
{
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
const auto &tin(iax==0 ? in : out);
multi_iter<vlen> it(tin, out, axes[iax]);
#ifndef POCKETFFT_NO_VECTORS
if (vlen>1)
while (it.remaining()>=vlen)
{
using vtype = typename VTYPE<T>::type;
it.advance(vlen);
auto tdatav = reinterpret_cast<vtype *>(storage.data());
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = tin[it.iofs(j,i)];
type2 ? plan->type2(tdatav, fct) : plan->type3(tdatav, fct);
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,i)] = tdatav[i][j];
}
#endif
while (it.remaining()>0)
{
it.advance(1);
auto tdata = reinterpret_cast<T *>(storage.data());
if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place
type2 ? plan->type2(&out[it.oofs(0)], fct)
: plan->type3(&out[it.oofs(0)], fct);
else if (it.stride_out()==sizeof(T)) // compute FFT in output location
{
for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tin[it.iofs(i)];
type2 ? plan->type2(&out[it.oofs(0)], fct)
: plan->type3(&out[it.oofs(0)], fct);
}
else
{
for (size_t i=0; i<len; ++i)
tdata[i] = tin[it.iofs(i)];
type2 ? plan->type2(tdata, fct) : plan->type3(tdata, fct);
for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tdata[i];
}
}
} // end of parallel region
fct = T(1); // factor has been applied, use 1 for remaining axes
}
}
template<typename T> POCKETFFT_NOINLINE void general_dst23(
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, bool type2, T fct,
size_t POCKETFFT_NTHREADS)
{ {
shared_ptr<T_sinq<T>> plan; shared_ptr<Trafo> plan;
for (size_t iax=0; iax<axes.size(); ++iax) for (size_t iax=0; iax<axes.size(); ++iax)
{ {
constexpr auto vlen = VLEN<T>::val; constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axes[iax]); size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length())) if ((!plan) || (len!=plan->length()))
plan = get_plan<T_sinq<T>>(len); plan = get_plan<Trafo>(len);
#ifdef POCKETFFT_OPENMP #ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax])) #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
...@@ -2947,7 +2965,7 @@ template<typename T> POCKETFFT_NOINLINE void general_dst23( ...@@ -2947,7 +2965,7 @@ template<typename T> POCKETFFT_NOINLINE void general_dst23(
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = tin[it.iofs(j,i)]; tdatav[i][j] = tin[it.iofs(j,i)];
type2 ? plan->type2(tdatav, fct) : plan->type3(tdatav, fct); plan->exec(tdatav, fct);
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,i)] = tdatav[i][j]; out[it.oofs(j,i)] = tdatav[i][j];
...@@ -2958,20 +2976,18 @@ template<typename T> POCKETFFT_NOINLINE void general_dst23( ...@@ -2958,20 +2976,18 @@ template<typename T> POCKETFFT_NOINLINE void general_dst23(
it.advance(1); it.advance(1);
auto tdata = reinterpret_cast<T *>(storage.data()); auto tdata = reinterpret_cast<T *>(storage.data());
if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place
type2 ? plan->type2(&out[it.oofs(0)], fct) plan->exec(&out[it.oofs(0)], fct);
: plan->type3(&out[it.oofs(0)], fct);
else if (it.stride_out()==sizeof(T)) // compute FFT in output location else if (it.stride_out()==sizeof(T)) // compute FFT in output location
{ {
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tin[it.iofs(i)]; out[it.oofs(i)] = tin[it.iofs(i)];
type2 ? plan->type2(&out[it.oofs(0)], fct) plan->exec(&out[it.oofs(0)], fct);
: plan->type3(&out[it.oofs(0)], fct);
} }
else else
{ {
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
tdata[i] = tin[it.iofs(i)]; tdata[i] = tin[it.iofs(i)];
type2 ? plan->type2(tdata, fct) : plan->type3(tdata, fct); plan->exec(tdata, fct);
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tdata[i]; out[it.oofs(i)] = tdata[i];
} }
...@@ -3230,12 +3246,18 @@ template<typename T> void r2r_dct(const shape_t &shape, ...@@ -3230,12 +3246,18 @@ template<typename T> void r2r_dct(const shape_t &shape,
int type, const T *data_in, T *data_out, T fct, size_t nthreads=1) int type, const T *data_in, T *data_out, T fct, size_t nthreads=1)
{ {
if ((type<1) || (type>4)) throw runtime_error("invalid DCT type"); if ((type<1) || (type>4)) throw runtime_error("invalid DCT type");
if ((type==1) || (type==4)) throw runtime_error("DCT type not yet supported ");
if (util::prod(shape)==0) return; if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<T> ain(data_in, shape, stride_in); cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out); ndarr<T> aout(data_out, shape, stride_out);
general_dct23(ain, aout, axes, type==2, fct, nthreads); if (type==1)
general_dcst<T_dct1<T>>(ain, aout, axes, fct, nthreads);
else if (type==2)
general_dcst<T_dct2<T>>(ain, aout, axes, fct, nthreads);
else if (type==3)
general_dcst<T_dct3<T>>(ain, aout, axes, fct, nthreads);
else
throw runtime_error("unsupported DCT type");
} }
template<typename T> void r2r_dst(const shape_t &shape, template<typename T> void r2r_dst(const shape_t &shape,
...@@ -3243,12 +3265,18 @@ template<typename T> void r2r_dst(const shape_t &shape, ...@@ -3243,12 +3265,18 @@ template<typename T> void r2r_dst(const shape_t &shape,
int type, const T *data_in, T *data_out, T fct, size_t nthreads=1) int type, const T *data_in, T *data_out, T fct, size_t nthreads=1)
{ {
if ((type<1) || (type>4)) throw runtime_error("invalid DST type"); if ((type<1) || (type>4)) throw runtime_error("invalid DST type");
if ((type==1) || (type==4)) throw runtime_error("DST type not yet supported ");
if (util::prod(shape)==0) return; if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<T> ain(data_in, shape, stride_in); cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out); ndarr<T> aout(data_out, shape, stride_out);
general_dst23(ain, aout, axes, type==2, fct, nthreads); if (type==1)
general_dcst<T_dst1<T>>(ain, aout, axes, fct, nthreads);
else if (type==2)
general_dcst<T_dst2<T>>(ain, aout, axes, fct, nthreads);
else if (type==3)
general_dcst<T_dst3<T>>(ain, aout, axes, fct, nthreads);
else
throw runtime_error("unsupported DST type");
} }
template<typename T> void r2c(const shape_t &shape_in, template<typename T> void r2c(const shape_t &shape_in,
......
...@@ -97,7 +97,7 @@ template<typename T> T norm_fct(int inorm, const shape_t &shape, ...@@ -97,7 +97,7 @@ template<typename T> T norm_fct(int inorm, const shape_t &shape,
if (inorm==0) return T(1); if (inorm==0) return T(1);
size_t N(1); size_t N(1);
for (auto a: axes) for (auto a: axes)
N *= size_t(int64_t(shape[a]*fct)+delta); N *= fct * size_t(int64_t(shape[a])+delta);
return norm_fct<T>(inorm, N); return norm_fct<T>(inorm, N);
} }
...@@ -239,12 +239,10 @@ template<typename T> py::array r2r_dct_internal(const py::array &in, ...@@ -239,12 +239,10 @@ template<typename T> py::array r2r_dct_internal(const py::array &in,
auto d_out=reinterpret_cast<T *>(res.mutable_data()); auto d_out=reinterpret_cast<T *>(res.mutable_data());
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
if ((type==2)||(type==3)) T fct = (type==1) ? norm_fct<T>(inorm, dims, axes, 2, -1)
{ : norm_fct<T>(inorm, dims, axes, 2);
T fct = norm_fct<T>(inorm, dims, axes, 2); pocketfft::r2r_dct(dims, s_in, s_out, axes, type, d_in, d_out, fct,
pocketfft::r2r_dct(dims, s_in, s_out, axes, type, d_in, d_out, fct, nthreads);
nthreads);
}
} }
return res; return res;
} }
...@@ -253,7 +251,6 @@ py::array r2r_dct(const py::array &in, int type, const py::object &axes_, ...@@ -253,7 +251,6 @@ py::array r2r_dct(const py::array &in, int type, const py::object &axes_,
int inorm, py::object &out_, size_t nthreads) int inorm, py::object &out_, size_t nthreads)
{ {
if ((type<1) || (type>4)) throw runtime_error("invalid DCT type"); if ((type<1) || (type>4)) throw runtime_error("invalid DCT type");
if ((type==1) || (type==4)) throw runtime_error("DCT type not yet supported ");
DISPATCH(in, f64, f32, flong, r2r_dct_internal, (in, axes_, type, inorm, DISPATCH(in, f64, f32, flong, r2r_dct_internal, (in, axes_, type, inorm,
out_, nthreads)) out_, nthreads))
} }
...@@ -271,12 +268,10 @@ template<typename T> py::array r2r_dst_internal(const py::array &in, ...@@ -271,12 +268,10 @@ template<typename T> py::array r2r_dst_internal(const py::array &in,
auto d_out=reinterpret_cast<T *>(res.mutable_data()); auto d_out=reinterpret_cast<T *>(res.mutable_data());
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
if ((type==2)||(type==3)) T fct = (type==1) ? norm_fct<T>(inorm, dims, axes, 2, 1)
{ : norm_fct<T>(inorm, dims, axes, 2);
T fct = norm_fct<T>(inorm, dims, axes, 2); pocketfft::r2r_dst(dims, s_in, s_out, axes, type, d_in, d_out, fct,
pocketfft::r2r_dst(dims, s_in, s_out, axes, type, d_in, d_out, fct, nthreads);
nthreads);
}
} }
return res; return res;
} }
...@@ -285,7 +280,6 @@ py::array r2r_dst(const py::array &in, int type, const py::object &axes_, ...@@ -285,7 +280,6 @@ py::array r2r_dst(const py::array &in, int type, const py::object &axes_,
int inorm, py::object &out_, size_t nthreads) int inorm, py::object &out_, size_t nthreads)
{ {
if ((type<1) || (type>4)) throw runtime_error("invalid DST type"); if ((type<1) || (type>4)) throw runtime_error("invalid DST type");
if ((type==1) || (type==4)) throw runtime_error("DST type not yet supported ");
DISPATCH(in, f64, f32, flong, r2r_dst_internal, (in, axes_, type, inorm, DISPATCH(in, f64, f32, flong, r2r_dst_internal, (in, axes_, type, inorm,
out_, nthreads)) out_, nthreads))
} }
......
...@@ -207,13 +207,17 @@ def test_genuine_hartley_2D(shp, axes): ...@@ -207,13 +207,17 @@ def test_genuine_hartley_2D(shp, axes):
@pmp("len", len1D) @pmp("len", len1D)
@pmp("inorm", [0, 1, 2]) @pmp("inorm", [0, 1, 2])
def testdcst1D(len, inorm): @pmp("type", [1, 2, 3])
def testdcst1D(len, inorm, type):
a = np.random.rand(len)-0.5 a = np.random.rand(len)-0.5
b = a.astype(np.float32) b = a.astype(np.float32)
c = a.astype(np.float128) c = a.astype(np.float128)
_assert_close(a, pypocketfft.r2r_dct(pypocketfft.r2r_dct(c, inorm=inorm, type=2), inorm=2-inorm, type=3), 2e-18) itp = (0, 1, 3, 2, 4)
_assert_close(a, pypocketfft.r2r_dct(pypocketfft.r2r_dct(a, inorm=inorm, type=2), inorm=2-inorm, type=3), 1.5e-15) itype = itp[type]
_assert_close(b, pypocketfft.r2r_dct(pypocketfft.r2r_dct(b, inorm=inorm, type=2), inorm=2-inorm, type=3), 6e-7) if type != 1 or len > 1:
_assert_close(a, pypocketfft.r2r_dst(pypocketfft.r2r_dst(c, inorm=inorm, type=2), inorm=2-inorm, type=3), 2e-18) _assert_close(a, pypocketfft.r2r_dct(pypocketfft.r2r_dct(c, inorm=inorm, type=type), inorm=2-inorm, type=itype), 2e-18)
_assert_close(a, pypocketfft.r2r_dst(pypocketfft.r2r_dst(a, inorm=inorm, type=2), inorm=2-inorm, type=3), 1.5e-15) _assert_close(a, pypocketfft.r2r_dct(pypocketfft.r2r_dct(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-15)
_assert_close(b, pypocketfft.r2r_dst(pypocketfft.r2r_dst(b, inorm=inorm, type=2), inorm=2-inorm, type=3), 6e-7) _assert_close(b, pypocketfft.r2r_dct(pypocketfft.r2r_dct(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-7)
_assert_close(a, pypocketfft.r2r_dst(pypocketfft.r2r_dst(c, inorm=inorm, type=type), inorm=2-inorm, type=itype), 2e-18)
_assert_close(a, pypocketfft.r2r_dst(pypocketfft.r2r_dst(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-15)
_assert_close(b, pypocketfft.r2r_dst(pypocketfft.r2r_dst(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-7)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment