Commit 4ef6e3bf authored by Peter Bell's avatar Peter Bell

Refactor n-dimensional general_x functions

parent 343ff57e
......@@ -2895,24 +2895,29 @@ template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
dst[it.oofs(i)] = src[i];
}
template<typename T> POCKETFFT_NOINLINE void general_c(
const cndarr<cmplx<T>> &in, ndarr<cmplx<T>> &out,
const shape_t &axes, bool forward, T fct, size_t POCKETFFT_NTHREADS)
template <typename T> struct add_vec { using type = vtype_t<T>; };
template <typename T> struct add_vec<cmplx<T>>
{ using type = cmplx<vtype_t<T>>; };
template <typename T> using add_vec_t = typename add_vec<T>::type;
template<typename Tplan, typename T, typename T0, typename Exec>
POCKETFFT_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
const shape_t &axes, T0 fct, size_t POCKETFFT_NTHREADS, const Exec & exec)
{
shared_ptr<pocketfft_c<T>> plan;
shared_ptr<Tplan> plan;
for (size_t iax=0; iax<axes.size(); ++iax)
{
constexpr auto vlen = VLEN<T>::val;
constexpr auto vlen = VLEN<T0>::val;
size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length()))
plan = get_plan<pocketfft_c<T>>(len);
plan = get_plan<Tplan>(len);
#ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
#endif
{
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(cmplx<T>));
auto storage = alloc_tmp<T0>(in.shape(), len, sizeof(T));
const auto &tin(iax==0? in : out);
multi_iter<vlen> it(tin, out, axes[iax]);
#ifndef POCKETFFT_NO_VECTORS
......@@ -2920,27 +2925,43 @@ template<typename T> POCKETFFT_NOINLINE void general_c(
while (it.remaining()>=vlen)
{
it.advance(vlen);
auto tdatav = reinterpret_cast<cmplx<vtype_t<T>> *>(storage.data());
copy_input(it, tin, tdatav);
plan->exec (tdatav, fct, forward);
copy_output(it, tdatav, out);
auto tdatav = reinterpret_cast<add_vec_t<T> *>(storage.data());
exec(it, in, out, tdatav, *plan, fct);
}
#endif
while (it.remaining()>0)
{
it.advance(1);
auto buf = it.stride_out() == sizeof(cmplx<T>) ?
&out[it.oofs(0)] : reinterpret_cast<cmplx<T> *>(storage.data());
copy_input(it, tin, buf);
plan->exec (buf, fct, forward);
copy_output(it, buf, out);
auto buf = it.stride_out() == sizeof(T) ?
&out[it.oofs(0)] : reinterpret_cast<T *>(storage.data());
exec(it, in, out, buf, *plan, fct);
}
} // end of parallel region
fct = T(1); // factor has been applied, use 1 for remaining axes
fct = T0(1); // factor has been applied, use 1 for remaining axes
}
}
struct ExecC2C
{
bool forward;
template <typename T0, typename T, size_t vlen> void operator () (
const multi_iter<vlen> &it, const cndarr<cmplx<T0>> &in,
ndarr<cmplx<T0>> &out, T * buf, const pocketfft_c<T0> &plan, T0 fct) const
{
copy_input(it, in, buf);
plan.exec(buf, fct, forward);
copy_output(it, buf, out);
}
};
template<typename T> void general_c(
const cndarr<cmplx<T>> &in, ndarr<cmplx<T>> &out,
const shape_t &axes, bool forward, T fct, size_t nthreads)
{
general_nd<pocketfft_c<T>>(in, out, axes, fct, nthreads, ExecC2C{forward});
}
template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
const vtype_t<T> *POCKETFFT_RESTRICT src, ndarr<T> &dst)
{
......@@ -2972,94 +2993,46 @@ template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
dst[it.oofs(i1)] = src[i];
}
struct ExecHartley
{
template <typename T0, typename T, size_t vlen> void operator () (
const multi_iter<vlen> &it, const cndarr<T0> &in, ndarr<T0> &out,
T * buf, const pocketfft_r<T0> &plan, T0 fct) const
{
copy_input(it, in, buf);
plan.exec(buf, fct, true);
copy_hartley(it, buf, out);
}
};
template<typename T> POCKETFFT_NOINLINE void general_hartley(
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, T fct,
size_t POCKETFFT_NTHREADS)
size_t nthreads)
{
shared_ptr<pocketfft_r<T>> plan;
general_nd<pocketfft_r<T>>(in, out, axes, fct, nthreads, ExecHartley{});
}
for (size_t iax=0; iax<axes.size(); ++iax)
{
constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length()))
plan = get_plan<pocketfft_r<T>>(len);
struct ExecDcst
{
bool ortho;
int type;
bool cosine;
#ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
#endif
{
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
const auto &tin(iax==0 ? in : out);
multi_iter<vlen> it(tin, out, axes[iax]);
#ifndef POCKETFFT_NO_VECTORS
if (vlen>1)
while (it.remaining()>=vlen)
{
it.advance(vlen);
auto tdatav = reinterpret_cast<vtype_t<T> *>(storage.data());
copy_input(it, tin, tdatav);
plan->exec(tdatav, fct, true);
copy_hartley(it, tdatav, out);
}
#endif
while (it.remaining()>0)
{
it.advance(1);
auto tdata = reinterpret_cast<T *>(storage.data());
copy_input(it, tin, tdata);
plan->exec(tdata, fct, true);
copy_hartley(it, tdata, out);
}
} // end of parallel region
fct = T(1); // factor has been applied, use 1 for remaining axes
template <typename T0, typename T, typename Tplan, size_t vlen>
void operator () (const multi_iter<vlen> &it, const cndarr<T0> &in,
ndarr<T0> &out, T * buf, const Tplan &plan, T0 fct) const
{
copy_input(it, in, buf);
plan.exec(buf, fct, ortho, type, cosine);
copy_output(it, buf, out);
}
}
};
template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes,
T fct, bool ortho, int type, bool cosine, size_t POCKETFFT_NTHREADS)
T fct, bool ortho, int type, bool cosine, size_t nthreads)
{
shared_ptr<Trafo> plan;
for (size_t iax=0; iax<axes.size(); ++iax)
{
constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length()))
plan = get_plan<Trafo>(len);
#ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
#endif
{
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
const auto &tin(iax==0 ? in : out);
multi_iter<vlen> it(tin, out, axes[iax]);
#ifndef POCKETFFT_NO_VECTORS
if (vlen>1)
while (it.remaining()>=vlen)
{
it.advance(vlen);
auto tdatav = reinterpret_cast<vtype_t<T> *>(storage.data());
copy_input(it, tin, tdatav);
plan->exec(tdatav, fct, ortho, type, cosine);
copy_output(it, tdatav, out);
}
#endif
while (it.remaining()>0)
{
it.advance(1);
auto buf = it.stride_out() == sizeof(T) ? &out[it.oofs(0)]
: reinterpret_cast<T *>(storage.data());
copy_input(it, tin, buf);
plan->exec(buf, fct, ortho, type, cosine);
copy_output(it, buf, out);
}
} // end of parallel region
fct = T(1); // factor has been applied, use 1 for remaining axes
}
general_nd<Trafo>(in, out, axes, fct, nthreads, ExecDcst{ortho, type, cosine});
}
template<typename T> POCKETFFT_NOINLINE void general_r2c(
......@@ -3179,62 +3152,32 @@ template<typename T> POCKETFFT_NOINLINE void general_c2r(
} // end of parallel region
}
template<typename T> POCKETFFT_NOINLINE void general_r(
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, bool r2c,
bool forward, T fct, size_t POCKETFFT_NTHREADS)
struct ExecR2R
{
shared_ptr<pocketfft_r<T>> plan;
bool r2c, forward;
for (size_t iax=0; iax<axes.size(); ++iax)
template <typename T0, typename T, size_t vlen> void operator () (
const multi_iter<vlen> &it, const cndarr<T0> &in, ndarr<T0> &out, T * buf,
const pocketfft_r<T0> &plan, T0 fct) const
{
constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length()))
plan = get_plan<pocketfft_r<T>>(len);
#ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
#endif
{
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
const auto &tin(iax==0 ? in : out);
multi_iter<vlen> it(tin, out, axes[iax]);
#ifndef POCKETFFT_NO_VECTORS
if (vlen>1)
while (it.remaining()>=vlen)
{
it.advance(vlen);
auto tdatav = reinterpret_cast<vtype_t<T> *>(storage.data());
copy_input(it, tin, tdatav);
if ((!r2c) && forward)
for (size_t i=2; i<len; i+=2)
tdatav[i] = -tdatav[i];
plan->exec (tdatav, fct, forward);
if (r2c && (!forward))
for (size_t i=2; i<len; i+=2)
tdatav[i] = -tdatav[i];
copy_output(it, tdatav, out);
}
#endif
while (it.remaining()>0)
{
it.advance(1);
auto buf = it.stride_out() == sizeof(T) ?
&out[it.oofs(0)] : reinterpret_cast<T *>(storage.data());
copy_input(it, tin, buf);
if ((!r2c) && forward)
for (size_t i=2; i<len; i+=2)
buf[i] = -buf[i];
plan->exec(buf, fct, forward);
if (r2c && (!forward))
for (size_t i=2; i<len; i+=2)
buf[i] = -buf[i];
copy_output(it, buf, out);
}
} // end of parallel region
fct = T(1); // factor has been applied, use 1 for remaining axes
copy_input(it, in, buf);
if ((!r2c) && forward)
for (size_t i=2; i<it.length_out(); i+=2)
buf[i] = -buf[i];
plan.exec(buf, fct, forward);
if (r2c && (!forward))
for (size_t i=2; i<it.length_out(); i+=2)
buf[i] = -buf[i];
copy_output(it, buf, out);
}
};
template<typename T> POCKETFFT_NOINLINE void general_r(
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, bool r2c,
bool forward, T fct, size_t nthreads)
{
general_nd<pocketfft_r<T>>(in, out, axes, fct, nthreads,
ExecR2R{r2c, forward});
}
#undef POCKETFFT_NTHREADS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment