From 0ad63b5f37809b540825ba2e07d5f7caf153cccc Mon Sep 17 00:00:00 2001 From: Martin Reinecke Date: Thu, 9 May 2019 16:47:34 +0200 Subject: [PATCH] sync --- pocketfft_hdronly.h | 35 ++++++++++++++++++----------------- pypocketfft.cc | 19 ++++++++++++------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index 40530e0..5c552ef 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -19,6 +19,7 @@ #include #include #include +#include #if defined(_WIN32) #include #endif @@ -2381,9 +2382,9 @@ template NOINLINE void general_r( } // namespace detail -template void c2c(const shape_t &shape, - const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, - bool forward, const void *data_in, void *data_out, T fct) +template void c2c(const shape_t &shape, const stride_t &stride_in, + const stride_t &stride_out, const shape_t &axes, bool forward, + const std::complex *data_in, std::complex *data_out, T fct) { using namespace detail; util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); @@ -2392,23 +2393,23 @@ template void c2c(const shape_t &shape, general_c(ain, aout, axes, forward, fct); } -template void r2c(const shape_t &shape, - const stride_t &stride_in, const stride_t &stride_out, size_t axis, - const void *data_in, void *data_out, T fct) +template void r2c(const shape_t &shape, const stride_t &stride_in, + const stride_t &stride_out, size_t axis, const T *data_in, + std::complex *data_out, T fct) { using namespace detail; - util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axis); + util::sanity_check(shape, stride_in, stride_out, false, axis); ndarr ain(data_in, shape, stride_in); ndarr> aout(data_out, shape, stride_out); general_r2c(ain, aout, axis, fct); } template void r2c(const shape_t &shape, const stride_t &stride_in, - const stride_t &stride_out, const shape_t &axes, const void *data_in, - void *data_out, T fct) + const stride_t &stride_out, const shape_t &axes, const T *data_in, + std::complex *data_out, T fct) { using namespace detail; - util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + util::sanity_check(shape, stride_in, stride_out, false, axes); r2c(shape, stride_in, stride_out, axes.back(), data_in, data_out, fct); if (axes.size()==1) return; @@ -2421,10 +2422,10 @@ template void r2c(const shape_t &shape, const stride_t &stride_in, template void c2r(const shape_t &shape, size_t new_size, const stride_t &stride_in, const stride_t &stride_out, size_t axis, - const void *data_in, void *data_out, T fct) + const std::complex *data_in, T *data_out, T fct) { using namespace detail; - util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axis); + util::sanity_check(shape, stride_in, stride_out, false, axis); shape_t shape_out(shape); shape_out[axis] = new_size; ndarr> ain(data_in, shape, stride_in); @@ -2434,7 +2435,7 @@ template void c2r(const shape_t &shape, size_t new_size, template void c2r(const shape_t &shape, size_t new_size, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, - const void *data_in, void *data_out, T fct) + const std::complex *data_in, T *data_out, T fct) { using namespace detail; if (axes.size()==1) @@ -2443,13 +2444,13 @@ template void c2r(const shape_t &shape, size_t new_size, data_in, data_out, fct); return; } - util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + util::sanity_check(shape, stride_in, stride_out, false, axes); auto nval = util::prod(shape); stride_t stride_inter(shape.size()); stride_inter.back() = sizeof(cmplx); for (int i=shape.size()-2; i>=0; --i) stride_inter[i] = stride_inter[i+1]*shape[i+1]; - arr tmp(nval*sizeof(cmplx)); + arr> tmp(nval); auto newaxes = shape_t({axes.begin(), --axes.end()}); c2c(shape, stride_in, stride_inter, newaxes, false, data_in, tmp.data(), T(1)); c2r(shape, new_size, stride_inter, stride_out, axes.back(), @@ -2458,7 +2459,7 @@ template void c2r(const shape_t &shape, size_t new_size, template void r2r_fftpack(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, size_t axis, - bool forward, const void *data_in, void *data_out, T fct) + bool forward, const T *data_in, T *data_out, T fct) { using namespace detail; util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axis); @@ -2468,7 +2469,7 @@ template void r2r_fftpack(const shape_t &shape, template void r2r_hartley(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, - const void *data_in, void *data_out, T fct) + const T *data_in, T *data_out, T fct) { using namespace detail; util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); diff --git a/pypocketfft.cc b/pypocketfft.cc index 273fc87..5c74ef1 100644 --- a/pypocketfft.cc +++ b/pypocketfft.cc @@ -82,8 +82,9 @@ template py::array xfftn_internal(const py::array &in, { auto dims(copy_shape(in)); py::array res = inplace ? in : py::array_t>(dims); - c2c(dims, copy_strides(in), copy_strides(res), axes, fwd, in.data(), - res.mutable_data(), T(fct)); + c2c(dims, copy_strides(in), copy_strides(res), axes, fwd, + reinterpret_cast *>(in.data()), + reinterpret_cast *>(res.mutable_data()), T(fct)); return res; } @@ -106,8 +107,9 @@ template py::array rfftn_internal(const py::array &in, auto dims_in(copy_shape(in)), dims_out(dims_in); dims_out[axes.back()] = (dims_out[axes.back()]>>1)+1; py::array res = py::array_t>(dims_out); - r2c(dims_in, copy_strides(in), copy_strides(res), axes, in.data(), - res.mutable_data(), T(fct)); + r2c(dims_in, copy_strides(in), copy_strides(res), axes, + reinterpret_cast(in.data()), + reinterpret_cast *>(res.mutable_data()), T(fct)); return res; } py::array rfftn(const py::array &in, py::object axes_, double fct) @@ -121,7 +123,8 @@ template py::array xrfft_scipy(const py::array &in, auto dims(copy_shape(in)); py::array res = inplace ? in : py::array_t(dims); r2r_fftpack(dims, copy_strides(in), copy_strides(res), axis, fwd, - in.data(), res.mutable_data(), T(fct)); + reinterpret_cast(in.data()), + reinterpret_cast(res.mutable_data()), T(fct)); return res; } py::array rfft_scipy(const py::array &in, size_t axis, double fct, bool inplace) @@ -149,7 +152,8 @@ template py::array irfftn_internal(const py::array &in, dims_out[axis] = lastsize; py::array res = py::array_t(dims_out); c2r(dims_in, lastsize, copy_strides(in), copy_strides(res), axes, - in.data(), res.mutable_data(), T(fct)); + reinterpret_cast *>(in.data()), + reinterpret_cast(res.mutable_data()), T(fct)); return res; } py::array irfftn(const py::array &in, py::object axes_, size_t lastsize, @@ -166,7 +170,8 @@ template py::array hartley_internal(const py::array &in, auto dims(copy_shape(in)); py::array res = inplace ? in : py::array_t(dims); r2r_hartley(dims, copy_strides(in), copy_strides(res), makeaxes(in, axes_), - in.data(), res.mutable_data(), T(fct)); + reinterpret_cast(in.data()), + reinterpret_cast(res.mutable_data()), T(fct)); return res; } py::array hartley(const py::array &in, py::object axes_, double fct, -- GitLab