Commit 7e8c30c6 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

add 'forward' parameter to c2r and r2c

parent 89d506fd
...@@ -2587,7 +2587,7 @@ template<typename T> NOINLINE void general_hartley( ...@@ -2587,7 +2587,7 @@ template<typename T> NOINLINE void general_hartley(
} }
template<typename T> NOINLINE void general_r2c( template<typename T> NOINLINE void general_r2c(
const cndarr<T> &in, ndarr<cmplx<T>> &out, size_t axis, T fct, const cndarr<T> &in, ndarr<cmplx<T>> &out, size_t axis, bool forward, T fct,
size_t POCKETFFT_NTHREADS) size_t POCKETFFT_NTHREADS)
{ {
pocketfft_r<T> plan(in.shape(axis)); pocketfft_r<T> plan(in.shape(axis));
...@@ -2613,9 +2613,14 @@ template<typename T> NOINLINE void general_r2c( ...@@ -2613,9 +2613,14 @@ template<typename T> NOINLINE void general_r2c(
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,0)].Set(tdatav[0][j]); out[it.oofs(j,0)].Set(tdatav[0][j]);
size_t i=1, ii=1; size_t i=1, ii=1;
for (; i<len-1; i+=2, ++ii) if (forward)
for (size_t j=0; j<vlen; ++j) for (; i<len-1; i+=2, ++ii)
out[it.oofs(j,ii)].Set(tdatav[i][j], tdatav[i+1][j]); for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,ii)].Set(tdatav[i][j], tdatav[i+1][j]);
else
for (; i<len-1; i+=2, ++ii)
for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,ii)].Set(tdatav[i][j], -tdatav[i+1][j]);
if (i<len) if (i<len)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,ii)].Set(tdatav[i][j]); out[it.oofs(j,ii)].Set(tdatav[i][j]);
...@@ -2630,15 +2635,19 @@ template<typename T> NOINLINE void general_r2c( ...@@ -2630,15 +2635,19 @@ template<typename T> NOINLINE void general_r2c(
plan.forward(tdata, fct); plan.forward(tdata, fct);
out[it.oofs(0)].Set(tdata[0]); out[it.oofs(0)].Set(tdata[0]);
size_t i=1, ii=1; size_t i=1, ii=1;
for (; i<len-1; i+=2, ++ii) if (forward)
out[it.oofs(ii)].Set(tdata[i], tdata[i+1]); for (; i<len-1; i+=2, ++ii)
out[it.oofs(ii)].Set(tdata[i], tdata[i+1]);
else
for (; i<len-1; i+=2, ++ii)
out[it.oofs(ii)].Set(tdata[i], -tdata[i+1]);
if (i<len) if (i<len)
out[it.oofs(ii)].Set(tdata[i]); out[it.oofs(ii)].Set(tdata[i]);
} }
} // end of parallel region } // end of parallel region
} }
template<typename T> NOINLINE void general_c2r( template<typename T> NOINLINE void general_c2r(
const cndarr<cmplx<T>> &in, ndarr<T> &out, size_t axis, T fct, const cndarr<cmplx<T>> &in, ndarr<T> &out, size_t axis, bool forward, T fct,
size_t POCKETFFT_NTHREADS) size_t POCKETFFT_NTHREADS)
{ {
pocketfft_r<T> plan(out.shape(axis)); pocketfft_r<T> plan(out.shape(axis));
...@@ -2661,12 +2670,20 @@ template<typename T> NOINLINE void general_c2r( ...@@ -2661,12 +2670,20 @@ template<typename T> NOINLINE void general_c2r(
tdatav[0][j]=in[it.iofs(j,0)].r; tdatav[0][j]=in[it.iofs(j,0)].r;
{ {
size_t i=1, ii=1; size_t i=1, ii=1;
for (; i<len-1; i+=2, ++ii) if (forward)
for (size_t j=0; j<vlen; ++j) for (; i<len-1; i+=2, ++ii)
{ for (size_t j=0; j<vlen; ++j)
tdatav[i ][j] = in[it.iofs(j,ii)].r; {
tdatav[i+1][j] = in[it.iofs(j,ii)].i; tdatav[i ][j] = in[it.iofs(j,ii)].r;
} tdatav[i+1][j] = -in[it.iofs(j,ii)].i;
}
else
for (; i<len-1; i+=2, ++ii)
for (size_t j=0; j<vlen; ++j)
{
tdatav[i ][j] = in[it.iofs(j,ii)].r;
tdatav[i+1][j] = in[it.iofs(j,ii)].i;
}
if (i<len) if (i<len)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = in[it.iofs(j,ii)].r; tdatav[i][j] = in[it.iofs(j,ii)].r;
...@@ -2684,11 +2701,18 @@ template<typename T> NOINLINE void general_c2r( ...@@ -2684,11 +2701,18 @@ template<typename T> NOINLINE void general_c2r(
tdata[0]=in[it.iofs(0)].r; tdata[0]=in[it.iofs(0)].r;
{ {
size_t i=1, ii=1; size_t i=1, ii=1;
for (; i<len-1; i+=2, ++ii) if (forward)
{ for (; i<len-1; i+=2, ++ii)
tdata[i ] = in[it.iofs(ii)].r; {
tdata[i+1] = in[it.iofs(ii)].i; tdata[i ] = in[it.iofs(ii)].r;
} tdata[i+1] = -in[it.iofs(ii)].i;
}
else
for (; i<len-1; i+=2, ++ii)
{
tdata[i ] = in[it.iofs(ii)].r;
tdata[i+1] = in[it.iofs(ii)].i;
}
if (i<len) if (i<len)
tdata[i] = in[it.iofs(ii)].r; tdata[i] = in[it.iofs(ii)].r;
} }
...@@ -2771,7 +2795,8 @@ template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in, ...@@ -2771,7 +2795,8 @@ template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,
template<typename T> void r2c(const shape_t &shape_in, template<typename T> void r2c(const shape_t &shape_in,
const stride_t &stride_in, const stride_t &stride_out, size_t axis, const stride_t &stride_in, const stride_t &stride_out, size_t axis,
const T *data_in, complex<T> *data_out, T fct, size_t nthreads=1) bool forward, const T *data_in, complex<T> *data_out, T fct,
size_t nthreads=1)
{ {
if (util::prod(shape_in)==0) return; if (util::prod(shape_in)==0) return;
util::sanity_check(shape_in, stride_in, stride_out, false, axis); util::sanity_check(shape_in, stride_in, stride_out, false, axis);
...@@ -2779,29 +2804,30 @@ template<typename T> void r2c(const shape_t &shape_in, ...@@ -2779,29 +2804,30 @@ template<typename T> void r2c(const shape_t &shape_in,
shape_t shape_out(shape_in); shape_t shape_out(shape_in);
shape_out[axis] = shape_in[axis]/2 + 1; shape_out[axis] = shape_in[axis]/2 + 1;
ndarr<cmplx<T>> aout(data_out, shape_out, stride_out); ndarr<cmplx<T>> aout(data_out, shape_out, stride_out);
general_r2c(ain, aout, axis, fct, nthreads); general_r2c(ain, aout, axis, forward, fct, nthreads);
} }
template<typename T> void r2c(const shape_t &shape_in, template<typename T> void r2c(const shape_t &shape_in,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
const T *data_in, complex<T> *data_out, T fct, size_t nthreads=1) bool forward, const T *data_in, complex<T> *data_out, T fct, size_t nthreads=1)
{ {
if (util::prod(shape_in)==0) return; if (util::prod(shape_in)==0) return;
util::sanity_check(shape_in, stride_in, stride_out, false, axes); util::sanity_check(shape_in, stride_in, stride_out, false, axes);
r2c(shape_in, stride_in, stride_out, axes.back(), data_in, data_out, fct, r2c(shape_in, stride_in, stride_out, axes.back(), forward, data_in, data_out,
nthreads); fct, nthreads);
if (axes.size()==1) return; if (axes.size()==1) return;
shape_t shape_out(shape_in); shape_t shape_out(shape_in);
shape_out[axes.back()] = shape_in[axes.back()]/2 + 1; shape_out[axes.back()] = shape_in[axes.back()]/2 + 1;
auto newaxes = shape_t{axes.begin(), --axes.end()}; auto newaxes = shape_t{axes.begin(), --axes.end()};
c2c(shape_out, stride_out, stride_out, newaxes, true, data_out, data_out, c2c(shape_out, stride_out, stride_out, newaxes, forward, data_out, data_out,
T(1), nthreads); T(1), nthreads);
} }
template<typename T> void c2r(const shape_t &shape_out, template<typename T> void c2r(const shape_t &shape_out,
const stride_t &stride_in, const stride_t &stride_out, size_t axis, const stride_t &stride_in, const stride_t &stride_out, size_t axis,
const complex<T> *data_in, T *data_out, T fct, size_t nthreads=1) bool forward, const complex<T> *data_in, T *data_out, T fct,
size_t nthreads=1)
{ {
if (util::prod(shape_out)==0) return; if (util::prod(shape_out)==0) return;
util::sanity_check(shape_out, stride_in, stride_out, false, axis); util::sanity_check(shape_out, stride_in, stride_out, false, axis);
...@@ -2809,17 +2835,18 @@ template<typename T> void c2r(const shape_t &shape_out, ...@@ -2809,17 +2835,18 @@ template<typename T> void c2r(const shape_t &shape_out,
shape_in[axis] = shape_out[axis]/2 + 1; shape_in[axis] = shape_out[axis]/2 + 1;
cndarr<cmplx<T>> ain(data_in, shape_in, stride_in); cndarr<cmplx<T>> ain(data_in, shape_in, stride_in);
ndarr<T> aout(data_out, shape_out, stride_out); ndarr<T> aout(data_out, shape_out, stride_out);
general_c2r(ain, aout, axis, fct, nthreads); general_c2r(ain, aout, axis, forward, fct, nthreads);
} }
template<typename T> void c2r(const shape_t &shape_out, template<typename T> void c2r(const shape_t &shape_out,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
const complex<T> *data_in, T *data_out, T fct, size_t nthreads=1) bool forward, const complex<T> *data_in, T *data_out, T fct,
size_t nthreads=1)
{ {
if (util::prod(shape_out)==0) return; if (util::prod(shape_out)==0) return;
if (axes.size()==1) if (axes.size()==1)
return c2r(shape_out, stride_in, stride_out, axes[0], data_in, data_out, return c2r(shape_out, stride_in, stride_out, axes[0], forward,
fct, nthreads); data_in, data_out, fct, nthreads);
util::sanity_check(shape_out, stride_in, stride_out, false, axes); util::sanity_check(shape_out, stride_in, stride_out, false, axes);
auto shape_in = shape_out; auto shape_in = shape_out;
shape_in[axes.back()] = shape_out[axes.back()]/2 + 1; shape_in[axes.back()] = shape_out[axes.back()]/2 + 1;
...@@ -2830,10 +2857,10 @@ template<typename T> void c2r(const shape_t &shape_out, ...@@ -2830,10 +2857,10 @@ template<typename T> void c2r(const shape_t &shape_out,
stride_inter[size_t(i)] = stride_inter[size_t(i+1)]*ptrdiff_t(shape_in[size_t(i+1)]); stride_inter[size_t(i)] = stride_inter[size_t(i+1)]*ptrdiff_t(shape_in[size_t(i+1)]);
arr<complex<T>> tmp(nval); arr<complex<T>> tmp(nval);
auto newaxes = shape_t({axes.begin(), --axes.end()}); auto newaxes = shape_t({axes.begin(), --axes.end()});
c2c(shape_in, stride_in, stride_inter, newaxes, false, data_in, tmp.data(), c2c(shape_in, stride_in, stride_inter, newaxes, forward, data_in, tmp.data(),
T(1), nthreads); T(1), nthreads);
c2r(shape_out, stride_inter, stride_out, axes.back(), tmp.data(), data_out, c2r(shape_out, stride_inter, stride_out, axes.back(), forward,
fct, nthreads); tmp.data(), data_out, fct, nthreads);
} }
template<typename T> void r2r_fftpack(const shape_t &shape, template<typename T> void r2r_fftpack(const shape_t &shape,
......
...@@ -128,7 +128,7 @@ template<typename T> py::array sym_rfftn_internal(const py::array &in, ...@@ -128,7 +128,7 @@ template<typename T> py::array sym_rfftn_internal(const py::array &in,
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
T fct = norm_fct<T>(inorm, dims, axes); T fct = norm_fct<T>(inorm, dims, axes);
r2c(dims, s_in, s_out, axes, d_in, d_out, fct, nthreads); r2c(dims, s_in, s_out, axes, fwd, d_in, d_out, fct, nthreads);
// now fill in second half // now fill in second half
using namespace pocketfft::detail; using namespace pocketfft::detail;
ndarr<complex<T>> ares(res.mutable_data(), dims, s_out); ndarr<complex<T>> ares(res.mutable_data(), dims, s_out);
...@@ -139,16 +139,6 @@ template<typename T> py::array sym_rfftn_internal(const py::array &in, ...@@ -139,16 +139,6 @@ template<typename T> py::array sym_rfftn_internal(const py::array &in,
ares[iter.rev_ofs()] = conj(v); ares[iter.rev_ofs()] = conj(v);
iter.advance(); iter.advance();
} }
if (!fwd)
{
simple_iter iter(ares);
while(iter.remaining()>0)
{
auto v = ares[iter.ofs()];
ares[iter.ofs()] = conj(v);
iter.advance();
}
}
} }
return res; return res;
} }
...@@ -187,7 +177,7 @@ template<typename T> py::array rfftn_internal(const py::array &in, ...@@ -187,7 +177,7 @@ template<typename T> py::array rfftn_internal(const py::array &in,
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
T fct = norm_fct<T>(inorm, dims_in, axes); T fct = norm_fct<T>(inorm, dims_in, axes);
r2c(dims_in, s_in, s_out, axes, d_in, d_out, fct, nthreads); r2c(dims_in, s_in, s_out, axes, true, d_in, d_out, fct, nthreads);
} }
return res; return res;
} }
...@@ -246,7 +236,7 @@ template<typename T> py::array irfftn_internal(const py::array &in, ...@@ -246,7 +236,7 @@ template<typename T> py::array irfftn_internal(const py::array &in,
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
T fct = norm_fct<T>(inorm, dims_out, axes); T fct = norm_fct<T>(inorm, dims_out, axes);
c2r(dims_out, s_in, s_out, axes, d_in, d_out, fct, nthreads); c2r(dims_out, s_in, s_out, axes, false, d_in, d_out, fct, nthreads);
} }
return res; return res;
} }
...@@ -324,7 +314,7 @@ py::array hartley2(const py::array &in, py::object axes_, int inorm, ...@@ -324,7 +314,7 @@ py::array hartley2(const py::array &in, py::object axes_, int inorm,
const char *pypocketfft_DS = R"""(Fast Fourier and Hartley transforms. const char *pypocketfft_DS = R"""(Fast Fourier and Hartley transforms.
This module supports This module supports
- single and double precision - single, double, and long double precision
- complex and real-valued transforms - complex and real-valued transforms
- multi-dimensional transforms - multi-dimensional transforms
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment