Commit 9dae4cdd authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more

parent 92cfcfdd
......@@ -6,6 +6,7 @@
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#ifdef __GNUC__
#define NOINLINE __attribute__((noinline))
......@@ -2064,6 +2065,76 @@ template<typename T> void pocketfft_general_r(int ndim, const size_t *shape,
it_out.advance();
}
}
template<typename T> void pocketfft_general_r2c(int ndim, const size_t *shape,
const int64_t *stride_in, const int64_t *stride_out, size_t axis,
const T *data_in, cmplx<T> *data_out, T fct)
{
// allocate temporary 1D array storage
arr<T> tdata(shape[axis]);
multiarr a_in(ndim, shape, stride_in), a_out(ndim, shape, stride_out);
pocketfft_r<T> plan(shape[axis]);
multi_iter it_in(a_in, axis), it_out(a_out, axis);
size_t len=shape[axis], s_i=it_in.stride(), s_o=it_out.stride();
while (!it_in.done())
{
const T *d_i = data_in+it_in.offset();
cmplx<T> *d_o = data_out+it_out.offset();
for (size_t i=0; i<len; ++i)
tdata[i] = d_i[i*s_i];
plan.forward(tdata.data(), fct);
d_o[0].r = tdata[0];
d_o[0].i = 0.;
size_t i;
for (i=1; i<len-1; i+=2)
{
size_t io = (i+1)/2;
d_o[io*s_o].r = tdata[i];
d_o[io*s_o].i = tdata[i+1];
}
if (i<len)
{
size_t io = (i+1)/2;
d_o[io*s_o].r = tdata[i];
d_o[io*s_o].i = 0.;
}
it_in.advance();
it_out.advance();
}
}
template<typename T> void pocketfft_general_c2r(int ndim, const size_t *shape_out,
const int64_t *stride_in, const int64_t *stride_out, size_t axis,
const cmplx<T> *data_in, T *data_out, T fct)
{
// allocate temporary 1D array storage
arr<T> tdata(shape_out[axis]);
multiarr a_in(ndim, shape_out, stride_in), a_out(ndim, shape_out, stride_out);
pocketfft_r<T> plan(shape_out[axis]);
multi_iter it_in(a_in, axis), it_out(a_out, axis);
size_t len=shape_out[axis], s_i=it_in.stride(), s_o=it_out.stride();
while (!it_in.done())
{
const cmplx<T> *d_i = data_in+it_in.offset();
T *d_o = data_out+it_out.offset();
tdata[0]=d_i[0].r;
size_t i;
for (i=1; i<len-1; i+=2)
{
size_t ii = (i+1)/2;
tdata[i] = d_i[ii*s_i].r;
tdata[i+1] = d_i[ii*s_i].i;
}
if (i<len)
{
size_t ii = (i+1)/2;
tdata[i] = d_i[ii*s_i].r;
}
plan.backward(tdata.data(), fct);
for (size_t i=0; i<len; ++i)
d_o[i*s_o] = tdata[i];
it_in.advance();
it_out.advance();
}
}
namespace py = pybind11;
......@@ -2072,10 +2143,24 @@ auto c128 = py::dtype("complex128");
auto f32 = py::dtype("float32");
auto f64 = py::dtype("float64");
py::array execute(const py::array &in, bool fwd)
void check_args(const py::array &in, vector<size_t> &axes)
{
if (axes.size()==0)
axes.resize(in.ndim());
for (int i=0; i<in.ndim(); ++i)
axes[i] = i;
if (axes.size()>size_t(in.ndim()))
throw runtime_error("bad axes argument");
for (size_t i=0; i<axes.size(); ++i)
if (axes[i]>=size_t(in.ndim()))
throw runtime_error("invalid axis number");
}
py::array execute(const py::array &in, vector<size_t> axes, bool fwd)
{
check_args(in, axes);
py::array res;
vector<size_t> dims(in.ndim()), axes(in.ndim());
vector<size_t> dims(in.ndim());
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
for (int i=0; i<in.ndim(); ++i)
dims[i] = in.shape(i);
......@@ -2092,36 +2177,39 @@ py::array execute(const py::array &in, bool fwd)
s_o[i] = res.strides(i)/res.itemsize();
if (s_o[i]*res.itemsize() != res.strides(i))
throw runtime_error("weird strides");
axes[i] = i;
}
if (in.dtype().is(c128))
pocketfft_general_c<double>(in.ndim(), dims.data(),
s_i.data(), s_o.data(), in.ndim(),
s_i.data(), s_o.data(), axes.size(),
axes.data(), fwd, (const cmplx<double> *)in.data(),
(cmplx<double> *)res.mutable_data(), 1.);
else
pocketfft_general_c<float>(in.ndim(), dims.data(),
s_i.data(), s_o.data(), in.ndim(),
s_i.data(), s_o.data(), axes.size(),
axes.data(), fwd, (const cmplx<float> *)in.data(),
(cmplx<float> *)res.mutable_data(), 1.);
return res;
}
py::array fftn(const py::array &in)
{ return execute(in, true); }
py::array ifftn(const py::array &in)
{ return execute(in, false); }
py::array fftn(const py::array &a, const vector<size_t> &axes)
{ return execute(a, axes, true); }
py::array ifftn(const py::array &a, const vector<size_t> &axes)
{ return execute(a, axes, false); }
py::array execute_r(const py::array &in, int axis, bool fwd)
py::array rfftn(const py::array &in, int axis)
{
py::array res;
vector<size_t> dims(in.ndim());
vector<size_t> dims_in(in.ndim()), dims_out(in.ndim());
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
for (int i=0; i<in.ndim(); ++i)
dims[i] = in.shape(i);
{
dims_in[i] = in.shape(i);
dims_out[i] = in.shape(i);
}
dims_out[axis] = (dims_out[axis]>>1)+1;
if (in.dtype().is(f64))
res = py::array_t<double>(dims);
res = py::array_t<complex<double>>(dims_out);
else if (in.dtype().is(f32))
res = py::array_t<float>(dims);
res = py::array_t<complex<float>>(dims_out);
else throw runtime_error("unsupported data type");
for (int i=0; i<in.ndim(); ++i)
{
......@@ -2133,27 +2221,71 @@ py::array execute_r(const py::array &in, int axis, bool fwd)
throw runtime_error("weird strides");
}
if (in.dtype().is(f64))
pocketfft_general_r<double>(in.ndim(), dims.data(),
pocketfft_general_r2c<double>(in.ndim(), dims_in.data(),
s_i.data(), s_o.data(),
axis, (const double *)in.data(),
(cmplx<double> *)res.mutable_data(), 1.);
else
pocketfft_general_r2c<float>(in.ndim(), dims_in.data(),
s_i.data(), s_o.data(),
axis, fwd, (const double *)in.data(),
axis, (const float *)in.data(),
(cmplx<float> *)res.mutable_data(), 1.);
return res;
}
py::array irfftn(const py::array &in, int axis, int osize)
{
py::array res;
if (osize/2 + 1 != in.shape(axis))
throw runtime_error("bad output size");
vector<size_t> dims_in(in.ndim()), dims_out(in.ndim());
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
for (int i=0; i<in.ndim(); ++i)
{
dims_in[i] = in.shape(i);
dims_out[i] = in.shape(i);
}
dims_out[axis] = osize;
if (in.dtype().is(c128))
res = py::array_t<double>(dims_out);
else if (in.dtype().is(c64))
res = py::array_t<float>(dims_out);
else throw runtime_error("unsupported data type");
for (int i=0; i<in.ndim(); ++i)
{
s_i[i] = in.strides(i)/in.itemsize();
if (s_i[i]*in.itemsize() != in.strides(i))
throw runtime_error("weird strides");
s_o[i] = res.strides(i)/res.itemsize();
if (s_o[i]*res.itemsize() != res.strides(i))
throw runtime_error("weird strides");
}
if (in.dtype().is(c128))
pocketfft_general_c2r<double>(in.ndim(), dims_out.data(),
s_i.data(), s_o.data(),
axis, (const cmplx<double> *)in.data(),
(double *)res.mutable_data(), 1.);
else
pocketfft_general_r<float>(in.ndim(), dims.data(),
pocketfft_general_c2r<float>(in.ndim(), dims_out.data(),
s_i.data(), s_o.data(),
axis, fwd, (const float *)in.data(),
axis, (const cmplx<float> *)in.data(),
(float *)res.mutable_data(), 1.);
return res;
}
py::array rfftn(const py::array &in, int axis)
{ return execute_r(in, axis, true); }
py::array irfftn(const py::array &in, int axis)
{ return execute_r(in, axis, false); }
void xtest(const py::array &a, const vector<size_t> &s, const vector<size_t> &axes, const string &norm)
{
// for (size_t i=0; i<data.size(); ++i)
// cout << data[i] << endl;
}
} // unnamed namespace
PYBIND11_MODULE(pypocketfft, m)
{
m.def("fftn",&fftn);
m.def("ifftn",&ifftn);
using namespace pybind11::literals;
m.def("fftn",&fftn, "a"_a, "axes"_a=vector<size_t>());
m.def("ifftn",&ifftn, "a"_a, "axes"_a=vector<size_t>());
m.def("rfftn",&rfftn);
m.def("irfftn",&irfftn);
m.def("xtest",&xtest,"a"_a,"s"_a=vector<size_t>(), "axes"_a=vector<size_t>(), "norm"_a="");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment