Commit 02db4525 authored by Martin Reinecke's avatar Martin Reinecke

structure

parent 97730839
......@@ -72,6 +72,10 @@ template<typename T> struct arr
size_t size() const { return sz; }
};
//
// twiddle factor section
//
class sincos_2pibyn
{
private:
......@@ -357,6 +361,10 @@ template<typename T> void ROTM90(cmplx<T> &a)
constexpr size_t NFCT=25;
//
// complex FFTPACK transforms
//
template<typename T0> class cfftp
{
private:
......@@ -940,6 +948,9 @@ template<bool bwd, typename T> NOINLINE void pass_all(T c[], T0 fact)
}
};
//
// real-valued FFTPACK transforms
//
template<typename T0> class rfftp
{
......@@ -1740,6 +1751,10 @@ NOINLINE rfftp(size_t length_)
};
//
// complex Bluestein transforms
//
template<typename T0> class fftblue
{
private:
......@@ -1854,6 +1869,10 @@ template<typename T0> class fftblue
}
};
//
// flexible (FFTPACK/Bluestein) complex 1D transform
//
template<typename T0> class pocketfft_c
{
private:
......@@ -1894,6 +1913,11 @@ template<typename T0> class pocketfft_c
size_t length() const { return len; }
};
//
// flexible (FFTPACK/Bluestein) real-valued 1D transform
//
template<typename T0> class pocketfft_r
{
private:
......@@ -1935,6 +1959,10 @@ template<typename T0> class pocketfft_r
size_t length() const { return len; }
};
//
// multi-D infrastructure
//
struct diminfo
{ size_t n; int64_t s; };
class multiarr
......@@ -2345,6 +2373,10 @@ template<typename T> void pocketfft_general_c2r(const vector<size_t> &shape_out,
}
}
//
// Python interface
//
namespace py = pybind11;
auto c64 = py::dtype("complex64");
......@@ -2359,6 +2391,7 @@ vector<size_t> copy_shape(const py::array &arr)
res[i] = arr.shape(i);
return res;
}
vector<int64_t> copy_strides(const py::array &arr)
{
vector<int64_t> res(arr.ndim());
......@@ -2389,32 +2422,29 @@ vector<size_t> makeaxes(const py::array &in, py::object axes)
return tmp;
}
template<typename T> py::array execute(const py::array &in, const vector<size_t> &axes, double fct, bool fwd)
template<typename T> py::array xfftn_internal(const py::array &in,
const vector<size_t> &axes, double fct, bool fwd)
{
auto dims(copy_shape(in));
py::array res = py::array_t<complex<T>>(dims);
auto s_i(copy_strides(in)), s_o(copy_strides(res));
pocketfft_general_c<T>(dims,
s_i, s_o, axes, fwd, (const cmplx<T> *)in.data(),
pocketfft_general_c<T>(dims, s_i, s_o, axes, fwd, (const cmplx<T> *)in.data(),
(cmplx<T> *)res.mutable_data(), fct);
return res;
}
py::array fftn(const py::array &a, py::object axes, double fct)
py::array xfftn(const py::array &a, py::object axes, double fct, bool fwd)
{
if (a.dtype().is(c128))
return execute<double>(a, makeaxes(a, axes), fct, true);
return xfftn_internal<double>(a, makeaxes(a, axes), fct, fwd);
else if (a.dtype().is(c64))
return execute<float>(a, makeaxes(a, axes), fct, true);
return xfftn_internal<float>(a, makeaxes(a, axes), fct, fwd);
throw runtime_error("unsupported data type");
}
py::array fftn(const py::array &a, py::object axes, double fct)
{ return xfftn(a, axes, fct, true); }
py::array ifftn(const py::array &a, py::object axes, double fct)
{
if (a.dtype().is(c128))
return execute<double>(a, makeaxes(a, axes), fct, false);
else if (a.dtype().is(c64))
return execute<float>(a, makeaxes(a, axes), fct, false);
throw runtime_error("unsupported data type");
}
{ return xfftn(a, axes, fct, false); }
template<typename T> py::array rfftn_internal(const py::array &in,
py::object axes_, T fct)
......@@ -2452,7 +2482,7 @@ template<typename T> py::array irfftn_internal(const py::array &in,
vector<size_t> axes2(axes.size()-1);
for (size_t i=0; i<axes2.size(); ++i)
axes2[i] = axes[i];
inter = execute<T>(in, axes2, 1., false);
inter = xfftn_internal<T>(in, axes2, 1., false);
}
else
inter = in;
......@@ -2482,7 +2512,7 @@ py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
template<typename T> py::array hartley_internal(const py::array &in,
py::object axes_, double fct)
{
auto axes = makeaxes(in, axes_);
auto axes(makeaxes(in, axes_));
auto dims(copy_shape(in));
py::array res = py::array_t<T>(dims);
auto s_i(copy_strides(in)), s_o(copy_strides(res));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment