Commit 320b7b8d authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanups

parent 06d75f2c
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#ifdef __GNUC__ #ifdef __GNUC__
#define NOINLINE __attribute__((noinline)) #define NOINLINE __attribute__((noinline))
...@@ -1047,14 +1049,14 @@ template<typename T0> class pocketfft_c ...@@ -1047,14 +1049,14 @@ template<typename T0> class pocketfft_c
}; };
struct diminfo struct diminfo
{ size_t n, s; }; { size_t n; int64_t s; };
class multiarr class multiarr
{ {
private: private:
vector<diminfo> dim; vector<diminfo> dim;
public: public:
multiarr (size_t ndim_, const size_t *n, const size_t *s) multiarr (size_t ndim_, const size_t *n, const int64_t *s)
{ {
dim.reserve(ndim_); dim.reserve(ndim_);
for (size_t i=0; i<ndim_; ++i) for (size_t i=0; i<ndim_; ++i)
...@@ -1062,7 +1064,7 @@ class multiarr ...@@ -1062,7 +1064,7 @@ class multiarr
} }
size_t ndim() const { return dim.size(); } size_t ndim() const { return dim.size(); }
size_t size(size_t i) const { return dim[i].n; } size_t size(size_t i) const { return dim[i].n; }
size_t stride(size_t i) const { return dim[i].s; } int64_t stride(size_t i) const { return dim[i].s; }
}; };
class multi_iter class multi_iter
...@@ -1070,7 +1072,8 @@ class multi_iter ...@@ -1070,7 +1072,8 @@ class multi_iter
private: private:
vector<diminfo> dim; vector<diminfo> dim;
vector<size_t> pos; vector<size_t> pos;
size_t ofs_, len, str; size_t ofs_, len;
int64_t str;
bool done_; bool done_;
public: public:
...@@ -1102,11 +1105,11 @@ class multi_iter ...@@ -1102,11 +1105,11 @@ class multi_iter
bool done() const { return done_; } bool done() const { return done_; }
size_t offset() const { return ofs_; } size_t offset() const { return ofs_; }
size_t length() const { return len; } size_t length() const { return len; }
size_t stride() const { return str; } int64_t stride() const { return str; }
}; };
template<typename T> void pocketfft_general_c(int ndim, const size_t *shape, template<typename T> void pocketfft_general_c(int ndim, const size_t *shape,
const size_t *stride_in, const size_t *stride_out, int nax, const int64_t *stride_in, const int64_t *stride_out, int nax,
const size_t *axes, bool forward, const cmplx<T> *data_in, const size_t *axes, bool forward, const cmplx<T> *data_in,
cmplx<T> *data_out, T fct) cmplx<T> *data_out, T fct)
{ {
...@@ -1115,7 +1118,7 @@ template<typename T> void pocketfft_general_c(int ndim, const size_t *shape, ...@@ -1115,7 +1118,7 @@ template<typename T> void pocketfft_general_c(int ndim, const size_t *shape,
for (int iax=0; iax<nax; ++iax) for (int iax=0; iax<nax; ++iax)
{ {
bool inplace = false; bool inplace = false;
size_t stride = (iax==0) ? stride_in[axes[iax]] : stride_out[axes[iax]]; int64_t stride = (iax==0) ? stride_in[axes[iax]] : stride_out[axes[iax]];
if (stride==1) if (stride==1)
if ((iax>0) || (data_in==data_out)) if ((iax>0) || (data_in==data_out))
inplace = true; inplace = true;
...@@ -1158,37 +1161,38 @@ template<typename T> void pocketfft_general_c(int ndim, const size_t *shape, ...@@ -1158,37 +1161,38 @@ template<typename T> void pocketfft_general_c(int ndim, const size_t *shape,
a_in = a_out; a_in = a_out;
data_in = data_out; data_in = data_out;
// factor has been applied, use 1 for remaining axes // factor has been applied, use 1 for remaining axes
fct = 1.; fct = T(1);
} }
} }
} // unnamed namespace
#if 1
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
namespace py = pybind11; namespace py = pybind11;
namespace { auto c64 = py::dtype("complex64");
auto c128 = py::dtype("complex128");
py::array execute(const py::array &in, bool fwd) py::array execute(const py::array &in, bool fwd)
{ {
py::array res; py::array res;
vector<size_t> dims(in.ndim()), s_i(in.ndim()), s_o(in.ndim()), axes(in.ndim()); vector<size_t> dims(in.ndim()), axes(in.ndim());
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
for (int i=0; i<in.ndim(); ++i) for (int i=0; i<in.ndim(); ++i)
dims[i] = in.shape(i); dims[i] = in.shape(i);
if (in.dtype().is(py::dtype("complex128"))) if (in.dtype().is(c128))
res = py::array_t<complex<double>>(dims); res = py::array_t<complex<double>>(dims);
else if (in.dtype().is(py::dtype("complex64"))) else if (in.dtype().is(c64))
res = py::array_t<complex<float>>(dims); res = py::array_t<complex<float>>(dims);
else throw runtime_error("unsupported data type"); else throw runtime_error("unsupported data type");
for (int i=0; i<in.ndim(); ++i) for (int i=0; i<in.ndim(); ++i)
{ {
s_i[i] = in.strides(i)/in.itemsize(); s_i[i] = in.strides(i)/in.itemsize();
if (s_i[i]*in.itemsize() != in.strides(i))
throw runtime_error("weird strides");
s_o[i] = res.strides(i)/res.itemsize(); s_o[i] = res.strides(i)/res.itemsize();
if (s_o[i]*res.itemsize() != res.strides(i))
throw runtime_error("weird strides");
axes[i] = i; axes[i] = i;
} }
if (in.dtype().is(py::dtype("complex128"))) if (in.dtype().is(c128))
pocketfft_general_c<double>(in.ndim(), dims.data(), pocketfft_general_c<double>(in.ndim(), dims.data(),
s_i.data(), s_o.data(), in.ndim(), s_i.data(), s_o.data(), in.ndim(),
axes.data(), fwd, (const cmplx<double> *)in.data(), axes.data(), fwd, (const cmplx<double> *)in.data(),
...@@ -1211,19 +1215,3 @@ PYBIND11_MODULE(pypocketfft, m) ...@@ -1211,19 +1215,3 @@ PYBIND11_MODULE(pypocketfft, m)
m.def("fftn",&fftn); m.def("fftn",&fftn);
m.def("ifftn",&ifftn); m.def("ifftn",&ifftn);
} }
#else
int main()
{
int sz=100;
vector<cmplx<double>> data(sz), data2(sz);
size_t shape[] = {sz};
size_t stride[] = {1};
size_t axes[] = {0};
pocketfft_general_c(1, shape,
stride, stride, 1,
axes, true, data.data(),
data2.data(), 1.);
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment