From d65f7227fa9bb42ff237eebe172e8973742c7772 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Fri, 10 May 2019 19:59:10 +0200
Subject: [PATCH] better type dispatching; some preparations for float128

---
 pypocketfft.cc | 55 +++++++++++++++++++++++++-------------------------
 1 file changed, 27 insertions(+), 28 deletions(-)

diff --git a/pypocketfft.cc b/pypocketfft.cc
index e96a0da..2b4b873 100644
--- a/pypocketfft.cc
+++ b/pypocketfft.cc
@@ -31,17 +31,10 @@ namespace py = pybind11;
 
 auto c64 = py::dtype("complex64");
 auto c128 = py::dtype("complex128");
+//auto c256 = py::dtype("complex256");
 auto f32 = py::dtype("float32");
 auto f64 = py::dtype("float64");
-
-bool tcheck(const py::array &arr, const py::object &t1, const py::object &t2)
-  {
-  if (arr.dtype().is(t1))
-    return true;
-  if (arr.dtype().is(t2))
-    return false;
-  throw runtime_error("unsupported data type");
-  }
+//auto f128 = py::dtype("float128");
 
 shape_t copy_shape(const py::array &arr)
   {
@@ -77,6 +70,13 @@ shape_t makeaxes(const py::array &in, py::object axes)
   return tmp;
   }
 
+#define DISPATCH(arr, T1, T2, T3, func, args) \
+  auto dtype = arr.dtype(); \
+  if (dtype.is(T1)) return func<double> args; \
+  if (dtype.is(T2)) return func<float> args; \
+/*  if (dtype.is(T3)) return func<long double> args; */ \
+  throw runtime_error("unsupported data type");
+
 template<typename T> py::array xfftn_internal(const py::array &in,
   const shape_t &axes, double fct, bool inplace, bool fwd)
   {
@@ -91,12 +91,13 @@ template<typename T> py::array xfftn_internal(const py::array &in,
 py::array xfftn(const py::array &a, py::object axes, double fct, bool inplace,
   bool fwd)
   {
-  return tcheck(a, c128, c64) ?
-    xfftn_internal<double>(a, makeaxes(a, axes), fct, inplace, fwd) :
-    xfftn_internal<float> (a, makeaxes(a, axes), fct, inplace, fwd);
+  DISPATCH(a, c128, c64, c256, xfftn_internal, (a, makeaxes(a, axes), fct,
+           inplace, fwd))
   }
+
 py::array fftn(const py::array &a, py::object axes, double fct, bool inplace)
   { return xfftn(a, axes, fct, inplace, true); }
+
 py::array ifftn(const py::array &a, py::object axes, double fct, bool inplace)
   { return xfftn(a, axes, fct, inplace, false); }
 
@@ -112,11 +113,12 @@ template<typename T> py::array rfftn_internal(const py::array &in,
     reinterpret_cast<complex<T> *>(res.mutable_data()), T(fct));
   return res;
   }
+
 py::array rfftn(const py::array &in, py::object axes_, double fct)
   {
-  return tcheck(in, f64, f32) ? rfftn_internal<double>(in, axes_, fct)
-                              : rfftn_internal<float> (in, axes_, fct);
+  DISPATCH(in, f64, f32, f128, rfftn_internal, (in, axes_, fct))
   }
+
 template<typename T> py::array xrfft_scipy(const py::array &in,
   size_t axis, double fct, bool inplace, bool fwd)
   {
@@ -127,18 +129,16 @@ template<typename T> py::array xrfft_scipy(const py::array &in,
     reinterpret_cast<T *>(res.mutable_data()), T(fct));
   return res;
   }
+
 py::array rfft_scipy(const py::array &in, size_t axis, double fct, bool inplace)
   {
-  return tcheck(in, f64, f32) ?
-    xrfft_scipy<double>(in, axis, fct, inplace, true) :
-    xrfft_scipy<float> (in, axis, fct, inplace, true);
+  DISPATCH(in, f64, f32, f128, xrfft_scipy, (in, axis, fct, inplace, true))
   }
+
 py::array irfft_scipy(const py::array &in, size_t axis, double fct,
   bool inplace)
   {
-  return tcheck(in, f64, f32) ?
-    xrfft_scipy<double>(in, axis, fct, inplace, false) :
-    xrfft_scipy<float> (in, axis, fct, inplace, false);
+  DISPATCH(in, f64, f32, f128, xrfft_scipy, (in, axis, fct, inplace, false))
   }
 template<typename T> py::array irfftn_internal(const py::array &in,
   py::object axes_, size_t lastsize, T fct)
@@ -156,12 +156,11 @@ template<typename T> py::array irfftn_internal(const py::array &in,
     reinterpret_cast<T *>(res.mutable_data()), T(fct));
   return res;
   }
+
 py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
   double fct)
   {
-  return tcheck(in, c128, c64) ?
-    irfftn_internal<double>(in, axes_, lastsize, fct) :
-    irfftn_internal<float> (in, axes_, lastsize, fct);
+  DISPATCH(in, c128, c64, c256, irfftn_internal, (in, axes_, lastsize, fct))
   }
 
 template<typename T> py::array hartley_internal(const py::array &in,
@@ -174,12 +173,11 @@ template<typename T> py::array hartley_internal(const py::array &in,
     reinterpret_cast<T *>(res.mutable_data()), T(fct));
   return res;
   }
+
 py::array hartley(const py::array &in, py::object axes_, double fct,
   bool inplace)
   {
-  return tcheck(in, f64, f32) ?
-    hartley_internal<double>(in, axes_, fct, inplace) :
-    hartley_internal<float> (in, axes_, fct, inplace);
+  DISPATCH(in, f64, f32, f128, hartley_internal, (in, axes_, fct, inplace))
   }
 
 template<typename T>py::array complex2hartley(const py::array &in,
@@ -224,12 +222,13 @@ template<typename T>py::array complex2hartley(const py::array &in,
     }
   return out;
   }
+
 py::array mycomplex2hartley(const py::array &in,
   const py::array &tmp, py::object axes_, bool inplace)
   {
-  return tcheck(in, f64, f32) ? complex2hartley<double>(in, tmp, axes_, inplace)
-                              : complex2hartley<float> (in, tmp, axes_, inplace);
+  DISPATCH(in, f64, f32, f128, complex2hartley, (in, tmp, axes_, inplace))
   }
+
 py::array hartley2(const py::array &in, py::object axes_, double fct,
   bool inplace)
   { return mycomplex2hartley(in, rfftn(in, axes_, fct), axes_, inplace); }
-- 
GitLab