Commit 5483e02e authored by Martin Reinecke's avatar Martin Reinecke

add (preliminary) type IV transforms

parent 12f9cf44
......@@ -2416,6 +2416,36 @@ template<typename T0> class T_dct3
size_t length() const { return fftplan.length(); }
};
template<typename T0> class T_dct4
{
private:
T_dct2<T0> dct2;
arr<T0> C;
public:
POCKETFFT_NOINLINE T_dct4(size_t length)
: dct2(length), C(length)
{
constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
for (size_t i=0; i<length; ++i)
C[i] = cos(T0(0.5)*pi*(T0(i)+T0(0.5))/T0(length));
}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const
{
// FIXME: due to the recurrence below, the algorithm is not very accurate
// Trying to find a better one...
size_t N = length();
for (size_t i=0; i<N; ++i)
c[i] *= C[i];
dct2.exec(c, fct);
for (size_t i=1; i<N; ++i)
c[i] = 2*c[i] - c[i-1];
}
size_t length() const { return dct2.length(); }
};
template<typename T0> class T_dst1
{
private:
......@@ -2456,7 +2486,6 @@ template<typename T0> class T_dst2
{
size_t N=length();
if (N==1) { c[0]*=2*fct; return; }
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
dct.exec(c, fct);
......@@ -2491,6 +2520,30 @@ template<typename T0> class T_dst3
size_t length() const { return dct.length(); }
};
template<typename T0> class T_dst4
{
private:
T_dct4<T0> dct;
public:
POCKETFFT_NOINLINE T_dst4(size_t length)
: dct(length) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct)
{
size_t N=length();
//if (N==1) { c[0]*=fct; return; }
size_t NS2 = N/2;
for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
swap(c[k], c[kc]);
dct.exec(c, fct);
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
}
size_t length() const { return dct.length(); }
};
//
// multi-D infrastructure
......@@ -3256,6 +3309,8 @@ template<typename T> void r2r_dct(const shape_t &shape,
general_dcst<T_dct2<T>>(ain, aout, axes, fct, nthreads);
else if (type==3)
general_dcst<T_dct3<T>>(ain, aout, axes, fct, nthreads);
else if (type==4)
general_dcst<T_dct4<T>>(ain, aout, axes, fct, nthreads);
else
throw runtime_error("unsupported DCT type");
}
......@@ -3275,6 +3330,8 @@ template<typename T> void r2r_dst(const shape_t &shape,
general_dcst<T_dst2<T>>(ain, aout, axes, fct, nthreads);
else if (type==3)
general_dcst<T_dst3<T>>(ain, aout, axes, fct, nthreads);
else if (type==4)
general_dcst<T_dst4<T>>(ain, aout, axes, fct, nthreads);
else
throw runtime_error("unsupported DST type");
}
......
......@@ -221,3 +221,22 @@ def testdcst1D(len, inorm, type):
_assert_close(a, pypocketfft.r2r_dst(pypocketfft.r2r_dst(c, inorm=inorm, type=type), inorm=2-inorm, type=itype), 2e-18)
_assert_close(a, pypocketfft.r2r_dst(pypocketfft.r2r_dst(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-15)
_assert_close(b, pypocketfft.r2r_dst(pypocketfft.r2r_dst(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-7)
# TEMPORARY: separate test for DCT/DST IV, since they are less accurate
@pmp("len", len1D)
@pmp("inorm", [0, 1, 2])
@pmp("type", [4])
def testdcst1D4(len, inorm, type):
a = np.random.rand(len)-0.5
b = a.astype(np.float32)
c = a.astype(np.float128)
itp = (0, 1, 3, 2, 4)
itype = itp[type]
if type != 1 or len > 1:
_assert_close(a, pypocketfft.r2r_dct(pypocketfft.r2r_dct(c, inorm=inorm, type=type), inorm=2-inorm, type=itype), 2e-16)
_assert_close(a, pypocketfft.r2r_dct(pypocketfft.r2r_dct(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-13)
_assert_close(b, pypocketfft.r2r_dct(pypocketfft.r2r_dct(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-5)
_assert_close(a, pypocketfft.r2r_dst(pypocketfft.r2r_dst(c, inorm=inorm, type=type), inorm=2-inorm, type=itype), 2e-16)
_assert_close(a, pypocketfft.r2r_dst(pypocketfft.r2r_dst(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-13)
_assert_close(b, pypocketfft.r2r_dst(pypocketfft.r2r_dst(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-5)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment