Commit da7ba7dd authored by Martin Reinecke's avatar Martin Reinecke
Browse files

allow multiple axes for r2r_fftpack

parent 9d8af0bc
......@@ -2724,59 +2724,96 @@ template<typename T> NOINLINE void general_c2r(
}
template<typename T> NOINLINE void general_r(
const cndarr<T> &in, ndarr<T> &out, size_t axis, bool forward, T fct,
size_t POCKETFFT_NTHREADS)
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, bool r2c,
bool forward, T fct, size_t POCKETFFT_NTHREADS)
{
constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axis);
pocketfft_r<T> plan(len);
unique_ptr<pocketfft_r<T>> plan;
for (size_t iax=0; iax<axes.size(); ++iax)
{
constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length()))
plan.reset(new pocketfft_r<T>(len));
#ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axis))
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axes[iax]))
#endif
{
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
multi_iter<vlen> it(in, out, axis);
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
const auto &tin(iax==0 ? in : out);
multi_iter<vlen> it(tin, out, axes[iax]);
#ifndef POCKETFFT_NO_VECTORS
if (vlen>1)
while (it.remaining()>=vlen)
{
using vtype = typename VTYPE<T>::type;
it.advance(vlen);
auto tdatav = reinterpret_cast<vtype *>(storage.data());
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = in[it.iofs(j,i)];
forward ? plan.forward (tdatav, fct)
: plan.backward(tdatav, fct);
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,i)] = tdatav[i][j];
}
if (vlen>1)
while (it.remaining()>=vlen)
{
using vtype = typename VTYPE<T>::type;
it.advance(vlen);
auto tdatav = reinterpret_cast<vtype *>(storage.data());
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = tin[it.iofs(j,i)];
if ((!r2c) && forward)
for (size_t i=2; i<len; i+=2)
for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = -tdatav[i][j];
forward ? plan->forward (tdatav, fct)
: plan->backward(tdatav, fct);
if (r2c && (!forward))
for (size_t i=2; i<len; i+=2)
for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = -tdatav[i][j];
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,i)] = tdatav[i][j];
}
#endif
while (it.remaining()>0)
{
it.advance(1);
auto tdata = reinterpret_cast<T *>(storage.data());
if ((&in[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place
forward ? plan.forward (&out[it.oofs(0)], fct)
: plan.backward(&out[it.oofs(0)], fct);
else if (it.stride_out()==sizeof(T)) // compute FFT in output location
{
for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = in[it.iofs(i)];
forward ? plan.forward (&out[it.oofs(0)], fct)
: plan.backward(&out[it.oofs(0)], fct);
}
else
while (it.remaining()>0)
{
for (size_t i=0; i<len; ++i)
tdata[i] = in[it.iofs(i)];
forward ? plan.forward (tdata, fct) : plan.backward(tdata, fct);
for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tdata[i];
it.advance(1);
auto tdata = reinterpret_cast<T *>(storage.data());
if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place
{
if ((!r2c) && forward)
for (size_t i=2; i<len; i+=2)
out[it.oofs(i)] = -out[it.oofs(i)];
forward ? plan->forward (&out[it.oofs(0)], fct)
: plan->backward(&out[it.oofs(0)], fct);
if (r2c && (!forward))
for (size_t i=2; i<len; i+=2)
out[it.oofs(i)] = -out[it.oofs(i)];
}
else if (it.stride_out()==sizeof(T)) // compute FFT in output location
{
for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tin[it.iofs(i)];
if ((!r2c) && forward)
for (size_t i=2; i<len; i+=2)
out[it.oofs(i)] = -out[it.oofs(i)];
forward ? plan->forward (&out[it.oofs(0)], fct)
: plan->backward(&out[it.oofs(0)], fct);
if (r2c && (!forward))
for (size_t i=2; i<len; i+=2)
out[it.oofs(i)] = -out[it.oofs(i)];
}
else
{
for (size_t i=0; i<len; ++i)
tdata[i] = tin[it.iofs(i)];
if ((!r2c) && forward)
for (size_t i=2; i<len; i+=2)
tdata[i] = -tdata[i];
forward ? plan->forward (tdata, fct) : plan->backward(tdata, fct);
if (r2c && (!forward))
for (size_t i=2; i<len; i+=2)
tdata[i] = -tdata[i];
for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tdata[i];
}
}
}
} // end of parallel region
fct = T(1); // factor has been applied, use 1 for remaining axes
}
}
#undef POCKETFFT_NTHREADS
......@@ -2864,14 +2901,15 @@ template<typename T> void c2r(const shape_t &shape_out,
}
template<typename T> void r2r_fftpack(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, size_t axis,
bool forward, const T *data_in, T *data_out, T fct, size_t nthreads=1)
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
bool r2c, bool forward, const T *data_in, T *data_out, T fct,
size_t nthreads=1)
{
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axis);
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out);
general_r(ain, aout, axis, forward, fct, nthreads);
general_r(ain, aout, axes, r2c, forward, fct, nthreads);
}
template<typename T> void r2r_hartley(const shape_t &shape,
......
......@@ -205,7 +205,7 @@ template<typename T> py::array xrfft_scipy(const py::array &in,
{
py::gil_scoped_release release;
T fct = norm_fct<T>(inorm, dims[axis]);
r2r_fftpack(dims, s_in, s_out, axis, fwd, d_in, d_out, fct, nthreads);
r2r_fftpack(dims, s_in, s_out, {axis}, fwd, fwd, d_in, d_out, fct, nthreads);
}
return res;
}
......
......@@ -54,6 +54,11 @@ def test():
print("fmaxerrf:", fmaxerrf, shape, axes)
b=pypocketfft.hartley(pypocketfft.hartley(a.real,axes=axes,nthreads=nthreads),axes=axes,inorm=2,nthreads=nthreads)
err = _l2error(a.real,b)
if err > hmaxerr:
hmaxerr = err
print("hmaxerr:", hmaxerr, shape, axes)
b=pypocketfft.hartley2(pypocketfft.hartley2(a.real,axes=axes,nthreads=nthreads),axes=axes,inorm=2,nthreads=nthreads)
err = _l2error(a.real,b)
if err > hmaxerr:
hmaxerr = err
print("hmaxerr:", hmaxerr, shape, axes)
......@@ -62,6 +67,11 @@ def test():
if err > hmaxerrf:
hmaxerrf = err
print("hmaxerrf:", hmaxerrf, shape, axes)
b=pypocketfft.hartley2(pypocketfft.hartley2(a.real.astype(np.float32),axes=axes,nthreads=nthreads),axes=axes,inorm=2,nthreads=nthreads)
err = _l2error(a.real.astype(np.float32),b)
if err > hmaxerrf:
hmaxerrf = err
print("hmaxerrf:", hmaxerrf, shape, axes)
while True:
test()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment