Commit 7e8c30c6 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

add 'forward' parameter to c2r and r2c

parent 89d506fd
......@@ -2587,7 +2587,7 @@ template<typename T> NOINLINE void general_hartley(
}
template<typename T> NOINLINE void general_r2c(
const cndarr<T> &in, ndarr<cmplx<T>> &out, size_t axis, T fct,
const cndarr<T> &in, ndarr<cmplx<T>> &out, size_t axis, bool forward, T fct,
size_t POCKETFFT_NTHREADS)
{
pocketfft_r<T> plan(in.shape(axis));
......@@ -2613,9 +2613,14 @@ template<typename T> NOINLINE void general_r2c(
for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,0)].Set(tdatav[0][j]);
size_t i=1, ii=1;
for (; i<len-1; i+=2, ++ii)
for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,ii)].Set(tdatav[i][j], tdatav[i+1][j]);
if (forward)
for (; i<len-1; i+=2, ++ii)
for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,ii)].Set(tdatav[i][j], tdatav[i+1][j]);
else
for (; i<len-1; i+=2, ++ii)
for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,ii)].Set(tdatav[i][j], -tdatav[i+1][j]);
if (i<len)
for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,ii)].Set(tdatav[i][j]);
......@@ -2630,15 +2635,19 @@ template<typename T> NOINLINE void general_r2c(
plan.forward(tdata, fct);
out[it.oofs(0)].Set(tdata[0]);
size_t i=1, ii=1;
for (; i<len-1; i+=2, ++ii)
out[it.oofs(ii)].Set(tdata[i], tdata[i+1]);
if (forward)
for (; i<len-1; i+=2, ++ii)
out[it.oofs(ii)].Set(tdata[i], tdata[i+1]);
else
for (; i<len-1; i+=2, ++ii)
out[it.oofs(ii)].Set(tdata[i], -tdata[i+1]);
if (i<len)
out[it.oofs(ii)].Set(tdata[i]);
}
} // end of parallel region
}
template<typename T> NOINLINE void general_c2r(
const cndarr<cmplx<T>> &in, ndarr<T> &out, size_t axis, T fct,
const cndarr<cmplx<T>> &in, ndarr<T> &out, size_t axis, bool forward, T fct,
size_t POCKETFFT_NTHREADS)
{
pocketfft_r<T> plan(out.shape(axis));
......@@ -2661,12 +2670,20 @@ template<typename T> NOINLINE void general_c2r(
tdatav[0][j]=in[it.iofs(j,0)].r;
{
size_t i=1, ii=1;
for (; i<len-1; i+=2, ++ii)
for (size_t j=0; j<vlen; ++j)
{
tdatav[i ][j] = in[it.iofs(j,ii)].r;
tdatav[i+1][j] = in[it.iofs(j,ii)].i;
}
if (forward)
for (; i<len-1; i+=2, ++ii)
for (size_t j=0; j<vlen; ++j)
{
tdatav[i ][j] = in[it.iofs(j,ii)].r;
tdatav[i+1][j] = -in[it.iofs(j,ii)].i;
}
else
for (; i<len-1; i+=2, ++ii)
for (size_t j=0; j<vlen; ++j)
{
tdatav[i ][j] = in[it.iofs(j,ii)].r;
tdatav[i+1][j] = in[it.iofs(j,ii)].i;
}
if (i<len)
for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = in[it.iofs(j,ii)].r;
......@@ -2684,11 +2701,18 @@ template<typename T> NOINLINE void general_c2r(
tdata[0]=in[it.iofs(0)].r;
{
size_t i=1, ii=1;
for (; i<len-1; i+=2, ++ii)
{
tdata[i ] = in[it.iofs(ii)].r;
tdata[i+1] = in[it.iofs(ii)].i;
}
if (forward)
for (; i<len-1; i+=2, ++ii)
{
tdata[i ] = in[it.iofs(ii)].r;
tdata[i+1] = -in[it.iofs(ii)].i;
}
else
for (; i<len-1; i+=2, ++ii)
{
tdata[i ] = in[it.iofs(ii)].r;
tdata[i+1] = in[it.iofs(ii)].i;
}
if (i<len)
tdata[i] = in[it.iofs(ii)].r;
}
......@@ -2771,7 +2795,8 @@ template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,
template<typename T> void r2c(const shape_t &shape_in,
const stride_t &stride_in, const stride_t &stride_out, size_t axis,
const T *data_in, complex<T> *data_out, T fct, size_t nthreads=1)
bool forward, const T *data_in, complex<T> *data_out, T fct,
size_t nthreads=1)
{
if (util::prod(shape_in)==0) return;
util::sanity_check(shape_in, stride_in, stride_out, false, axis);
......@@ -2779,29 +2804,30 @@ template<typename T> void r2c(const shape_t &shape_in,
shape_t shape_out(shape_in);
shape_out[axis] = shape_in[axis]/2 + 1;
ndarr<cmplx<T>> aout(data_out, shape_out, stride_out);
general_r2c(ain, aout, axis, fct, nthreads);
general_r2c(ain, aout, axis, forward, fct, nthreads);
}
template<typename T> void r2c(const shape_t &shape_in,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
const T *data_in, complex<T> *data_out, T fct, size_t nthreads=1)
bool forward, const T *data_in, complex<T> *data_out, T fct, size_t nthreads=1)
{
if (util::prod(shape_in)==0) return;
util::sanity_check(shape_in, stride_in, stride_out, false, axes);
r2c(shape_in, stride_in, stride_out, axes.back(), data_in, data_out, fct,
nthreads);
r2c(shape_in, stride_in, stride_out, axes.back(), forward, data_in, data_out,
fct, nthreads);
if (axes.size()==1) return;
shape_t shape_out(shape_in);
shape_out[axes.back()] = shape_in[axes.back()]/2 + 1;
auto newaxes = shape_t{axes.begin(), --axes.end()};
c2c(shape_out, stride_out, stride_out, newaxes, true, data_out, data_out,
c2c(shape_out, stride_out, stride_out, newaxes, forward, data_out, data_out,
T(1), nthreads);
}
template<typename T> void c2r(const shape_t &shape_out,
const stride_t &stride_in, const stride_t &stride_out, size_t axis,
const complex<T> *data_in, T *data_out, T fct, size_t nthreads=1)
bool forward, const complex<T> *data_in, T *data_out, T fct,
size_t nthreads=1)
{
if (util::prod(shape_out)==0) return;
util::sanity_check(shape_out, stride_in, stride_out, false, axis);
......@@ -2809,17 +2835,18 @@ template<typename T> void c2r(const shape_t &shape_out,
shape_in[axis] = shape_out[axis]/2 + 1;
cndarr<cmplx<T>> ain(data_in, shape_in, stride_in);
ndarr<T> aout(data_out, shape_out, stride_out);
general_c2r(ain, aout, axis, fct, nthreads);
general_c2r(ain, aout, axis, forward, fct, nthreads);
}
template<typename T> void c2r(const shape_t &shape_out,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
const complex<T> *data_in, T *data_out, T fct, size_t nthreads=1)
bool forward, const complex<T> *data_in, T *data_out, T fct,
size_t nthreads=1)
{
if (util::prod(shape_out)==0) return;
if (axes.size()==1)
return c2r(shape_out, stride_in, stride_out, axes[0], data_in, data_out,
fct, nthreads);
return c2r(shape_out, stride_in, stride_out, axes[0], forward,
data_in, data_out, fct, nthreads);
util::sanity_check(shape_out, stride_in, stride_out, false, axes);
auto shape_in = shape_out;
shape_in[axes.back()] = shape_out[axes.back()]/2 + 1;
......@@ -2830,10 +2857,10 @@ template<typename T> void c2r(const shape_t &shape_out,
stride_inter[size_t(i)] = stride_inter[size_t(i+1)]*ptrdiff_t(shape_in[size_t(i+1)]);
arr<complex<T>> tmp(nval);
auto newaxes = shape_t({axes.begin(), --axes.end()});
c2c(shape_in, stride_in, stride_inter, newaxes, false, data_in, tmp.data(),
c2c(shape_in, stride_in, stride_inter, newaxes, forward, data_in, tmp.data(),
T(1), nthreads);
c2r(shape_out, stride_inter, stride_out, axes.back(), tmp.data(), data_out,
fct, nthreads);
c2r(shape_out, stride_inter, stride_out, axes.back(), forward,
tmp.data(), data_out, fct, nthreads);
}
template<typename T> void r2r_fftpack(const shape_t &shape,
......
......@@ -128,7 +128,7 @@ template<typename T> py::array sym_rfftn_internal(const py::array &in,
{
py::gil_scoped_release release;
T fct = norm_fct<T>(inorm, dims, axes);
r2c(dims, s_in, s_out, axes, d_in, d_out, fct, nthreads);
r2c(dims, s_in, s_out, axes, fwd, d_in, d_out, fct, nthreads);
// now fill in second half
using namespace pocketfft::detail;
ndarr<complex<T>> ares(res.mutable_data(), dims, s_out);
......@@ -139,16 +139,6 @@ template<typename T> py::array sym_rfftn_internal(const py::array &in,
ares[iter.rev_ofs()] = conj(v);
iter.advance();
}
if (!fwd)
{
simple_iter iter(ares);
while(iter.remaining()>0)
{
auto v = ares[iter.ofs()];
ares[iter.ofs()] = conj(v);
iter.advance();
}
}
}
return res;
}
......@@ -187,7 +177,7 @@ template<typename T> py::array rfftn_internal(const py::array &in,
{
py::gil_scoped_release release;
T fct = norm_fct<T>(inorm, dims_in, axes);
r2c(dims_in, s_in, s_out, axes, d_in, d_out, fct, nthreads);
r2c(dims_in, s_in, s_out, axes, true, d_in, d_out, fct, nthreads);
}
return res;
}
......@@ -246,7 +236,7 @@ template<typename T> py::array irfftn_internal(const py::array &in,
{
py::gil_scoped_release release;
T fct = norm_fct<T>(inorm, dims_out, axes);
c2r(dims_out, s_in, s_out, axes, d_in, d_out, fct, nthreads);
c2r(dims_out, s_in, s_out, axes, false, d_in, d_out, fct, nthreads);
}
return res;
}
......@@ -324,7 +314,7 @@ py::array hartley2(const py::array &in, py::object axes_, int inorm,
const char *pypocketfft_DS = R"""(Fast Fourier and Hartley transforms.
This module supports
- single and double precision
- single, double, and long double precision
- complex and real-valued transforms
- multi-dimensional transforms
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment