Skip to content
Snippets Groups Projects
Commit d65f7227 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

better type dispatching; some preparations for float128

parent 8ac40bf5
No related branches found
No related tags found
1 merge request!6Develop
...@@ -31,17 +31,10 @@ namespace py = pybind11; ...@@ -31,17 +31,10 @@ namespace py = pybind11;
auto c64 = py::dtype("complex64"); auto c64 = py::dtype("complex64");
auto c128 = py::dtype("complex128"); auto c128 = py::dtype("complex128");
//auto c256 = py::dtype("complex256");
auto f32 = py::dtype("float32"); auto f32 = py::dtype("float32");
auto f64 = py::dtype("float64"); auto f64 = py::dtype("float64");
//auto f128 = py::dtype("float128");
bool tcheck(const py::array &arr, const py::object &t1, const py::object &t2)
{
if (arr.dtype().is(t1))
return true;
if (arr.dtype().is(t2))
return false;
throw runtime_error("unsupported data type");
}
shape_t copy_shape(const py::array &arr) shape_t copy_shape(const py::array &arr)
{ {
...@@ -77,6 +70,13 @@ shape_t makeaxes(const py::array &in, py::object axes) ...@@ -77,6 +70,13 @@ shape_t makeaxes(const py::array &in, py::object axes)
return tmp; return tmp;
} }
#define DISPATCH(arr, T1, T2, T3, func, args) \
auto dtype = arr.dtype(); \
if (dtype.is(T1)) return func<double> args; \
if (dtype.is(T2)) return func<float> args; \
/* if (dtype.is(T3)) return func<long double> args; */ \
throw runtime_error("unsupported data type");
template<typename T> py::array xfftn_internal(const py::array &in, template<typename T> py::array xfftn_internal(const py::array &in,
const shape_t &axes, double fct, bool inplace, bool fwd) const shape_t &axes, double fct, bool inplace, bool fwd)
{ {
...@@ -91,12 +91,13 @@ template<typename T> py::array xfftn_internal(const py::array &in, ...@@ -91,12 +91,13 @@ template<typename T> py::array xfftn_internal(const py::array &in,
py::array xfftn(const py::array &a, py::object axes, double fct, bool inplace, py::array xfftn(const py::array &a, py::object axes, double fct, bool inplace,
bool fwd) bool fwd)
{ {
return tcheck(a, c128, c64) ? DISPATCH(a, c128, c64, c256, xfftn_internal, (a, makeaxes(a, axes), fct,
xfftn_internal<double>(a, makeaxes(a, axes), fct, inplace, fwd) : inplace, fwd))
xfftn_internal<float> (a, makeaxes(a, axes), fct, inplace, fwd);
} }
py::array fftn(const py::array &a, py::object axes, double fct, bool inplace) py::array fftn(const py::array &a, py::object axes, double fct, bool inplace)
{ return xfftn(a, axes, fct, inplace, true); } { return xfftn(a, axes, fct, inplace, true); }
py::array ifftn(const py::array &a, py::object axes, double fct, bool inplace) py::array ifftn(const py::array &a, py::object axes, double fct, bool inplace)
{ return xfftn(a, axes, fct, inplace, false); } { return xfftn(a, axes, fct, inplace, false); }
...@@ -112,11 +113,12 @@ template<typename T> py::array rfftn_internal(const py::array &in, ...@@ -112,11 +113,12 @@ template<typename T> py::array rfftn_internal(const py::array &in,
reinterpret_cast<complex<T> *>(res.mutable_data()), T(fct)); reinterpret_cast<complex<T> *>(res.mutable_data()), T(fct));
return res; return res;
} }
py::array rfftn(const py::array &in, py::object axes_, double fct) py::array rfftn(const py::array &in, py::object axes_, double fct)
{ {
return tcheck(in, f64, f32) ? rfftn_internal<double>(in, axes_, fct) DISPATCH(in, f64, f32, f128, rfftn_internal, (in, axes_, fct))
: rfftn_internal<float> (in, axes_, fct);
} }
template<typename T> py::array xrfft_scipy(const py::array &in, template<typename T> py::array xrfft_scipy(const py::array &in,
size_t axis, double fct, bool inplace, bool fwd) size_t axis, double fct, bool inplace, bool fwd)
{ {
...@@ -127,18 +129,16 @@ template<typename T> py::array xrfft_scipy(const py::array &in, ...@@ -127,18 +129,16 @@ template<typename T> py::array xrfft_scipy(const py::array &in,
reinterpret_cast<T *>(res.mutable_data()), T(fct)); reinterpret_cast<T *>(res.mutable_data()), T(fct));
return res; return res;
} }
py::array rfft_scipy(const py::array &in, size_t axis, double fct, bool inplace) py::array rfft_scipy(const py::array &in, size_t axis, double fct, bool inplace)
{ {
return tcheck(in, f64, f32) ? DISPATCH(in, f64, f32, f128, xrfft_scipy, (in, axis, fct, inplace, true))
xrfft_scipy<double>(in, axis, fct, inplace, true) :
xrfft_scipy<float> (in, axis, fct, inplace, true);
} }
py::array irfft_scipy(const py::array &in, size_t axis, double fct, py::array irfft_scipy(const py::array &in, size_t axis, double fct,
bool inplace) bool inplace)
{ {
return tcheck(in, f64, f32) ? DISPATCH(in, f64, f32, f128, xrfft_scipy, (in, axis, fct, inplace, false))
xrfft_scipy<double>(in, axis, fct, inplace, false) :
xrfft_scipy<float> (in, axis, fct, inplace, false);
} }
template<typename T> py::array irfftn_internal(const py::array &in, template<typename T> py::array irfftn_internal(const py::array &in,
py::object axes_, size_t lastsize, T fct) py::object axes_, size_t lastsize, T fct)
...@@ -156,12 +156,11 @@ template<typename T> py::array irfftn_internal(const py::array &in, ...@@ -156,12 +156,11 @@ template<typename T> py::array irfftn_internal(const py::array &in,
reinterpret_cast<T *>(res.mutable_data()), T(fct)); reinterpret_cast<T *>(res.mutable_data()), T(fct));
return res; return res;
} }
py::array irfftn(const py::array &in, py::object axes_, size_t lastsize, py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
double fct) double fct)
{ {
return tcheck(in, c128, c64) ? DISPATCH(in, c128, c64, c256, irfftn_internal, (in, axes_, lastsize, fct))
irfftn_internal<double>(in, axes_, lastsize, fct) :
irfftn_internal<float> (in, axes_, lastsize, fct);
} }
template<typename T> py::array hartley_internal(const py::array &in, template<typename T> py::array hartley_internal(const py::array &in,
...@@ -174,12 +173,11 @@ template<typename T> py::array hartley_internal(const py::array &in, ...@@ -174,12 +173,11 @@ template<typename T> py::array hartley_internal(const py::array &in,
reinterpret_cast<T *>(res.mutable_data()), T(fct)); reinterpret_cast<T *>(res.mutable_data()), T(fct));
return res; return res;
} }
py::array hartley(const py::array &in, py::object axes_, double fct, py::array hartley(const py::array &in, py::object axes_, double fct,
bool inplace) bool inplace)
{ {
return tcheck(in, f64, f32) ? DISPATCH(in, f64, f32, f128, hartley_internal, (in, axes_, fct, inplace))
hartley_internal<double>(in, axes_, fct, inplace) :
hartley_internal<float> (in, axes_, fct, inplace);
} }
template<typename T>py::array complex2hartley(const py::array &in, template<typename T>py::array complex2hartley(const py::array &in,
...@@ -224,12 +222,13 @@ template<typename T>py::array complex2hartley(const py::array &in, ...@@ -224,12 +222,13 @@ template<typename T>py::array complex2hartley(const py::array &in,
} }
return out; return out;
} }
py::array mycomplex2hartley(const py::array &in, py::array mycomplex2hartley(const py::array &in,
const py::array &tmp, py::object axes_, bool inplace) const py::array &tmp, py::object axes_, bool inplace)
{ {
return tcheck(in, f64, f32) ? complex2hartley<double>(in, tmp, axes_, inplace) DISPATCH(in, f64, f32, f128, complex2hartley, (in, tmp, axes_, inplace))
: complex2hartley<float> (in, tmp, axes_, inplace);
} }
py::array hartley2(const py::array &in, py::object axes_, double fct, py::array hartley2(const py::array &in, py::object axes_, double fct,
bool inplace) bool inplace)
{ return mycomplex2hartley(in, rfftn(in, axes_, fct), axes_, inplace); } { return mycomplex2hartley(in, rfftn(in, axes_, fct), axes_, inplace); }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment