Commit e2b8a7ca authored by Martin Reinecke's avatar Martin Reinecke
Browse files

start using fmav in FFT

parent 00c8d288
...@@ -87,3 +87,4 @@ libtool ...@@ -87,3 +87,4 @@ libtool
ar-lib ar-lib
*-uninstalled.sh *-uninstalled.sh
missing missing
test-driver
...@@ -243,7 +243,7 @@ unique_ptr<sharp_geom_info> sharp_make_fejer1_geom_info (size_t nrings, size_t p ...@@ -243,7 +243,7 @@ unique_ptr<sharp_geom_info> sharp_make_fejer1_geom_info (size_t nrings, size_t p
weight[2*k ]=2./(1.-4.*k*k)*sin((k*pi)/nrings); weight[2*k ]=2./(1.-4.*k*k)*sin((k*pi)/nrings);
} }
if ((nrings&1)==0) weight[nrings-1]=0.; if ((nrings&1)==0) weight[nrings-1]=0.;
mr::r2r_fftpack({size_t(nrings)}, {sizeof(double)}, {sizeof(double)}, {0}, false, false, weight.data(), weight.data(), 1.); mr::r2r_fftpack({size_t(nrings)}, {1}, {1}, {0}, false, false, weight.data(), weight.data(), 1.);
for (size_t m=0; m<(nrings+1)/2; ++m) for (size_t m=0; m<(nrings+1)/2; ++m)
{ {
...@@ -275,7 +275,7 @@ unique_ptr<sharp_geom_info> sharp_make_cc_geom_info (size_t nrings, size_t pprin ...@@ -275,7 +275,7 @@ unique_ptr<sharp_geom_info> sharp_make_cc_geom_info (size_t nrings, size_t pprin
for (size_t k=1; k<=(n/2-1); ++k) for (size_t k=1; k<=(n/2-1); ++k)
weight[2*k-1]=2./(1.-4.*k*k) + dw; weight[2*k-1]=2./(1.-4.*k*k) + dw;
weight[2*(n/2)-1]=(n-3.)/(2*(n/2)-1) -1. -dw*((2-(n&1))*n-1); weight[2*(n/2)-1]=(n-3.)/(2*(n/2)-1) -1. -dw*((2-(n&1))*n-1);
mr::r2r_fftpack({size_t(n)}, {sizeof(double)}, {sizeof(double)}, {0}, false, false, weight.data(), weight.data(), 1.); mr::r2r_fftpack({size_t(n)}, {1}, {1}, {0}, false, false, weight.data(), weight.data(), 1.);
weight[n]=weight[0]; weight[n]=weight[0];
for (size_t m=0; m<(nrings+1)/2; ++m) for (size_t m=0; m<(nrings+1)/2; ++m)
...@@ -308,7 +308,7 @@ unique_ptr<sharp_geom_info> sharp_make_fejer2_geom_info (size_t nrings, size_t p ...@@ -308,7 +308,7 @@ unique_ptr<sharp_geom_info> sharp_make_fejer2_geom_info (size_t nrings, size_t p
for (size_t k=1; k<=(n/2-1); ++k) for (size_t k=1; k<=(n/2-1); ++k)
weight[2*k-1]=2./(1.-4.*k*k); weight[2*k-1]=2./(1.-4.*k*k);
weight[2*(n/2)-1]=(n-3.)/(2*(n/2)-1) -1.; weight[2*(n/2)-1]=(n-3.)/(2*(n/2)-1) -1.;
mr::r2r_fftpack({size_t(n)}, {sizeof(double)}, {sizeof(double)}, {0}, false, false, weight.data(), weight.data(), 1.); mr::r2r_fftpack({size_t(n)}, {1}, {1}, {0}, false, false, weight.data(), weight.data(), 1.);
for (size_t m=0; m<nrings; ++m) for (size_t m=0; m<nrings; ++m)
weight[m]=weight[m+1]; weight[m]=weight[m+1];
......
...@@ -60,6 +60,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ...@@ -60,6 +60,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "mr_util/unity_roots.h" #include "mr_util/unity_roots.h"
#include "mr_util/useful_macros.h" #include "mr_util/useful_macros.h"
#include "mr_util/simd.h" #include "mr_util/simd.h"
#include "mr_util/mav.h"
#ifndef MRUTIL_NO_THREADING #ifndef MRUTIL_NO_THREADING
#include <mutex> #include <mutex>
#endif #endif
...@@ -76,9 +77,6 @@ template <typename T> T cos(T) = delete; ...@@ -76,9 +77,6 @@ template <typename T> T cos(T) = delete;
template <typename T> T sin(T) = delete; template <typename T> T sin(T) = delete;
template <typename T> T sqrt(T) = delete; template <typename T> T sqrt(T) = delete;
using shape_t = std::vector<size_t>;
using stride_t = std::vector<ptrdiff_t>;
constexpr bool FORWARD = true, constexpr bool FORWARD = true,
BACKWARD = false; BACKWARD = false;
...@@ -2293,51 +2291,11 @@ template<typename T> std::shared_ptr<T> get_plan(size_t length) ...@@ -2293,51 +2291,11 @@ template<typename T> std::shared_ptr<T> get_plan(size_t length)
#endif #endif
} }
class arr_info
{
protected:
shape_t shp;
stride_t str;
public:
arr_info(const shape_t &shape_, const stride_t &stride_)
: shp(shape_), str(stride_) {}
size_t ndim() const { return shp.size(); }
size_t size() const { return util::prod(shp); }
const shape_t &shape() const { return shp; }
size_t shape(size_t i) const { return shp[i]; }
const stride_t &stride() const { return str; }
const ptrdiff_t &stride(size_t i) const { return str[i]; }
};
template<typename T> class cndarr: public arr_info
{
protected:
const char *d;
public:
cndarr(const void *data_, const shape_t &shape_, const stride_t &stride_)
: arr_info(shape_, stride_),
d(reinterpret_cast<const char *>(data_)) {}
const T &operator[](ptrdiff_t ofs) const
{ return *reinterpret_cast<const T *>(d+ofs); }
};
template<typename T> class ndarr: public cndarr<T>
{
public:
ndarr(void *data_, const shape_t &shape_, const stride_t &stride_)
: cndarr<T>::cndarr(const_cast<const void *>(data_), shape_, stride_)
{}
T &operator[](ptrdiff_t ofs)
{ return *reinterpret_cast<T *>(const_cast<char *>(cndarr<T>::d+ofs)); }
};
template<size_t N> class multi_iter template<size_t N> class multi_iter
{ {
private: private:
shape_t pos; shape_t pos;
const arr_info &iarr, &oarr; fmav_info iarr, oarr;
ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o; ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o;
size_t idim, rem; size_t idim, rem;
...@@ -2358,7 +2316,7 @@ template<size_t N> class multi_iter ...@@ -2358,7 +2316,7 @@ template<size_t N> class multi_iter
} }
public: public:
multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_, multi_iter(const fmav_info &iarr_, const fmav_info &oarr_, size_t idim_,
size_t nshares, size_t myshare) size_t nshares, size_t myshare)
: pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0), : pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0),
str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)), str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)),
...@@ -2412,12 +2370,12 @@ class simple_iter ...@@ -2412,12 +2370,12 @@ class simple_iter
{ {
private: private:
shape_t pos; shape_t pos;
const arr_info &arr; fmav_info arr;
ptrdiff_t p; ptrdiff_t p;
size_t rem; size_t rem;
public: public:
simple_iter(const arr_info &arr_) simple_iter(const fmav_info &arr_)
: pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size()) {} : pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size()) {}
void advance() void advance()
{ {
...@@ -2440,7 +2398,7 @@ class rev_iter ...@@ -2440,7 +2398,7 @@ class rev_iter
{ {
private: private:
shape_t pos; shape_t pos;
const arr_info &arr; fmav_info arr;
std::vector<char> rev_axis; std::vector<char> rev_axis;
std::vector<char> rev_jump; std::vector<char> rev_jump;
size_t last_axis, last_size; size_t last_axis, last_size;
...@@ -2449,7 +2407,7 @@ class rev_iter ...@@ -2449,7 +2407,7 @@ class rev_iter
size_t rem; size_t rem;
public: public:
rev_iter(const arr_info &arr_, const shape_t &axes) rev_iter(const fmav_info &arr_, const shape_t &axes)
: pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0), : pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0),
rev_jump(arr_.ndim(), 1), p(0), rp(0) rev_jump(arr_.ndim(), 1), p(0), rp(0)
{ {
...@@ -2517,15 +2475,15 @@ template<> struct VTYPE<long double> ...@@ -2517,15 +2475,15 @@ template<> struct VTYPE<long double>
}; };
#endif #endif
template<typename T> aligned_array<char> alloc_tmp(const shape_t &shape, template<typename T, typename T0> aligned_array<T> alloc_tmp(const shape_t &shape,
size_t axsize, size_t elemsize) size_t axsize)
{ {
auto othersize = util::prod(shape)/axsize; auto othersize = util::prod(shape)/axsize;
auto tmpsize = axsize*((othersize>=VLEN<T>::val) ? VLEN<T>::val : 1); auto tmpsize = axsize*((othersize>=VLEN<T0>::val) ? VLEN<T0>::val : 1);
return aligned_array<char>(tmpsize*elemsize); return aligned_array<T>(tmpsize);
} }
template<typename T> aligned_array<char> alloc_tmp(const shape_t &shape, template<typename T, typename T0> aligned_array<T> alloc_tmp(const shape_t &shape,
const shape_t &axes, size_t elemsize) const shape_t &axes)
{ {
size_t fullsize=util::prod(shape); size_t fullsize=util::prod(shape);
size_t tmpsize=0; size_t tmpsize=0;
...@@ -2533,14 +2491,14 @@ template<typename T> aligned_array<char> alloc_tmp(const shape_t &shape, ...@@ -2533,14 +2491,14 @@ template<typename T> aligned_array<char> alloc_tmp(const shape_t &shape,
{ {
auto axsize = shape[axes[i]]; auto axsize = shape[axes[i]];
auto othersize = fullsize/axsize; auto othersize = fullsize/axsize;
auto sz = axsize*((othersize>=VLEN<T>::val) ? VLEN<T>::val : 1); auto sz = axsize*((othersize>=VLEN<T0>::val) ? VLEN<T0>::val : 1);
if (sz>tmpsize) tmpsize=sz; if (sz>tmpsize) tmpsize=sz;
} }
return aligned_array<char>(tmpsize*elemsize); return aligned_array<T>(tmpsize);
} }
template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it, template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
const cndarr<Cmplx<T>> &src, Cmplx<vtype_t<T>> *MRUTIL_RESTRICT dst) const cfmav<Cmplx<T>> &src, Cmplx<vtype_t<T>> *MRUTIL_RESTRICT dst)
{ {
for (size_t i=0; i<it.length_in(); ++i) for (size_t i=0; i<it.length_in(); ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
...@@ -2551,7 +2509,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it, ...@@ -2551,7 +2509,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
} }
template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it, template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
const cndarr<T> &src, vtype_t<T> *MRUTIL_RESTRICT dst) const cfmav<T> &src, vtype_t<T> *MRUTIL_RESTRICT dst)
{ {
for (size_t i=0; i<it.length_in(); ++i) for (size_t i=0; i<it.length_in(); ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
...@@ -2559,7 +2517,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it, ...@@ -2559,7 +2517,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
} }
template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it, template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
const cndarr<T> &src, T *MRUTIL_RESTRICT dst) const cfmav<T> &src, T *MRUTIL_RESTRICT dst)
{ {
if (dst == &src[it.iofs(0)]) return; // in-place if (dst == &src[it.iofs(0)]) return; // in-place
for (size_t i=0; i<it.length_in(); ++i) for (size_t i=0; i<it.length_in(); ++i)
...@@ -2567,7 +2525,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it, ...@@ -2567,7 +2525,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
} }
template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it, template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
const Cmplx<vtype_t<T>> *MRUTIL_RESTRICT src, ndarr<Cmplx<T>> &dst) const Cmplx<vtype_t<T>> *MRUTIL_RESTRICT src, const fmav<Cmplx<T>> &dst)
{ {
for (size_t i=0; i<it.length_out(); ++i) for (size_t i=0; i<it.length_out(); ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
...@@ -2575,7 +2533,7 @@ template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it, ...@@ -2575,7 +2533,7 @@ template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
} }
template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it, template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
const vtype_t<T> *MRUTIL_RESTRICT src, ndarr<T> &dst) const vtype_t<T> *MRUTIL_RESTRICT src, const fmav<T> &dst)
{ {
for (size_t i=0; i<it.length_out(); ++i) for (size_t i=0; i<it.length_out(); ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
...@@ -2583,7 +2541,7 @@ template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it, ...@@ -2583,7 +2541,7 @@ template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
} }
template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it, template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
const T *MRUTIL_RESTRICT src, ndarr<T> &dst) const T *MRUTIL_RESTRICT src, const fmav<T> &dst)
{ {
if (src == &dst[it.oofs(0)]) return; // in-place if (src == &dst[it.oofs(0)]) return; // in-place
for (size_t i=0; i<it.length_out(); ++i) for (size_t i=0; i<it.length_out(); ++i)
...@@ -2596,7 +2554,7 @@ template <typename T> struct add_vec<Cmplx<T>> ...@@ -2596,7 +2554,7 @@ template <typename T> struct add_vec<Cmplx<T>>
template <typename T> using add_vec_t = typename add_vec<T>::type; template <typename T> using add_vec_t = typename add_vec<T>::type;
template<typename Tplan, typename T, typename T0, typename Exec> template<typename Tplan, typename T, typename T0, typename Exec>
MRUTIL_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out, MRUTIL_NOINLINE void general_nd(const cfmav<T> &in, const fmav<T> &out,
const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec, const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec,
const bool allow_inplace=true) const bool allow_inplace=true)
{ {
...@@ -2612,8 +2570,8 @@ MRUTIL_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out, ...@@ -2612,8 +2570,8 @@ MRUTIL_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
util::thread_count(nthreads, in.shape(), axes[iax], VLEN<T>::val), util::thread_count(nthreads, in.shape(), axes[iax], VLEN<T>::val),
[&](Scheduler &sched) { [&](Scheduler &sched) {
constexpr auto vlen = VLEN<T0>::val; constexpr auto vlen = VLEN<T0>::val;
auto storage = alloc_tmp<T0>(in.shape(), len, sizeof(T)); auto storage = alloc_tmp<T,T0>(in.shape(), len);
const auto &tin(iax==0? in : out); const auto &tin(iax==0? in : cfmav<T>(out));
multi_iter<vlen> it(tin, out, axes[iax], sched.num_threads(), sched.thread_num()); multi_iter<vlen> it(tin, out, axes[iax], sched.num_threads(), sched.thread_num());
#ifndef POCKETFFT_NO_VECTORS #ifndef POCKETFFT_NO_VECTORS
if (vlen>1) if (vlen>1)
...@@ -2627,7 +2585,7 @@ MRUTIL_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out, ...@@ -2627,7 +2585,7 @@ MRUTIL_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
while (it.remaining()>0) while (it.remaining()>0)
{ {
it.advance(1); it.advance(1);
auto buf = allow_inplace && it.stride_out() == sizeof(T) ? auto buf = allow_inplace && it.stride_out() == 1 ?
&out[it.oofs(0)] : reinterpret_cast<T *>(storage.data()); &out[it.oofs(0)] : reinterpret_cast<T *>(storage.data());
exec(it, tin, out, buf, *plan, fct); exec(it, tin, out, buf, *plan, fct);
} }
...@@ -2641,8 +2599,8 @@ struct ExecC2C ...@@ -2641,8 +2599,8 @@ struct ExecC2C
bool forward; bool forward;
template <typename T0, typename T, size_t vlen> void operator () ( template <typename T0, typename T, size_t vlen> void operator () (
const multi_iter<vlen> &it, const cndarr<Cmplx<T0>> &in, const multi_iter<vlen> &it, const cfmav<Cmplx<T0>> &in,
ndarr<Cmplx<T0>> &out, T * buf, const pocketfft_c<T0> &plan, T0 fct) const const fmav<Cmplx<T0>> &out, T * buf, const pocketfft_c<T0> &plan, T0 fct) const
{ {
copy_input(it, in, buf); copy_input(it, in, buf);
plan.exec(buf, fct, forward); plan.exec(buf, fct, forward);
...@@ -2651,7 +2609,7 @@ struct ExecC2C ...@@ -2651,7 +2609,7 @@ struct ExecC2C
}; };
template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it, template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
const vtype_t<T> *MRUTIL_RESTRICT src, ndarr<T> &dst) const vtype_t<T> *MRUTIL_RESTRICT src, const fmav<T> &dst)
{ {
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
dst[it.oofs(j,0)] = src[0][j]; dst[it.oofs(j,0)] = src[0][j];
...@@ -2668,7 +2626,7 @@ template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it, ...@@ -2668,7 +2626,7 @@ template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
} }
template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it, template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
const T *MRUTIL_RESTRICT src, ndarr<T> &dst) const T *MRUTIL_RESTRICT src, const fmav<T> &dst)
{ {
dst[it.oofs(0)] = src[0]; dst[it.oofs(0)] = src[0];
size_t i=1, i1=1, i2=it.length_out()-1; size_t i=1, i1=1, i2=it.length_out()-1;
...@@ -2684,7 +2642,7 @@ template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it, ...@@ -2684,7 +2642,7 @@ template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
struct ExecHartley struct ExecHartley
{ {
template <typename T0, typename T, size_t vlen> void operator () ( template <typename T0, typename T, size_t vlen> void operator () (
const multi_iter<vlen> &it, const cndarr<T0> &in, ndarr<T0> &out, const multi_iter<vlen> &it, const cfmav<T0> &in, const fmav<T0> &out,
T * buf, const pocketfft_r<T0> &plan, T0 fct) const T * buf, const pocketfft_r<T0> &plan, T0 fct) const
{ {
copy_input(it, in, buf); copy_input(it, in, buf);
...@@ -2700,8 +2658,8 @@ struct ExecDcst ...@@ -2700,8 +2658,8 @@ struct ExecDcst
bool cosine; bool cosine;
template <typename T0, typename T, typename Tplan, size_t vlen> template <typename T0, typename T, typename Tplan, size_t vlen>
void operator () (const multi_iter<vlen> &it, const cndarr<T0> &in, void operator () (const multi_iter<vlen> &it, const cfmav<T0> &in,
ndarr<T0> &out, T * buf, const Tplan &plan, T0 fct) const const fmav <T0> &out, T * buf, const Tplan &plan, T0 fct) const
{ {
copy_input(it, in, buf); copy_input(it, in, buf);
plan.exec(buf, fct, ortho, type, cosine); plan.exec(buf, fct, ortho, type, cosine);
...@@ -2710,7 +2668,7 @@ struct ExecDcst ...@@ -2710,7 +2668,7 @@ struct ExecDcst
}; };
template<typename T> MRUTIL_NOINLINE void general_r2c( template<typename T> MRUTIL_NOINLINE void general_r2c(
const cndarr<T> &in, ndarr<Cmplx<T>> &out, size_t axis, bool forward, T fct, const cfmav<T> &in, const fmav<Cmplx<T>> &out, size_t axis, bool forward, T fct,
size_t nthreads) size_t nthreads)
{ {
auto plan = get_plan<pocketfft_r<T>>(in.shape(axis)); auto plan = get_plan<pocketfft_r<T>>(in.shape(axis));
...@@ -2719,7 +2677,7 @@ template<typename T> MRUTIL_NOINLINE void general_r2c( ...@@ -2719,7 +2677,7 @@ template<typename T> MRUTIL_NOINLINE void general_r2c(
util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val), util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),
[&](Scheduler &sched) { [&](Scheduler &sched) {
constexpr auto vlen = VLEN<T>::val; constexpr auto vlen = VLEN<T>::val;
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T)); auto storage = alloc_tmp<T,T>(in.shape(), len);
multi_iter<vlen> it(in, out, axis, sched.num_threads(), sched.thread_num()); multi_iter<vlen> it(in, out, axis, sched.num_threads(), sched.thread_num());
#ifndef POCKETFFT_NO_VECTORS #ifndef POCKETFFT_NO_VECTORS
if (vlen>1) if (vlen>1)
...@@ -2765,7 +2723,7 @@ template<typename T> MRUTIL_NOINLINE void general_r2c( ...@@ -2765,7 +2723,7 @@ template<typename T> MRUTIL_NOINLINE void general_r2c(
}); // end of parallel region }); // end of parallel region
} }
template<typename T> MRUTIL_NOINLINE void general_c2r( template<typename T> MRUTIL_NOINLINE void general_c2r(
const cndarr<Cmplx<T>> &in, ndarr<T> &out, size_t axis, bool forward, T fct, const cfmav<Cmplx<T>> &in, const fmav<T> &out, size_t axis, bool forward, T fct,
size_t nthreads) size_t nthreads)
{ {
auto plan = get_plan<pocketfft_r<T>>(out.shape(axis)); auto plan = get_plan<pocketfft_r<T>>(out.shape(axis));
...@@ -2774,7 +2732,7 @@ template<typename T> MRUTIL_NOINLINE void general_c2r( ...@@ -2774,7 +2732,7 @@ template<typename T> MRUTIL_NOINLINE void general_c2r(
util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val), util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),
[&](Scheduler &sched) { [&](Scheduler &sched) {
constexpr auto vlen = VLEN<T>::val; constexpr auto vlen = VLEN<T>::val;
auto storage = alloc_tmp<T>(out.shape(), len, sizeof(T)); auto storage = alloc_tmp<T,T>(out.shape(), len);
multi_iter<vlen> it(in, out, axis, sched.num_threads(), sched.thread_num()); multi_iter<vlen> it(in, out, axis, sched.num_threads(), sched.thread_num());
#ifndef POCKETFFT_NO_VECTORS #ifndef POCKETFFT_NO_VECTORS
if (vlen>1) if (vlen>1)
...@@ -2841,7 +2799,7 @@ struct ExecR2R ...@@ -2841,7 +2799,7 @@ struct ExecR2R
bool r2c, forward; bool r2c, forward;
template <typename T0, typename T, size_t vlen> void operator () ( template <typename T0, typename T, size_t vlen> void operator () (
const multi_iter<vlen> &it, const cndarr<T0> &in, ndarr<T0> &out, T * buf, const multi_iter<vlen> &it, const cfmav<T0> &in, const fmav<T0> &out, T * buf,
const pocketfft_r<T0> &plan, T0 fct) const const pocketfft_r<T0> &plan, T0 fct) const
{ {
copy_input(it, in, buf); copy_input(it, in, buf);
...@@ -2863,8 +2821,8 @@ template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in, ...@@ -2863,8 +2821,8 @@ template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,
{ {
if (util::prod(shape)==0) return; if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<Cmplx<T>> ain(data_in, shape, stride_in); const cfmav<Cmplx<T>> ain(reinterpret_cast<const Cmplx<T> *>(data_in), shape, stride_in);
ndarr<Cmplx<T>> aout(data_out, shape, stride_out); const fmav<Cmplx<T>> aout(reinterpret_cast<Cmplx<T> *>(data_out), shape, stride_out);
general_nd<pocketfft_c<T>>(ain, aout, axes, fct, nthreads, ExecC2C{forward}); general_nd<pocketfft_c<T>>(ain, aout, axes, fct, nthreads, ExecC2C{forward});
} }
...@@ -2875,8 +2833,8 @@ template<typename T> void dct(const shape_t &shape, ...@@ -2875,8 +2833,8 @@ template<typename T> void dct(const shape_t &shape,
if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type"); if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type");
if (util::prod(shape)==0) return; if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<T> ain(data_in, shape, stride_in); const cfmav<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out); const fmav<T> aout(data_out, shape, stride_out);
const ExecDcst exec{ortho, type, true}; const ExecDcst exec{ortho, type, true};
if (type==1) if (type==1)
general_nd<T_dct1<T>>(ain, aout, axes, fct, nthreads, exec); general_nd<T_dct1<T>>(ain, aout, axes, fct, nthreads, exec);
...@@ -2893,8 +2851,8 @@ template<typename T> void dst(const shape_t &shape, ...@@ -2893,8 +2851,8 @@ template<typename T> void dst(const shape_t &shape,
if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type"); if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type");
if (util::prod(shape)==0) return; if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<T> ain(data_in, shape, stride_in); const cfmav<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out); const fmav<T> aout(data_out, shape, stride_out);
const ExecDcst exec{ortho, type, false}; const ExecDcst exec{ortho, type, false};
if (type==1) if (type==1)
general_nd<T_dst1<T>>(ain, aout, axes, fct, nthreads, exec); general_nd<T_dst1<T>>(ain, aout, axes, fct, nthreads, exec);
...@@ -2911,10 +2869,10 @@ template<typename T> void r2c(const shape_t &shape_in, ...@@ -2911,10 +2869,10 @@ template<typename T> void r2c(const shape_t &shape_in,
{ {
if (util::prod(shape_in)==0) return; if (util::prod(shape_in)==0) return;
util::sanity_check(shape_in, stride_in, stride_out, false, axis); util::sanity_check(shape_in, stride_in, stride_out, false, axis);
cndarr<T> ain(data_in, shape_in, stride_in); const cfmav<T> ain(data_in, shape_in, stride_in);
shape_t shape_out(shape_in); shape_t shape_out(shape_in);
shape_out[axis] = shape_in[axis]/2 + 1; shape_out[axis] = shape_in[axis]/2 + 1;
ndarr<Cmplx<T>> aout(data_out, shape_out, stride_out); const fmav<Cmplx<T>> aout(reinterpret_cast<Cmplx<T> *>(data_out), shape_out, stride_out);
general_r2c(ain, aout, axis, forward, fct, nthreads); general_r2c(ain, aout, axis, forward, fct, nthreads);
} }
...@@ -2945,8 +2903,8 @@ template<typename T> void c2r(const shape_t &shape_out, ...@@ -2945,8 +2903,8 @@ template<typename T> void c2r(const shape_t &shape_out,
util::sanity_check(shape_out, stride_in, stride_out, false, axis); util::sanity_check(shape_out, stride_in, stride_out, false, axis);
shape_t shape_in(shape_out); shape_t shape_in(shape_out);
shape_in[axis] = shape_out[axis]/2 + 1; shape_in[axis] = shape_out[axis]/2 + 1;
cndarr<Cmplx<T>> ain(data_in, shape_in, stride_in); const cfmav<Cmplx<T>> ain(reinterpret_cast<const Cmplx<T> *>(data_in), shape_in, stride_in);
ndarr<T> aout(data_out, shape_out, stride_out); const fmav<T> aout(data_out, shape_out, stride_out);
general_c2r(ain, aout, axis, forward, fct, nthreads); general_c2r(ain, aout, axis, forward, fct, nthreads);
} }
...@@ -2964,7 +2922,7 @@ template<typename T> void c2r(const shape_t &shape_out, ...@@ -2964,7 +2922,7 @@ template<typename T> void c2r(const shape_t &shape_out,
shape_in[axes.back()] = shape_out[axes.back()]/2 + 1; shape_in[axes.back()] = shape_out[axes.back()]/2 + 1;