Commit f1c7d8d8 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'mav_rewrite' into 'master'

Mav rewrite

See merge request mtr/cxxbase!1
parents 0d8a2890 43f38678
...@@ -498,9 +498,9 @@ template<typename I> template<typename I2> ...@@ -498,9 +498,9 @@ template<typename I> template<typename I2>
double dr=base[o].max_pixrad(); // safety distance double dr=base[o].max_pixrad(); // safety distance
for (size_t i=0; i<nv; ++i) for (size_t i=0; i<nv; ++i)
{ {
crlimit(o,i,0) = (rad[i]+dr>pi) ? -1. : cos(rad[i]+dr); crlimit.v(o,i,0) = (rad[i]+dr>pi) ? -1. : cos(rad[i]+dr);
crlimit(o,i,1) = (o==0) ? cos(rad[i]) : crlimit(0,i,1); crlimit.v(o,i,1) = (o==0) ? cos(rad[i]) : crlimit(0,i,1);
crlimit(o,i,2) = (rad[i]-dr<0.) ? 1. : cos(rad[i]-dr); crlimit.v(o,i,2) = (rad[i]-dr<0.) ? 1. : cos(rad[i]-dr);
} }
} }
...@@ -563,9 +563,9 @@ template<typename I> void T_Healpix_Base<I>::query_multidisc_general ...@@ -563,9 +563,9 @@ template<typename I> void T_Healpix_Base<I>::query_multidisc_general
double dr=base[o].max_pixrad(); // safety distance double dr=base[o].max_pixrad(); // safety distance
for (size_t i=0; i<nv; ++i) for (size_t i=0; i<nv; ++i)
{ {
crlimit(o,i,0) = (rad[i]+dr>pi) ? -1. : cos(rad[i]+dr); crlimit.v(o,i,0) = (rad[i]+dr>pi) ? -1. : cos(rad[i]+dr);
crlimit(o,i,1) = (o==0) ? cos(rad[i]) : crlimit(0,i,1); crlimit.v(o,i,1) = (o==0) ? cos(rad[i]) : crlimit(0,i,1);
crlimit(o,i,2) = (rad[i]-dr<0.) ? 1. : cos(rad[i]-dr); crlimit.v(o,i,2) = (rad[i]-dr<0.) ? 1. : cos(rad[i]-dr);
} }
} }
......
...@@ -25,6 +25,8 @@ ...@@ -25,6 +25,8 @@
#include <sstream> #include <sstream>
#include <exception> #include <exception>
#include "mr_util/useful_macros.h"
namespace mr { namespace mr {
namespace detail_error_handling { namespace detail_error_handling {
...@@ -35,19 +37,6 @@ namespace detail_error_handling { ...@@ -35,19 +37,6 @@ namespace detail_error_handling {
#define MRUTIL_ERROR_HANDLING_LOC_ ::mr::detail_error_handling::CodeLocation(__FILE__, __LINE__) #define MRUTIL_ERROR_HANDLING_LOC_ ::mr::detail_error_handling::CodeLocation(__FILE__, __LINE__)
#endif #endif
#define MR_fail(...) \
do { \
::std::ostringstream msg; \
::mr::detail_error_handling::streamDump__(msg, MRUTIL_ERROR_HANDLING_LOC_, "\n", ##__VA_ARGS__, "\n"); \
throw ::std::runtime_error(msg.str()); \
} while(0)
#define MR_assert(cond,...) \
do { \
if (cond); \
else { MR_fail("Assertion failure\n", ##__VA_ARGS__); } \
} while(0)
// to be replaced with std::source_location once generally available // to be replaced with std::source_location once generally available
class CodeLocation class CodeLocation
{ {
...@@ -73,21 +62,39 @@ inline ::std::ostream &operator<<(::std::ostream &os, const CodeLocation &loc) ...@@ -73,21 +62,39 @@ inline ::std::ostream &operator<<(::std::ostream &os, const CodeLocation &loc)
#if (__cplusplus>=201703L) // hyper-elegant C++2017 version #if (__cplusplus>=201703L) // hyper-elegant C++2017 version
template<typename ...Args> template<typename ...Args>
inline void streamDump__(::std::ostream &os, Args&&... args) void streamDump__(::std::ostream &os, Args&&... args)
{ (os << ... << args); } { (os << ... << args); }
#else #else
template<typename T> template<typename T>
inline void streamDump__(::std::ostream &os, const T& value) void streamDump__(::std::ostream &os, const T& value)
{ os << value; } { os << value; }
template<typename T, typename ... Args> template<typename T, typename ... Args>
inline void streamDump__(::std::ostream &os, const T& value, void streamDump__(::std::ostream &os, const T& value,
const Args& ... args) const Args& ... args)
{ {
os << value; os << value;
streamDump__(os, args...); streamDump__(os, args...);
} }
#endif #endif
template<typename ...Args>
[[noreturn]] void MRUTIL_NOINLINE fail__(Args&&... args)
{
::std::ostringstream msg; \
::mr::detail_error_handling::streamDump__(msg, args...); \
throw ::std::runtime_error(msg.str()); \
}
#define MR_fail(...) \
do { \
::mr::detail_error_handling::fail__(MRUTIL_ERROR_HANDLING_LOC_, "\n", ##__VA_ARGS__, "\n"); \
} while(0)
#define MR_assert(cond,...) \
do { \
if (cond); \
else { MR_fail("Assertion failure\n", ##__VA_ARGS__); } \
} while(0)
}} }}
......
...@@ -621,7 +621,7 @@ template<typename T, typename T0> aligned_array<T> alloc_tmp ...@@ -621,7 +621,7 @@ template<typename T, typename T0> aligned_array<T> alloc_tmp
} }
template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it, template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
const cfmav<Cmplx<T>> &src, Cmplx<native_simd<T>> *MRUTIL_RESTRICT dst) const fmav<Cmplx<T>> &src, Cmplx<native_simd<T>> *MRUTIL_RESTRICT dst)
{ {
for (size_t i=0; i<it.length_in(); ++i) for (size_t i=0; i<it.length_in(); ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
...@@ -632,7 +632,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it, ...@@ -632,7 +632,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
} }
template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it, template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
const cfmav<T> &src, native_simd<T> *MRUTIL_RESTRICT dst) const fmav<T> &src, native_simd<T> *MRUTIL_RESTRICT dst)
{ {
for (size_t i=0; i<it.length_in(); ++i) for (size_t i=0; i<it.length_in(); ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
...@@ -640,7 +640,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it, ...@@ -640,7 +640,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
} }
template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it, template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
const cfmav<T> &src, T *MRUTIL_RESTRICT dst) const fmav<T> &src, T *MRUTIL_RESTRICT dst)
{ {
if (dst == &src[it.iofs(0)]) return; // in-place if (dst == &src[it.iofs(0)]) return; // in-place
for (size_t i=0; i<it.length_in(); ++i) for (size_t i=0; i<it.length_in(); ++i)
...@@ -648,27 +648,30 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it, ...@@ -648,27 +648,30 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
} }
template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it, template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
const Cmplx<native_simd<T>> *MRUTIL_RESTRICT src, const fmav<Cmplx<T>> &dst) const Cmplx<native_simd<T>> *MRUTIL_RESTRICT src, fmav<Cmplx<T>> &dst)
{ {
auto ptr=dst.vdata();
for (size_t i=0; i<it.length_out(); ++i) for (size_t i=0; i<it.length_out(); ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
dst[it.oofs(j,i)].Set(src[i].r[j],src[i].i[j]); ptr[it.oofs(j,i)].Set(src[i].r[j],src[i].i[j]);
} }
template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it, template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
const native_simd<T> *MRUTIL_RESTRICT src, const fmav<T> &dst) const native_simd<T> *MRUTIL_RESTRICT src, fmav<T> &dst)
{ {
auto ptr=dst.vdata();
for (size_t i=0; i<it.length_out(); ++i) for (size_t i=0; i<it.length_out(); ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
dst[it.oofs(j,i)] = src[i][j]; ptr[it.oofs(j,i)] = src[i][j];
} }
template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it, template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
const T *MRUTIL_RESTRICT src, const fmav<T> &dst) const T *MRUTIL_RESTRICT src, fmav<T> &dst)
{ {
auto ptr=dst.vdata();
if (src == &dst[it.oofs(0)]) return; // in-place if (src == &dst[it.oofs(0)]) return; // in-place
for (size_t i=0; i<it.length_out(); ++i) for (size_t i=0; i<it.length_out(); ++i)
dst[it.oofs(i)] = src[i]; ptr[it.oofs(i)] = src[i];
} }
template <typename T> struct add_vec { using type = native_simd<T>; }; template <typename T> struct add_vec { using type = native_simd<T>; };
...@@ -677,7 +680,7 @@ template <typename T> struct add_vec<Cmplx<T>> ...@@ -677,7 +680,7 @@ template <typename T> struct add_vec<Cmplx<T>>
template <typename T> using add_vec_t = typename add_vec<T>::type; template <typename T> using add_vec_t = typename add_vec<T>::type;
template<typename Tplan, typename T, typename T0, typename Exec> template<typename Tplan, typename T, typename T0, typename Exec>
MRUTIL_NOINLINE void general_nd(const cfmav<T> &in, const fmav<T> &out, MRUTIL_NOINLINE void general_nd(const fmav<T> &in, fmav<T> &out,
const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec, const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec,
const bool allow_inplace=true) const bool allow_inplace=true)
{ {
...@@ -709,7 +712,7 @@ MRUTIL_NOINLINE void general_nd(const cfmav<T> &in, const fmav<T> &out, ...@@ -709,7 +712,7 @@ MRUTIL_NOINLINE void general_nd(const cfmav<T> &in, const fmav<T> &out,
{ {
it.advance(1); it.advance(1);
auto buf = allow_inplace && it.stride_out() == 1 ? auto buf = allow_inplace && it.stride_out() == 1 ?
&out[it.oofs(0)] : reinterpret_cast<T *>(storage.data()); &out.vraw(it.oofs(0)) : reinterpret_cast<T *>(storage.data());
exec(it, tin, out, buf, *plan, fct); exec(it, tin, out, buf, *plan, fct);
} }
}); // end of parallel region }); // end of parallel region
...@@ -721,9 +724,9 @@ struct ExecC2C ...@@ -721,9 +724,9 @@ struct ExecC2C
{ {
bool forward; bool forward;
template <typename T0, typename T, size_t vlen> void operator () ( template <typename T0, typename T, size_t vlen> void operator() (
const multi_iter<vlen> &it, const cfmav<Cmplx<T0>> &in, const multi_iter<vlen> &it, const fmav<Cmplx<T0>> &in,
const fmav<Cmplx<T0>> &out, T * buf, const pocketfft_c<T0> &plan, T0 fct) const fmav<Cmplx<T0>> &out, T *buf, const pocketfft_c<T0> &plan, T0 fct) const
{ {
copy_input(it, in, buf); copy_input(it, in, buf);
plan.exec(buf, fct, forward); plan.exec(buf, fct, forward);
...@@ -732,40 +735,42 @@ struct ExecC2C ...@@ -732,40 +735,42 @@ struct ExecC2C
}; };
template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it, template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
const native_simd<T> *MRUTIL_RESTRICT src, const fmav<T> &dst) const native_simd<T> *MRUTIL_RESTRICT src, fmav<T> &dst)
{ {
auto ptr = dst.vdata();
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
dst[it.oofs(j,0)] = src[0][j]; ptr[it.oofs(j,0)] = src[0][j];
size_t i=1, i1=1, i2=it.length_out()-1; size_t i=1, i1=1, i2=it.length_out()-1;
for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2) for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
{ {
dst[it.oofs(j,i1)] = src[i][j]+src[i+1][j]; ptr[it.oofs(j,i1)] = src[i][j]+src[i+1][j];
dst[it.oofs(j,i2)] = src[i][j]-src[i+1][j]; ptr[it.oofs(j,i2)] = src[i][j]-src[i+1][j];
} }
if (i<it.length_out()) if (i<it.length_out())
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
dst[it.oofs(j,i1)] = src[i][j]; ptr[it.oofs(j,i1)] = src[i][j];
} }
template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it, template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
const T *MRUTIL_RESTRICT src, const fmav<T> &dst) const T *MRUTIL_RESTRICT src, fmav<T> &dst)
{ {
dst[it.oofs(0)] = src[0]; auto ptr = dst.vdata();
ptr[it.oofs(0)] = src[0];
size_t i=1, i1=1, i2=it.length_out()-1; size_t i=1, i1=1, i2=it.length_out()-1;
for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2) for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
{ {
dst[it.oofs(i1)] = src[i]+src[i+1]; ptr[it.oofs(i1)] = src[i]+src[i+1];
dst[it.oofs(i2)] = src[i]-src[i+1]; ptr[it.oofs(i2)] = src[i]-src[i+1];
} }
if (i<it.length_out()) if (i<it.length_out())
dst[it.oofs(i1)] = src[i]; ptr[it.oofs(i1)] = src[i];
} }
struct ExecHartley struct ExecHartley
{ {
template <typename T0, typename T, size_t vlen> void operator () ( template <typename T0, typename T, size_t vlen> void operator () (
const multi_iter<vlen> &it, const cfmav<T0> &in, const fmav<T0> &out, const multi_iter<vlen> &it, const fmav<T0> &in, fmav<T0> &out,
T * buf, const pocketfft_r<T0> &plan, T0 fct) const T * buf, const pocketfft_r<T0> &plan, T0 fct) const
{ {
copy_input(it, in, buf); copy_input(it, in, buf);
...@@ -781,8 +786,8 @@ struct ExecDcst ...@@ -781,8 +786,8 @@ struct ExecDcst
bool cosine; bool cosine;
template <typename T0, typename T, typename Tplan, size_t vlen> template <typename T0, typename T, typename Tplan, size_t vlen>
void operator () (const multi_iter<vlen> &it, const cfmav<T0> &in, void operator () (const multi_iter<vlen> &it, const fmav<T0> &in,
const fmav <T0> &out, T * buf, const Tplan &plan, T0 fct) const fmav <T0> &out, T * buf, const Tplan &plan, T0 fct) const
{ {
copy_input(it, in, buf); copy_input(it, in, buf);
plan.exec(buf, fct, ortho, type, cosine); plan.exec(buf, fct, ortho, type, cosine);
...@@ -791,7 +796,7 @@ struct ExecDcst ...@@ -791,7 +796,7 @@ struct ExecDcst
}; };
template<typename T> MRUTIL_NOINLINE void general_r2c( template<typename T> MRUTIL_NOINLINE void general_r2c(
const cfmav<T> &in, const fmav<Cmplx<T>> &out, size_t axis, bool forward, T fct, const fmav<T> &in, fmav<Cmplx<T>> &out, size_t axis, bool forward, T fct,
size_t nthreads) size_t nthreads)
{ {
auto plan = get_plan<pocketfft_r<T>>(in.shape(axis)); auto plan = get_plan<pocketfft_r<T>>(in.shape(axis));
...@@ -810,20 +815,21 @@ template<typename T> MRUTIL_NOINLINE void general_r2c( ...@@ -810,20 +815,21 @@ template<typename T> MRUTIL_NOINLINE void general_r2c(
auto tdatav = reinterpret_cast<native_simd<T> *>(storage.data()); auto tdatav = reinterpret_cast<native_simd<T> *>(storage.data());
copy_input(it, in, tdatav); copy_input(it, in, tdatav);
plan->exec(tdatav, fct, true); plan->exec(tdatav, fct, true);
auto vout = out.vdata();
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,0)].Set(tdatav[0][j]); vout[it.oofs(j,0)].Set(tdatav[0][j]);
size_t i=1, ii=1; size_t i=1, ii=1;
if (forward) if (forward)
for (; i<len-1; i+=2, ++ii) for (; i<len-1; i+=2, ++ii)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,ii)].Set(tdatav[i][j], tdatav[i+1][j]); vout[it.oofs(j,ii)].Set(tdatav[i][j], tdatav[i+1][j]);
else else
for (; i<len-1; i+=2, ++ii) for (; i<len-1; i+=2, ++ii)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,ii)].Set(tdatav[i][j], -tdatav[i+1][j]); vout[it.oofs(j,ii)].Set(tdatav[i][j], -tdatav[i+1][j]);
if (i<len) if (i<len)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,ii)].Set(tdatav[i][j]); vout[it.oofs(j,ii)].Set(tdatav[i][j]);
} }
#endif #endif
while (it.remaining()>0) while (it.remaining()>0)
...@@ -832,21 +838,22 @@ template<typename T> MRUTIL_NOINLINE void general_r2c( ...@@ -832,21 +838,22 @@ template<typename T> MRUTIL_NOINLINE void general_r2c(
auto tdata = reinterpret_cast<T *>(storage.data()); auto tdata = reinterpret_cast<T *>(storage.data());
copy_input(it, in, tdata); copy_input(it, in, tdata);
plan->exec(tdata, fct, true); plan->exec(tdata, fct, true);
out[it.oofs(0)].Set(tdata[0]); auto vout = out.vdata();
vout[it.oofs(0)].Set(tdata[0]);
size_t i=1, ii=1; size_t i=1, ii=1;
if (forward) if (forward)
for (; i<len-1; i+=2, ++ii) for (; i<len-1; i+=2, ++ii)
out[it.oofs(ii)].Set(tdata[i], tdata[i+1]); vout[it.oofs(ii)].Set(tdata[i], tdata[i+1]);
else else
for (; i<len-1; i+=2, ++ii) for (; i<len-1; i+=2, ++ii)
out[it.oofs(ii)].Set(tdata[i], -tdata[i+1]); vout[it.oofs(ii)].Set(tdata[i], -tdata[i+1]);
if (i<len) if (i<len)
out[it.oofs(ii)].Set(tdata[i]); vout[it.oofs(ii)].Set(tdata[i]);
} }
}); // end of parallel region }); // end of parallel region
} }
template<typename T> MRUTIL_NOINLINE void general_c2r( template<typename T> MRUTIL_NOINLINE void general_c2r(
const cfmav<Cmplx<T>> &in, const fmav<T> &out, size_t axis, bool forward, T fct, const fmav<Cmplx<T>> &in, fmav<T> &out, size_t axis, bool forward, T fct,
size_t nthreads) size_t nthreads)
{ {
auto plan = get_plan<pocketfft_r<T>>(out.shape(axis)); auto plan = get_plan<pocketfft_r<T>>(out.shape(axis));
...@@ -922,7 +929,7 @@ struct ExecR2R ...@@ -922,7 +929,7 @@ struct ExecR2R
bool r2c, forward; bool r2c, forward;
template <typename T0, typename T, size_t vlen> void operator () ( template <typename T0, typename T, size_t vlen> void operator () (
const multi_iter<vlen> &it, const cfmav<T0> &in, const fmav<T0> &out, T * buf, const multi_iter<vlen> &it, const fmav<T0> &in, fmav<T0> &out, T * buf,
const pocketfft_r<T0> &plan, T0 fct) const const pocketfft_r<T0> &plan, T0 fct) const
{ {
copy_input(it, in, buf); copy_input(it, in, buf);
...@@ -937,18 +944,18 @@ struct ExecR2R ...@@ -937,18 +944,18 @@ struct ExecR2R
} }
}; };
template<typename T> void c2c(const cfmav<std::complex<T>> &in, template<typename T> void c2c(const fmav<std::complex<T>> &in,
const fmav<std::complex<T>> &out, const shape_t &axes, bool forward, fmav<std::complex<T>> &out, const shape_t &axes, bool forward,
T fct, size_t nthreads=1) T fct, size_t nthreads=1)
{ {
util::sanity_check_onetype(in, out, in.data()==out.data(), axes); util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
if (in.size()==0) return; if (in.size()==0) return;
cfmav<Cmplx<T>> in2(reinterpret_cast<const Cmplx<T> *>(in.data()), in); fmav<Cmplx<T>> in2(reinterpret_cast<const Cmplx<T> *>(in.data()), in);
fmav<Cmplx<T>> out2(reinterpret_cast<Cmplx<T> *>(out.data()), out); fmav<Cmplx<T>> out2(reinterpret_cast<Cmplx<T> *>(out.vdata()), out, out.writable());
general_nd<pocketfft_c<T>>(in2, out2, axes, fct, nthreads, ExecC2C{forward}); general_nd<pocketfft_c<T>>(in2, out2, axes, fct, nthreads, ExecC2C{forward});
} }
template<typename T> void dct(const cfmav<T> &in, const fmav<T> &out, template<typename T> void dct(const fmav<T> &in, fmav<T> &out,
const shape_t &axes, int type, T fct, bool ortho, size_t nthreads=1) const shape_t &axes, int type, T fct, bool ortho, size_t nthreads=1)
{ {
if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type"); if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type");
...@@ -963,7 +970,7 @@ template<typename T> void dct(const cfmav<T> &in, const fmav<T> &out, ...@@ -963,7 +970,7 @@ template<typename T> void dct(const cfmav<T> &in, const fmav<T> &out,
general_nd<T_dcst23<T>>(in, out, axes, fct, nthreads, exec); general_nd<T_dcst23<T>>(in, out, axes, fct, nthreads, exec);
} }
template<typename T> void dst(const cfmav<T> &in, const fmav<T> &out, template<typename T> void dst(const fmav<T> &in, fmav<T> &out,
const shape_t &axes, int type, T fct, bool ortho, size_t nthreads=1) const shape_t &axes, int type, T fct, bool ortho, size_t nthreads=1)
{ {
if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type"); if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type");
...@@ -977,18 +984,18 @@ template<typename T> void dst(const cfmav<T> &in, const fmav<T> &out, ...@@ -977,18 +984,18 @@ template<typename T> void dst(const cfmav<T> &in, const fmav<T> &out,
general_nd<T_dcst23<T>>(in, out, axes, fct, nthreads, exec); general_nd<T_dcst23<T>>(in, out, axes, fct, nthreads, exec);
} }
template<typename T> void r2c(const cfmav<T> &in, template<typename T> void r2c(const fmav<T> &in,
const fmav<std::complex<T>> &out, size_t axis, bool forward, T fct, fmav<std::complex<T>> &out, size_t axis, bool forward, T fct,
size_t nthreads=1) size_t nthreads=1)
{ {
util::sanity_check_cr(out, in, axis); util::sanity_check_cr(out, in, axis);
if (in.size()==0) return; if (in.size()==0) return;
fmav<Cmplx<T>> out2(reinterpret_cast<Cmplx<T> *>(out.data()), out); fmav<Cmplx<T>> out2(reinterpret_cast<Cmplx<T> *>(out.vdata()), out, out.writable());
general_r2c(in, out2, axis, forward, fct, nthreads); general_r2c(in, out2, axis, forward, fct, nthreads);
} }
template<typename T> void r2c(const cfmav<T> &in, template<typename T> void r2c(const fmav<T> &in,
const fmav<std::complex<T>> &out, const shape_t &axes, fmav<std::complex<T>> &out, const shape_t &axes,
bool forward, T fct, size_t nthreads=1) bool forward, T fct, size_t nthreads=1)
{ {
util::sanity_check_cr(out, in, axes); util::sanity_check_cr(out, in, axes);
...@@ -1000,17 +1007,17 @@ template<typename T> void r2c(const cfmav<T> &in, ...@@ -1000,17 +1007,17 @@ template<typename T> void r2c(const cfmav<T> &in,
c2c(out, out, newaxes, forward, T(1), nthreads); c2c(out, out, newaxes, forward, T(1), nthreads);
} }
template<typename T> void c2r(const cfmav<std::complex<T>> &in, template<typename T> void c2r(const fmav<std::complex<T>> &in,
const fmav<T> &out, size_t axis, bool forward, T fct, size_t nthreads=1) fmav<T> &out, size_t axis, bool forward, T fct, size_t nthreads=1)
{ {
util::sanity_check_cr(in, out, axis); util::sanity_check_cr(in, out, axis);
if (in.size()==0) return; if (in.size()==0) return;
cfmav<Cmplx<T>> in2(reinterpret_cast<const Cmplx<T> *>(in.data()), in); fmav<Cmplx<T>> in2(reinterpret_cast<const Cmplx<T> *>(in.data()), in);
general_c2r(in2, out, axis, forward, fct, nthreads); general_c2r(in2, out, axis, forward, fct, nthreads);
} }
template<typename T> void c2r(const cfmav<std::complex<T>> &in, template<typename T> void c2r(const fmav<std::complex<T>> &in,
const fmav<T> &out, const shape_t &axes, bool forward, T fct, fmav<T> &out, const shape_t &axes, bool forward, T fct,
size_t nthreads=1) size_t nthreads=1)
{ {
if (axes.size()==1) if (axes.size()==1)
...@@ -1023,8 +1030,8 @@ template<typename T> void c2r(const cfmav<std::complex<T>> &in, ...@@ -1023,8 +1030,8 @@ template<typename T> void c2r(const cfmav<std::complex<T>> &in,
c2r(atmp, out, axes.back(), forward, fct, nthreads); c2r(atmp, out, axes.back(), forward, fct, nthreads);
} }