Commit 320b7b8d authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanups

parent 06d75f2c
......@@ -4,6 +4,8 @@
#include <iostream>
#include <memory>
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#ifdef __GNUC__
#define NOINLINE __attribute__((noinline))
......@@ -1047,14 +1049,14 @@ template<typename T0> class pocketfft_c
};
struct diminfo
{ size_t n, s; };
{ size_t n; int64_t s; };
class multiarr
{
private:
vector<diminfo> dim;
public:
multiarr (size_t ndim_, const size_t *n, const size_t *s)
multiarr (size_t ndim_, const size_t *n, const int64_t *s)
{
dim.reserve(ndim_);
for (size_t i=0; i<ndim_; ++i)
......@@ -1062,7 +1064,7 @@ class multiarr
}
size_t ndim() const { return dim.size(); }
size_t size(size_t i) const { return dim[i].n; }
size_t stride(size_t i) const { return dim[i].s; }
int64_t stride(size_t i) const { return dim[i].s; }
};
class multi_iter
......@@ -1070,7 +1072,8 @@ class multi_iter
private:
vector<diminfo> dim;
vector<size_t> pos;
size_t ofs_, len, str;
size_t ofs_, len;
int64_t str;
bool done_;
public:
......@@ -1102,11 +1105,11 @@ class multi_iter
bool done() const { return done_; }
size_t offset() const { return ofs_; }
size_t length() const { return len; }
size_t stride() const { return str; }
int64_t stride() const { return str; }
};
template<typename T> void pocketfft_general_c(int ndim, const size_t *shape,
const size_t *stride_in, const size_t *stride_out, int nax,
const int64_t *stride_in, const int64_t *stride_out, int nax,
const size_t *axes, bool forward, const cmplx<T> *data_in,
cmplx<T> *data_out, T fct)
{
......@@ -1115,7 +1118,7 @@ template<typename T> void pocketfft_general_c(int ndim, const size_t *shape,
for (int iax=0; iax<nax; ++iax)
{
bool inplace = false;
size_t stride = (iax==0) ? stride_in[axes[iax]] : stride_out[axes[iax]];
int64_t stride = (iax==0) ? stride_in[axes[iax]] : stride_out[axes[iax]];
if (stride==1)
if ((iax>0) || (data_in==data_out))
inplace = true;
......@@ -1158,37 +1161,38 @@ template<typename T> void pocketfft_general_c(int ndim, const size_t *shape,
a_in = a_out;
data_in = data_out;
// factor has been applied, use 1 for remaining axes
fct = 1.;
fct = T(1);
}
}
} // unnamed namespace
#if 1
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
namespace py = pybind11;
namespace {
auto c64 = py::dtype("complex64");
auto c128 = py::dtype("complex128");
py::array execute(const py::array &in, bool fwd)
{
py::array res;
vector<size_t> dims(in.ndim()), s_i(in.ndim()), s_o(in.ndim()), axes(in.ndim());
vector<size_t> dims(in.ndim()), axes(in.ndim());
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
for (int i=0; i<in.ndim(); ++i)
dims[i] = in.shape(i);
if (in.dtype().is(py::dtype("complex128")))
if (in.dtype().is(c128))
res = py::array_t<complex<double>>(dims);
else if (in.dtype().is(py::dtype("complex64")))
else if (in.dtype().is(c64))
res = py::array_t<complex<float>>(dims);
else throw runtime_error("unsupported data type");
for (int i=0; i<in.ndim(); ++i)
{
s_i[i] = in.strides(i)/in.itemsize();
if (s_i[i]*in.itemsize() != in.strides(i))
throw runtime_error("weird strides");
s_o[i] = res.strides(i)/res.itemsize();
if (s_o[i]*res.itemsize() != res.strides(i))
throw runtime_error("weird strides");
axes[i] = i;
}
if (in.dtype().is(py::dtype("complex128")))
if (in.dtype().is(c128))
pocketfft_general_c<double>(in.ndim(), dims.data(),
s_i.data(), s_o.data(), in.ndim(),
axes.data(), fwd, (const cmplx<double> *)in.data(),
......@@ -1211,19 +1215,3 @@ PYBIND11_MODULE(pypocketfft, m)
m.def("fftn",&fftn);
m.def("ifftn",&ifftn);
}
#else
int main()
{
int sz=100;
vector<cmplx<double>> data(sz), data2(sz);
size_t shape[] = {sz};
size_t stride[] = {1};
size_t axes[] = {0};
pocketfft_general_c(1, shape,
stride, stride, 1,
axes, true, data.data(),
data2.data(), 1.);
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment