Commit 514e1afb authored by Martin Reinecke's avatar Martin Reinecke

improve interface

parent 0f5781c9
......@@ -2296,7 +2296,7 @@ template<typename T0> class T_cosq
twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length)));
}
template<typename T> POCKETFFT_NOINLINE void backward(T c[], T0 fct)
template<typename T> POCKETFFT_NOINLINE void type2(T c[], T0 fct)
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length();
......@@ -2315,9 +2315,6 @@ template<typename T0> class T_cosq
c[i] = T0(0.5)*(c[i]-c[i-1]);
c[i-1] = xim1;
}
// c[0] *= 2;
// if ((N&1)==0)
// c[N-1] *= 2;
fftplan.backward(c, fct);
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
{
......@@ -2336,7 +2333,7 @@ template<typename T0> class T_cosq
c[0] *= 2;
}
template<typename T> POCKETFFT_NOINLINE void forward(T c[], T0 fct)
template<typename T> POCKETFFT_NOINLINE void type3(T c[], T0 fct)
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length();
......@@ -2386,26 +2383,26 @@ template<typename T0> class T_sinq
POCKETFFT_NOINLINE T_sinq(size_t length)
: cosq(length) {}
template<typename T> POCKETFFT_NOINLINE void backward(T c[], T0 fct)
template<typename T> POCKETFFT_NOINLINE void type2(T c[], T0 fct)
{
size_t N=length();
if (N==1) { c[0]*=2*fct; return; }
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
cosq.backward(c, fct);
cosq.type2(c, fct);
for (size_t k=0, kc=N-1; k<kc; ++k, --kc)
swap(c[k], c[kc]);
}
template<typename T> POCKETFFT_NOINLINE void forward(T c[], T0 fct)
template<typename T> POCKETFFT_NOINLINE void type3(T c[], T0 fct)
{
size_t N=length();
if (N==1) { c[0]*=fct; return; }
size_t NS2 = N/2;
for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
swap(c[k], c[kc]);
cosq.forward(c, fct);
cosq.type3(c, fct);
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
}
......@@ -2858,7 +2855,7 @@ template<typename T> POCKETFFT_NOINLINE void general_hartley(
}
template<typename T> POCKETFFT_NOINLINE void general_dct23(
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, bool forward, T fct,
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, bool type2, T fct,
size_t POCKETFFT_NTHREADS)
{
shared_ptr<T_cosq<T>> plan;
......@@ -2887,7 +2884,7 @@ template<typename T> POCKETFFT_NOINLINE void general_dct23(
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = tin[it.iofs(j,i)];
forward ? plan->forward (tdatav, fct) : plan->backward(tdatav, fct);
type2 ? plan->type2(tdatav, fct) : plan->type3(tdatav, fct);
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,i)] = tdatav[i][j];
......@@ -2898,20 +2895,20 @@ template<typename T> POCKETFFT_NOINLINE void general_dct23(
it.advance(1);
auto tdata = reinterpret_cast<T *>(storage.data());
if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place
forward ? plan->forward (&out[it.oofs(0)], fct)
: plan->backward(&out[it.oofs(0)], fct);
type2 ? plan->type2(&out[it.oofs(0)], fct)
: plan->type3(&out[it.oofs(0)], fct);
else if (it.stride_out()==sizeof(T)) // compute FFT in output location
{
for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tin[it.iofs(i)];
forward ? plan->forward (&out[it.oofs(0)], fct)
: plan->backward(&out[it.oofs(0)], fct);
type2 ? plan->type2(&out[it.oofs(0)], fct)
: plan->type3(&out[it.oofs(0)], fct);
}
else
{
for (size_t i=0; i<len; ++i)
tdata[i] = tin[it.iofs(i)];
forward ? plan->forward (tdata, fct) : plan->backward(tdata, fct);
type2 ? plan->type2(tdata, fct) : plan->type3(tdata, fct);
for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tdata[i];
}
......@@ -2921,7 +2918,7 @@ template<typename T> POCKETFFT_NOINLINE void general_dct23(
}
}
template<typename T> POCKETFFT_NOINLINE void general_dst23(
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, bool forward, T fct,
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, bool type2, T fct,
size_t POCKETFFT_NTHREADS)
{
shared_ptr<T_sinq<T>> plan;
......@@ -2950,7 +2947,7 @@ template<typename T> POCKETFFT_NOINLINE void general_dst23(
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = tin[it.iofs(j,i)];
forward ? plan->forward (tdatav, fct) : plan->backward(tdatav, fct);
type2 ? plan->type2(tdatav, fct) : plan->type3(tdatav, fct);
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,i)] = tdatav[i][j];
......@@ -2961,20 +2958,20 @@ template<typename T> POCKETFFT_NOINLINE void general_dst23(
it.advance(1);
auto tdata = reinterpret_cast<T *>(storage.data());
if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place
forward ? plan->forward (&out[it.oofs(0)], fct)
: plan->backward(&out[it.oofs(0)], fct);
type2 ? plan->type2(&out[it.oofs(0)], fct)
: plan->type3(&out[it.oofs(0)], fct);
else if (it.stride_out()==sizeof(T)) // compute FFT in output location
{
for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tin[it.iofs(i)];
forward ? plan->forward (&out[it.oofs(0)], fct)
: plan->backward(&out[it.oofs(0)], fct);
type2 ? plan->type2(&out[it.oofs(0)], fct)
: plan->type3(&out[it.oofs(0)], fct);
}
else
{
for (size_t i=0; i<len; ++i)
tdata[i] = tin[it.iofs(i)];
forward ? plan->forward (tdata, fct) : plan->backward(tdata, fct);
type2 ? plan->type2(tdata, fct) : plan->type3(tdata, fct);
for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tdata[i];
}
......@@ -3228,26 +3225,30 @@ template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,
general_c(ain, aout, axes, forward, fct, nthreads);
}
template<typename T> void r2r_dct23(const shape_t &shape,
template<typename T> void r2r_dct(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
bool forward, const T *data_in, T *data_out, T fct, size_t nthreads=1)
int type, const T *data_in, T *data_out, T fct, size_t nthreads=1)
{
if ((type<1) || (type>4)) throw runtime_error("invalid DCT type");
if ((type==1) || (type==4)) throw runtime_error("DCT type not yet supported ");
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out);
general_dct23(ain, aout, axes, forward, fct, nthreads);
general_dct23(ain, aout, axes, type==2, fct, nthreads);
}
template<typename T> void r2r_dst23(const shape_t &shape,
template<typename T> void r2r_dst(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
bool forward, const T *data_in, T *data_out, T fct, size_t nthreads=1)
int type, const T *data_in, T *data_out, T fct, size_t nthreads=1)
{
if ((type<1) || (type>4)) throw runtime_error("invalid DST type");
if ((type==1) || (type==4)) throw runtime_error("DST type not yet supported ");
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out);
general_dst23(ain, aout, axes, forward, fct, nthreads);
general_dst23(ain, aout, axes, type==2, fct, nthreads);
}
template<typename T> void r2c(const shape_t &shape_in,
......@@ -3356,8 +3357,8 @@ using detail::c2r;
using detail::r2c;
using detail::r2r_fftpack;
using detail::r2r_separable_hartley;
using detail::r2r_dct23;
using detail::r2r_dst23;
using detail::r2r_dct;
using detail::r2r_dst;
} // namespace pocketfft
......
......@@ -92,12 +92,12 @@ template<typename T> T norm_fct(int inorm, size_t N)
}
template<typename T> T norm_fct(int inorm, const shape_t &shape,
const shape_t &axes)
const shape_t &axes, size_t fct=1, int delta=0)
{
if (inorm==0) return T(1);
size_t N(1);
for (auto a: axes)
N *= shape[a];
N *= size_t(int64_t(shape[a]*fct)+delta);
return norm_fct<T>(inorm, N);
}
......@@ -226,8 +226,8 @@ py::array r2r_fftpack(const py::array &in, const py::object &axes_,
real2hermitian, forward, inorm, out_, nthreads))
}
template<typename T> py::array r2r_dct23_internal(const py::array &in,
const py::object &axes_, bool forward, int inorm, py::object &out_,
template<typename T> py::array r2r_dct_internal(const py::array &in,
const py::object &axes_, int type, int inorm, py::object &out_,
size_t nthreads)
{
auto axes = makeaxes(in, axes_);
......@@ -239,24 +239,27 @@ template<typename T> py::array r2r_dct23_internal(const py::array &in,
auto d_out=reinterpret_cast<T *>(res.mutable_data());
{
py::gil_scoped_release release;
T fct = norm_fct<T>(inorm, dims, axes);
if (inorm==2) fct*=T(1/ldbl_t(2));
if (inorm==1) fct*=T(1/sqrt(ldbl_t(2)));
pocketfft::r2r_dct23(dims, s_in, s_out, axes, forward, d_in, d_out, fct,
nthreads);
if ((type==2)||(type==3))
{
T fct = norm_fct<T>(inorm, dims, axes, 2);
pocketfft::r2r_dct(dims, s_in, s_out, axes, type, d_in, d_out, fct,
nthreads);
}
}
return res;
}
py::array r2r_dct23(const py::array &in, const py::object &axes_,
bool forward, int inorm, py::object &out_, size_t nthreads)
py::array r2r_dct(const py::array &in, int type, const py::object &axes_,
int inorm, py::object &out_, size_t nthreads)
{
DISPATCH(in, f64, f32, flong, r2r_dct23_internal, (in, axes_,
forward, inorm, out_, nthreads))
if ((type<1) || (type>4)) throw runtime_error("invalid DCT type");
if ((type==1) || (type==4)) throw runtime_error("DCT type not yet supported ");
DISPATCH(in, f64, f32, flong, r2r_dct_internal, (in, axes_, type, inorm,
out_, nthreads))
}
template<typename T> py::array r2r_dst23_internal(const py::array &in,
const py::object &axes_, bool forward, int inorm, py::object &out_,
template<typename T> py::array r2r_dst_internal(const py::array &in,
const py::object &axes_, int type, int inorm, py::object &out_,
size_t nthreads)
{
auto axes = makeaxes(in, axes_);
......@@ -268,20 +271,23 @@ template<typename T> py::array r2r_dst23_internal(const py::array &in,
auto d_out=reinterpret_cast<T *>(res.mutable_data());
{
py::gil_scoped_release release;
T fct = norm_fct<T>(inorm, dims, axes);
if (inorm==2) fct*=T(1/ldbl_t(2));
if (inorm==1) fct*=T(1/sqrt(ldbl_t(2)));
pocketfft::r2r_dst23(dims, s_in, s_out, axes, forward, d_in, d_out, fct,
nthreads);
if ((type==2)||(type==3))
{
T fct = norm_fct<T>(inorm, dims, axes, 2);
pocketfft::r2r_dst(dims, s_in, s_out, axes, type, d_in, d_out, fct,
nthreads);
}
}
return res;
}
py::array r2r_dst23(const py::array &in, const py::object &axes_,
bool forward, int inorm, py::object &out_, size_t nthreads)
py::array r2r_dst(const py::array &in, int type, const py::object &axes_,
int inorm, py::object &out_, size_t nthreads)
{
DISPATCH(in, f64, f32, flong, r2r_dst23_internal, (in, axes_,
forward, inorm, out_, nthreads))
if ((type<1) || (type>4)) throw runtime_error("invalid DST type");
if ((type==1) || (type==4)) throw runtime_error("DST type not yet supported ");
DISPATCH(in, f64, f32, flong, r2r_dst_internal, (in, axes_, type, inorm,
out_, nthreads))
}
template<typename T> py::array c2r_internal(const py::array &in,
......@@ -601,8 +607,8 @@ PYBIND11_MODULE(pypocketfft, m)
"axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1);
m.def("genuine_hartley", genuine_hartley, genuine_hartley_DS, "a"_a,
"axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1);
m.def("r2r_dct23", r2r_dct23, /*r2r_dct23_DS,*/ "a"_a,
"axes"_a=None, "forward"_a, "inorm"_a=0, "out"_a=None, "nthreads"_a=1);
m.def("r2r_dst23", r2r_dst23, /*r2r_dst23_DS,*/ "a"_a,
"axes"_a=None, "forward"_a, "inorm"_a=0, "out"_a=None, "nthreads"_a=1);
m.def("r2r_dct", r2r_dct, /*r2r_dct_DS,*/ "a"_a, "type"_a, "axes"_a=None,
"inorm"_a=0, "out"_a=None, "nthreads"_a=1);
m.def("r2r_dst", r2r_dst, /*r2r_dst_DS,*/ "a"_a, "type"_a, "axes"_a=None,
"inorm"_a=0, "out"_a=None, "nthreads"_a=1);
}
......@@ -211,9 +211,9 @@ def testdcst1D(len, inorm):
a = np.random.rand(len)-0.5
b = a.astype(np.float32)
c = a.astype(np.float128)
_assert_close(a, pypocketfft.r2r_dct23(pypocketfft.r2r_dct23(c, inorm=inorm, forward=True), inorm=2-inorm, forward=False), 2e-18)
_assert_close(a, pypocketfft.r2r_dct23(pypocketfft.r2r_dct23(a, inorm=inorm, forward=True), inorm=2-inorm, forward=False), 1.5e-15)
_assert_close(b, pypocketfft.r2r_dct23(pypocketfft.r2r_dct23(b, inorm=inorm, forward=True), inorm=2-inorm, forward=False), 6e-7)
_assert_close(a, pypocketfft.r2r_dst23(pypocketfft.r2r_dst23(c, inorm=inorm, forward=True), inorm=2-inorm, forward=False), 2e-18)
_assert_close(a, pypocketfft.r2r_dst23(pypocketfft.r2r_dst23(a, inorm=inorm, forward=True), inorm=2-inorm, forward=False), 1.5e-15)
_assert_close(b, pypocketfft.r2r_dst23(pypocketfft.r2r_dst23(b, inorm=inorm, forward=True), inorm=2-inorm, forward=False), 6e-7)
_assert_close(a, pypocketfft.r2r_dct(pypocketfft.r2r_dct(c, inorm=inorm, type=2), inorm=2-inorm, type=3), 2e-18)
_assert_close(a, pypocketfft.r2r_dct(pypocketfft.r2r_dct(a, inorm=inorm, type=2), inorm=2-inorm, type=3), 1.5e-15)
_assert_close(b, pypocketfft.r2r_dct(pypocketfft.r2r_dct(b, inorm=inorm, type=2), inorm=2-inorm, type=3), 6e-7)
_assert_close(a, pypocketfft.r2r_dst(pypocketfft.r2r_dst(c, inorm=inorm, type=2), inorm=2-inorm, type=3), 2e-18)
_assert_close(a, pypocketfft.r2r_dst(pypocketfft.r2r_dst(a, inorm=inorm, type=2), inorm=2-inorm, type=3), 1.5e-15)
_assert_close(b, pypocketfft.r2r_dst(pypocketfft.r2r_dst(b, inorm=inorm, type=2), inorm=2-inorm, type=3), 6e-7)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment