Commit 8bbee24d authored by Martin Reinecke's avatar Martin Reinecke

try to add ortho support

parent 913949e9
......@@ -2296,9 +2296,12 @@ template<typename T0> class T_dct1
POCKETFFT_NOINLINE T_dct1(size_t length)
: fftplan(2*(length-1)) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=fftplan.length(), n=N/2+1;
if (ortho)
{ c[0]*=sqrt2; c[n-1]*=sqrt2; }
arr<T> tmp(N);
tmp[0] = c[0];
for (size_t i=1; i<n; ++i)
......@@ -2307,6 +2310,8 @@ template<typename T0> class T_dct1
c[0] = tmp[0];
for (size_t i=1; i<n; ++i)
c[i] = tmp[2*i-1];
if (ortho)
{ c[0]/=sqrt2; c[n-1]/=sqrt2; }
}
size_t length() const { return fftplan.length()/2+1; }
......@@ -2327,18 +2332,20 @@ template<typename T0> class T_dct2
twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length)));
}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length();
if (N==1) { c[0]*=2*fct; return; }
if (N==2)
if (N==1)
c[0]*=2*fct;
else if (N==2)
{
T x1 = 2*fct*(c[0]+c[1]);
c[1] = sqrt2*fct*(c[0]-c[1]);
c[0] = x1;
return;
}
else
{
size_t NS2 = (N+1)/2;
for (size_t i=2; i<N; i+=2)
{
......@@ -2363,6 +2370,8 @@ template<typename T0> class T_dct2
}
c[0] *= 2;
}
if (ortho) c[0]/=sqrt2;
}
size_t length() const { return fftplan.length(); }
};
......@@ -2382,18 +2391,21 @@ template<typename T0> class T_dct3
twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length)));
}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length();
if (N==1) { c[0]*=fct; return; }
if (N==2)
if (ortho) c[0]*=sqrt2;
if (N==1)
c[0]*=fct;
else if (N==2)
{
T TSQX = sqrt2*c[1];
c[1] = fct*(c[0]-TSQX);
c[0] = fct*(c[0]+TSQX);
return;
}
else
{
size_t NS2 = (N+1)/2;
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
{
......@@ -2419,6 +2431,7 @@ template<typename T0> class T_dct3
c[i-1] = xim1;
}
}
}
size_t length() const { return fftplan.length(); }
};
......@@ -2455,7 +2468,7 @@ template<typename T0> class T_dct4
}
}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const
{
// FIXME: due to the recurrence below, the algorithm is not very accurate
// Trying to find a better one...
......@@ -2463,7 +2476,7 @@ template<typename T0> class T_dct4
{
for (size_t i=0; i<N; ++i)
c[i] *= C[i];
dct2.exec(c, fct);
dct2.exec(c, fct, false);
for (size_t i=1; i<N; ++i)
c[i] = 2*c[i] - c[i-1];
}
......@@ -2498,7 +2511,7 @@ template<typename T0> class T_dst1
POCKETFFT_NOINLINE T_dst1(size_t length)
: fftplan(2*(length+1)) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const
{
size_t N=fftplan.length(), n=N/2-1;
arr<T> tmp(N);
......@@ -2525,16 +2538,22 @@ template<typename T0> class T_dst2
POCKETFFT_NOINLINE T_dst2(size_t length)
: dct(length) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length();
if (N==1) { c[0]*=2*fct; return; }
if (N==1)
c[0]*=2*fct;
else
{
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
dct.exec(c, fct);
dct.exec(c, fct, false);
for (size_t k=0, kc=N-1; k<kc; ++k, --kc)
swap(c[k], c[kc]);
}
if (ortho) c[0]/=sqrt2;
}
size_t length() const { return dct.length(); }
};
......@@ -2548,17 +2567,23 @@ template<typename T0> class T_dst3
POCKETFFT_NOINLINE T_dst3(size_t length)
: dct(length) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct)
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho)
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length();
if (N==1) { c[0]*=fct; return; }
if (ortho) c[0]*=sqrt2;
if (N==1)
c[0]*=fct;
else
{
size_t NS2 = N/2;
for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
swap(c[k], c[kc]);
dct.exec(c, fct);
dct.exec(c, fct, false);
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
}
}
size_t length() const { return dct.length(); }
};
......@@ -2572,14 +2597,14 @@ template<typename T0> class T_dst4
POCKETFFT_NOINLINE T_dst4(size_t length)
: dct(length) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct)
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/)
{
size_t N=length();
//if (N==1) { c[0]*=fct; return; }
size_t NS2 = N/2;
for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
swap(c[k], c[kc]);
dct.exec(c, fct);
dct.exec(c, fct, false);
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
}
......@@ -3033,7 +3058,7 @@ template<typename T> POCKETFFT_NOINLINE void general_hartley(
template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes,
T fct, size_t POCKETFFT_NTHREADS)
T fct, bool ortho, size_t POCKETFFT_NTHREADS)
{
shared_ptr<Trafo> plan;
......@@ -3061,7 +3086,7 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = tin[it.iofs(j,i)];
plan->exec(tdatav, fct);
plan->exec(tdatav, fct, ortho);
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,i)] = tdatav[i][j];
......@@ -3072,18 +3097,18 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
it.advance(1);
auto tdata = reinterpret_cast<T *>(storage.data());
if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place
plan->exec(&out[it.oofs(0)], fct);
plan->exec(&out[it.oofs(0)], fct, ortho);
else if (it.stride_out()==sizeof(T)) // compute FFT in output location
{
for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tin[it.iofs(i)];
plan->exec(&out[it.oofs(0)], fct);
plan->exec(&out[it.oofs(0)], fct, ortho);
}
else
{
for (size_t i=0; i<len; ++i)
tdata[i] = tin[it.iofs(i)];
plan->exec(tdata, fct);
plan->exec(tdata, fct, ortho);
for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tdata[i];
}
......@@ -3339,7 +3364,7 @@ template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,
template<typename T> void dct(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
int type, const T *data_in, T *data_out, T fct, size_t nthreads=1)
int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1)
{
if ((type<1) || (type>4)) throw invalid_argument("invalid DCT type");
if (util::prod(shape)==0) return;
......@@ -3347,20 +3372,20 @@ template<typename T> void dct(const shape_t &shape,
cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out);
if (type==1)
general_dcst<T_dct1<T>>(ain, aout, axes, fct, nthreads);
general_dcst<T_dct1<T>>(ain, aout, axes, fct, ortho, nthreads);
else if (type==2)
general_dcst<T_dct2<T>>(ain, aout, axes, fct, nthreads);
general_dcst<T_dct2<T>>(ain, aout, axes, fct, ortho, nthreads);
else if (type==3)
general_dcst<T_dct3<T>>(ain, aout, axes, fct, nthreads);
general_dcst<T_dct3<T>>(ain, aout, axes, fct, ortho, nthreads);
else if (type==4)
general_dcst<T_dct4<T>>(ain, aout, axes, fct, nthreads);
general_dcst<T_dct4<T>>(ain, aout, axes, fct, ortho, nthreads);
else
throw runtime_error("unsupported DCT type");
}
template<typename T> void dst(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
int type, const T *data_in, T *data_out, T fct, size_t nthreads=1)
int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1)
{
if ((type<1) || (type>4)) throw invalid_argument("invalid DST type");
if (util::prod(shape)==0) return;
......@@ -3368,13 +3393,13 @@ template<typename T> void dst(const shape_t &shape,
cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out);
if (type==1)
general_dcst<T_dst1<T>>(ain, aout, axes, fct, nthreads);
general_dcst<T_dst1<T>>(ain, aout, axes, fct, ortho, nthreads);
else if (type==2)
general_dcst<T_dst2<T>>(ain, aout, axes, fct, nthreads);
general_dcst<T_dst2<T>>(ain, aout, axes, fct, ortho, nthreads);
else if (type==3)
general_dcst<T_dst3<T>>(ain, aout, axes, fct, nthreads);
general_dcst<T_dst3<T>>(ain, aout, axes, fct, ortho, nthreads);
else if (type==4)
general_dcst<T_dst4<T>>(ain, aout, axes, fct, nthreads);
general_dcst<T_dst4<T>>(ain, aout, axes, fct, ortho, nthreads);
else
throw runtime_error("unsupported DST type");
}
......
......@@ -227,7 +227,7 @@ py::array r2r_fftpack(const py::array &in, const py::object &axes_,
}
template<typename T> py::array dct_internal(const py::array &in,
const py::object &axes_, int type, int inorm, py::object &out_,
const py::object &axes_, int type, int inorm, bool ortho, py::object &out_,
size_t nthreads)
{
auto axes = makeaxes(in, axes_);
......@@ -239,24 +239,26 @@ template<typename T> py::array dct_internal(const py::array &in,
auto d_out=reinterpret_cast<T *>(res.mutable_data());
{
py::gil_scoped_release release;
// override
if (ortho) inorm=1;
T fct = (type==1) ? norm_fct<T>(inorm, dims, axes, 2, -1)
: norm_fct<T>(inorm, dims, axes, 2);
pocketfft::dct(dims, s_in, s_out, axes, type, d_in, d_out, fct,
pocketfft::dct(dims, s_in, s_out, axes, type, d_in, d_out, fct, ortho,
nthreads);
}
return res;
}
py::array dct(const py::array &in, int type, const py::object &axes_,
int inorm, py::object &out_, size_t nthreads)
int inorm, bool ortho, py::object &out_, size_t nthreads)
{
if ((type<1) || (type>4)) throw invalid_argument("invalid DCT type");
DISPATCH(in, f64, f32, flong, dct_internal, (in, axes_, type, inorm,
DISPATCH(in, f64, f32, flong, dct_internal, (in, axes_, type, inorm, ortho,
out_, nthreads))
}
template<typename T> py::array dst_internal(const py::array &in,
const py::object &axes_, int type, int inorm, py::object &out_,
const py::object &axes_, int type, int inorm, bool ortho, py::object &out_,
size_t nthreads)
{
auto axes = makeaxes(in, axes_);
......@@ -268,19 +270,21 @@ template<typename T> py::array dst_internal(const py::array &in,
auto d_out=reinterpret_cast<T *>(res.mutable_data());
{
py::gil_scoped_release release;
// override
if (ortho) inorm=1;
T fct = (type==1) ? norm_fct<T>(inorm, dims, axes, 2, 1)
: norm_fct<T>(inorm, dims, axes, 2);
pocketfft::dst(dims, s_in, s_out, axes, type, d_in, d_out, fct,
pocketfft::dst(dims, s_in, s_out, axes, type, d_in, d_out, fct, ortho,
nthreads);
}
return res;
}
py::array dst(const py::array &in, int type, const py::object &axes_,
int inorm, py::object &out_, size_t nthreads)
int inorm, bool ortho, py::object &out_, size_t nthreads)
{
if ((type<1) || (type>4)) throw invalid_argument("invalid DST type");
DISPATCH(in, f64, f32, flong, dst_internal, (in, axes_, type, inorm,
DISPATCH(in, f64, f32, flong, dst_internal, (in, axes_, type, inorm, ortho,
out_, nthreads))
}
......@@ -665,7 +669,7 @@ PYBIND11_MODULE(pypocketfft, m)
m.def("genuine_hartley", genuine_hartley, genuine_hartley_DS, "a"_a,
"axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1);
m.def("dct", dct, dct_DS, "a"_a, "type"_a, "axes"_a=None, "inorm"_a=0,
"out"_a=None, "nthreads"_a=1);
"ortho"_a=false, "out"_a=None, "nthreads"_a=1);
m.def("dst", dst, dst_DS, "a"_a, "type"_a, "axes"_a=None, "inorm"_a=0,
"out"_a=None, "nthreads"_a=1);
"ortho"_a=false, "out"_a=None, "nthreads"_a=1);
}
......@@ -222,6 +222,23 @@ def testdcst1D(len, inorm, type):
_assert_close(a, pypocketfft.dst(pypocketfft.dst(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-15)
_assert_close(b, pypocketfft.dst(pypocketfft.dst(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-7)
@pmp("len", len1D)
@pmp("type", [1, 2, 3])
def testdcst1Dortho(len, type):
a = np.random.rand(len)-0.5
b = a.astype(np.float32)
c = a.astype(np.float128)
itp = (0, 1, 3, 2, 4)
itype = itp[type]
if type != 1 or len > 1:
_assert_close(a, pypocketfft.dct(pypocketfft.dct(c, ortho=True, type=type), ortho=True, type=itype), 2e-18)
_assert_close(a, pypocketfft.dct(pypocketfft.dct(a, ortho=True, type=type), ortho=True, type=itype), 1.5e-15)
_assert_close(b, pypocketfft.dct(pypocketfft.dct(b, ortho=True, type=type), ortho=True, type=itype), 6e-7)
if type != 1:
_assert_close(a, pypocketfft.dst(pypocketfft.dst(c, ortho=True, type=type), ortho=True, type=itype), 2e-18)
_assert_close(a, pypocketfft.dst(pypocketfft.dst(a, ortho=True, type=type), ortho=True, type=itype), 1.5e-15)
_assert_close(b, pypocketfft.dst(pypocketfft.dst(b, ortho=True, type=type), ortho=True, type=itype), 6e-7)
# TEMPORARY: separate test for DCT/DST IV, since they are less accurate
@pmp("len", len1D)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment