Commit d65f7227 authored by Martin Reinecke's avatar Martin Reinecke

better type dispatching; some preparations for float128

parent 8ac40bf5
......@@ -31,17 +31,10 @@ namespace py = pybind11;
auto c64 = py::dtype("complex64");
auto c128 = py::dtype("complex128");
//auto c256 = py::dtype("complex256");
auto f32 = py::dtype("float32");
auto f64 = py::dtype("float64");
bool tcheck(const py::array &arr, const py::object &t1, const py::object &t2)
{
if (arr.dtype().is(t1))
return true;
if (arr.dtype().is(t2))
return false;
throw runtime_error("unsupported data type");
}
//auto f128 = py::dtype("float128");
shape_t copy_shape(const py::array &arr)
{
......@@ -77,6 +70,13 @@ shape_t makeaxes(const py::array &in, py::object axes)
return tmp;
}
#define DISPATCH(arr, T1, T2, T3, func, args) \
auto dtype = arr.dtype(); \
if (dtype.is(T1)) return func<double> args; \
if (dtype.is(T2)) return func<float> args; \
/* if (dtype.is(T3)) return func<long double> args; */ \
throw runtime_error("unsupported data type");
template<typename T> py::array xfftn_internal(const py::array &in,
const shape_t &axes, double fct, bool inplace, bool fwd)
{
......@@ -91,12 +91,13 @@ template<typename T> py::array xfftn_internal(const py::array &in,
py::array xfftn(const py::array &a, py::object axes, double fct, bool inplace,
bool fwd)
{
return tcheck(a, c128, c64) ?
xfftn_internal<double>(a, makeaxes(a, axes), fct, inplace, fwd) :
xfftn_internal<float> (a, makeaxes(a, axes), fct, inplace, fwd);
DISPATCH(a, c128, c64, c256, xfftn_internal, (a, makeaxes(a, axes), fct,
inplace, fwd))
}
py::array fftn(const py::array &a, py::object axes, double fct, bool inplace)
{ return xfftn(a, axes, fct, inplace, true); }
py::array ifftn(const py::array &a, py::object axes, double fct, bool inplace)
{ return xfftn(a, axes, fct, inplace, false); }
......@@ -112,11 +113,12 @@ template<typename T> py::array rfftn_internal(const py::array &in,
reinterpret_cast<complex<T> *>(res.mutable_data()), T(fct));
return res;
}
py::array rfftn(const py::array &in, py::object axes_, double fct)
{
return tcheck(in, f64, f32) ? rfftn_internal<double>(in, axes_, fct)
: rfftn_internal<float> (in, axes_, fct);
DISPATCH(in, f64, f32, f128, rfftn_internal, (in, axes_, fct))
}
template<typename T> py::array xrfft_scipy(const py::array &in,
size_t axis, double fct, bool inplace, bool fwd)
{
......@@ -127,18 +129,16 @@ template<typename T> py::array xrfft_scipy(const py::array &in,
reinterpret_cast<T *>(res.mutable_data()), T(fct));
return res;
}
py::array rfft_scipy(const py::array &in, size_t axis, double fct, bool inplace)
{
return tcheck(in, f64, f32) ?
xrfft_scipy<double>(in, axis, fct, inplace, true) :
xrfft_scipy<float> (in, axis, fct, inplace, true);
DISPATCH(in, f64, f32, f128, xrfft_scipy, (in, axis, fct, inplace, true))
}
py::array irfft_scipy(const py::array &in, size_t axis, double fct,
bool inplace)
{
return tcheck(in, f64, f32) ?
xrfft_scipy<double>(in, axis, fct, inplace, false) :
xrfft_scipy<float> (in, axis, fct, inplace, false);
DISPATCH(in, f64, f32, f128, xrfft_scipy, (in, axis, fct, inplace, false))
}
template<typename T> py::array irfftn_internal(const py::array &in,
py::object axes_, size_t lastsize, T fct)
......@@ -156,12 +156,11 @@ template<typename T> py::array irfftn_internal(const py::array &in,
reinterpret_cast<T *>(res.mutable_data()), T(fct));
return res;
}
py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
double fct)
{
return tcheck(in, c128, c64) ?
irfftn_internal<double>(in, axes_, lastsize, fct) :
irfftn_internal<float> (in, axes_, lastsize, fct);
DISPATCH(in, c128, c64, c256, irfftn_internal, (in, axes_, lastsize, fct))
}
template<typename T> py::array hartley_internal(const py::array &in,
......@@ -174,12 +173,11 @@ template<typename T> py::array hartley_internal(const py::array &in,
reinterpret_cast<T *>(res.mutable_data()), T(fct));
return res;
}
py::array hartley(const py::array &in, py::object axes_, double fct,
bool inplace)
{
return tcheck(in, f64, f32) ?
hartley_internal<double>(in, axes_, fct, inplace) :
hartley_internal<float> (in, axes_, fct, inplace);
DISPATCH(in, f64, f32, f128, hartley_internal, (in, axes_, fct, inplace))
}
template<typename T>py::array complex2hartley(const py::array &in,
......@@ -224,12 +222,13 @@ template<typename T>py::array complex2hartley(const py::array &in,
}
return out;
}
py::array mycomplex2hartley(const py::array &in,
const py::array &tmp, py::object axes_, bool inplace)
{
return tcheck(in, f64, f32) ? complex2hartley<double>(in, tmp, axes_, inplace)
: complex2hartley<float> (in, tmp, axes_, inplace);
DISPATCH(in, f64, f32, f128, complex2hartley, (in, tmp, axes_, inplace))
}
py::array hartley2(const py::array &in, py::object axes_, double fct,
bool inplace)
{ return mycomplex2hartley(in, rfftn(in, axes_, fct), axes_, inplace); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment