Commit 0911de6d authored by Martin Reinecke's avatar Martin Reinecke

better axes handling

parent a872975f
......@@ -2387,25 +2387,26 @@ auto c128 = py::dtype("complex128");
auto f32 = py::dtype("float32");
auto f64 = py::dtype("float64");
void check_args(const py::array &in, vector<size_t> &axes)
vector<size_t> makeaxes(const py::array &in, py::object axes)
{
if (axes.size()==0) throw runtime_error("no FFT axis specified");
if ((axes.size()==1) && axes[0]==10000) // indicator for "axes not given"
if (axes.is(py::none()))
{
axes.resize(in.ndim());
for (int i=0; i<in.ndim(); ++i)
axes[i] = i;
vector<size_t> res(in.ndim());
for (size_t i=0; i<res.size(); ++i)
res[i]=i;
return res;
}
if (axes.size()>size_t(in.ndim()))
auto tmp=axes.cast<vector<size_t>>();
if (tmp.size()>size_t(in.ndim()))
throw runtime_error("bad axes argument");
for (size_t i=0; i<axes.size(); ++i)
if (axes[i]>=size_t(in.ndim()))
for (size_t i=0; i<tmp.size(); ++i)
if (tmp[i]>=size_t(in.ndim()))
throw runtime_error("invalid axis number");
return tmp;
}
py::array execute(const py::array &in, vector<size_t> axes, double fct, bool fwd)
{
check_args(in, axes);
py::array res;
vector<size_t> dims(in.ndim());
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
......@@ -2437,14 +2438,14 @@ py::array execute(const py::array &in, vector<size_t> axes, double fct, bool fwd
(cmplx<float> *)res.mutable_data(), fct);
return res;
}
py::array fftn(const py::array &a, const vector<size_t> &axes, double fct)
{ return execute(a, axes, fct, true); }
py::array ifftn(const py::array &a, const vector<size_t> &axes, double fct)
{ return execute(a, axes, fct, false); }
py::array fftn(const py::array &a, py::object axes, double fct)
{ return execute(a, makeaxes(a, axes), fct, true); }
py::array ifftn(const py::array &a, py::object axes, double fct)
{ return execute(a, makeaxes(a, axes), fct, false); }
py::array rfftn(const py::array &in, vector<size_t> axes, double fct)
py::array rfftn(const py::array &in, py::object axes_, double fct)
{
check_args(in, axes);
auto axes = makeaxes(in, axes_);
py::array res;
vector<size_t> dims_in(in.ndim()), dims_out(in.ndim());
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
......@@ -2494,17 +2495,17 @@ py::array rfftn(const py::array &in, vector<size_t> axes, double fct)
}
return res;
}
py::array irfftn(const py::array &in, vector<size_t> axes, size_t lastsize,
py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
double fct)
{
check_args(in, axes);
auto axes = makeaxes(in, axes_);
py::array inter;
if (axes.size()>1) // need to do complex transforms
{
vector<size_t> axes2(axes.size()-1);
for (size_t i=0; i<axes2.size(); ++i)
axes2[i] = axes[i];
inter = ifftn(in, axes2, 1.);
inter = execute(in, axes2, 1., false);
}
else
inter = in;
......@@ -2548,9 +2549,9 @@ py::array irfftn(const py::array &in, vector<size_t> axes, size_t lastsize,
(float *)res.mutable_data(), fct);
return res;
}
py::array hartley(const py::array &in, vector<size_t> axes, double fct)
py::array hartley(const py::array &in, py::object axes_, double fct)
{
check_args(in, axes);
auto axes = makeaxes(in, axes_);
py::array res;
vector<size_t> dims(in.ndim());
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
......@@ -2583,6 +2584,19 @@ py::array hartley(const py::array &in, vector<size_t> axes, double fct)
return res;
}
const char *pypocketfft_DS = R"DELIM(
Fast Fourier and Hartley transforms.
This module supports
- single and double precision
- complex and real-valued transforms
- multi-dimensional transforms
For two- and higher-dimensional transforms the code will use SSE2 and AVX
vector instructions for faster execution if these are supported by the CPU and
were enabled during compilation.
)DELIM";
const char *fftn_DS = R"DELIM(
Performs a forward complex FFT.
......@@ -2693,9 +2707,10 @@ PYBIND11_MODULE(pypocketfft, m)
{
using namespace pybind11::literals;
m.def("fftn",&fftn, fftn_DS, "a"_a, "axes"_a=vector<size_t>(1,10000), "fct"_a=1.);
m.def("ifftn",&ifftn, ifftn_DS, "a"_a, "axes"_a=vector<size_t>(1,10000), "fct"_a=1.);
m.def("rfftn",&rfftn, rfftn_DS, "a"_a, "axes"_a=vector<size_t>(1,10000), "fct"_a=1.);
m.def("irfftn",&irfftn, irfftn_DS, "a"_a, "axes"_a=vector<size_t>(1,10000), "lastsize"_a=0, "fct"_a=1.);
m.def("hartley",&hartley, hartley_DS, "a"_a, "axes"_a=vector<size_t>(1,10000), "fct"_a=1.);
m.doc() = pypocketfft_DS;
m.def("fftn",&fftn, fftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.);
m.def("ifftn",&ifftn, ifftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.);
m.def("rfftn",&rfftn, rfftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.);
m.def("irfftn",&irfftn, irfftn_DS, "a"_a, "axes"_a=py::none(), "lastsize"_a=0, "fct"_a=1.);
m.def("hartley",&hartley, hartley_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment