From 46b8e3af62e4963bb01926a2db237a9707d307a7 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Wed, 24 Apr 2019 00:27:17 +0200
Subject: [PATCH] cleanup

---
 pocketfft.cc | 101 +++++++++++++++++++++++++--------------------------
 1 file changed, 50 insertions(+), 51 deletions(-)

diff --git a/pocketfft.cc b/pocketfft.cc
index 6f54dcc..80505b5 100644
--- a/pocketfft.cc
+++ b/pocketfft.cc
@@ -1581,7 +1581,7 @@ template<typename T> NOINLINE void radbg(size_t ido, size_t ip, size_t l1,
 #undef MULPM
 #undef WA
 
-template<typename T> static void copy_and_norm(T *c, T *p1, size_t n, T0 fct)
+template<typename T> void copy_and_norm(T *c, T *p1, size_t n, T0 fct)
   {
   if (p1!=c)
     {
@@ -1963,6 +1963,9 @@ template<typename T0> class pocketfft_r
 // multi-D infrastructure
 //
 
+using shape_t = vector<size_t>;
+using stride_t = vector<int64_t>;
+
 struct diminfo
   { size_t n; int64_t s; };
 class multiarr
@@ -1971,7 +1974,7 @@ class multiarr
     vector<diminfo> dim;
 
   public:
-    multiarr (const vector<size_t> &n, const vector<int64_t> &s)
+    multiarr (const shape_t &n, const stride_t &s)
       {
       dim.reserve(n.size());
       for (size_t i=0; i<n.size(); ++i)
@@ -1986,7 +1989,7 @@ class multi_iter
   {
   public:
     vector<diminfo> dim;
-    vector<size_t> pos;
+    shape_t pos;
     size_t ofs_, len;
     int64_t str;
     int64_t rem;
@@ -2054,7 +2057,7 @@ template<> struct VTYPE<float>
   { using type = __m128; };
 #endif
 
-template<typename T> arr<char> alloc_tmp(const vector<size_t> &shape,
+template<typename T> arr<char> alloc_tmp(const shape_t &shape,
   size_t tmpsize, size_t elemsize)
   {
 #ifdef HAVE_VECSUPPORT
@@ -2064,8 +2067,8 @@ template<typename T> arr<char> alloc_tmp(const vector<size_t> &shape,
 #endif
   return arr<char>(tmpsize*elemsize);
   }
-template<typename T> arr<char> alloc_tmp(const vector<size_t> &shape,
-  const vector<size_t> &axes, size_t elemsize)
+template<typename T> arr<char> alloc_tmp(const shape_t &shape,
+  const shape_t &axes, size_t elemsize)
   {
   size_t fullsize=1;
   size_t ndim = shape.size();
@@ -2078,9 +2081,9 @@ template<typename T> arr<char> alloc_tmp(const vector<size_t> &shape,
   return alloc_tmp<T>(shape, tmpsize, elemsize);
   }
 
-template<typename T> void pocketfft_general_c(const vector<size_t> &shape,
-  const vector<int64_t> &stride_in, const vector<int64_t> &stride_out,
-  const vector<size_t> &axes, bool forward, const cmplx<T> *data_in,
+template<typename T> void pocketfft_general_c(const shape_t &shape,
+  const stride_t &stride_in, const stride_t &stride_out,
+  const shape_t &axes, bool forward, const cmplx<T> *data_in,
   cmplx<T> *data_out, T fct)
   {
   auto storage = alloc_tmp<T>(shape, axes, sizeof(cmplx<T>));
@@ -2144,9 +2147,9 @@ template<typename T> void pocketfft_general_c(const vector<size_t> &shape,
     }
   }
 
-template<typename T> void pocketfft_general_hartley(const vector<size_t> &shape,
-  const vector<int64_t> &stride_in, const vector<int64_t> &stride_out,
-  const vector<size_t> &axes, const T *data_in, T *data_out, T fct)
+template<typename T> void pocketfft_general_hartley(const shape_t &shape,
+  const stride_t &stride_in, const stride_t &stride_out,
+  const shape_t &axes, const T *data_in, T *data_out, T fct)
   {
   auto storage = alloc_tmp<T>(shape, axes, sizeof(T));
 #ifdef HAVE_VECSUPPORT
@@ -2218,8 +2221,8 @@ template<typename T> void pocketfft_general_hartley(const vector<size_t> &shape,
     }
   }
 
-template<typename T> void pocketfft_general_r2c(const vector<size_t> &shape,
-  const vector<int64_t> &stride_in, const vector<int64_t> &stride_out, size_t axis,
+template<typename T> void pocketfft_general_r2c(const shape_t &shape,
+  const stride_t &stride_in, const stride_t &stride_out, size_t axis,
   const T *data_in, cmplx<T> *data_out, T fct)
   {
   auto storage = alloc_tmp<T>(shape, shape[axis], sizeof(T));
@@ -2299,8 +2302,8 @@ template<typename T> void pocketfft_general_r2c(const vector<size_t> &shape,
     it_out.advance();
     }
   }
-template<typename T> void pocketfft_general_c2r(const vector<size_t> &shape_out,
-  const vector<int64_t> &stride_in, const vector<int64_t> &stride_out, size_t axis,
+template<typename T> void pocketfft_general_c2r(const shape_t &shape_out,
+  const stride_t &stride_in, const stride_t &stride_out, size_t axis,
   const cmplx<T> *data_in, T *data_out, T fct)
   {
   auto storage = alloc_tmp<T>(shape_out, shape_out[axis], sizeof(T));
@@ -2384,17 +2387,26 @@ auto c128 = py::dtype("complex128");
 auto f32 = py::dtype("float32");
 auto f64 = py::dtype("float64");
 
-vector<size_t> copy_shape(const py::array &arr)
+bool tcheck(const py::array &arr, const py::object &t1, const py::object &t2)
+  {
+  if (arr.dtype().is(t1))
+    return true;
+  if (arr.dtype().is(t2))
+    return false;
+  throw runtime_error("unsupported data type");
+  }
+
+shape_t copy_shape(const py::array &arr)
   {
-  vector<size_t> res(arr.ndim());
+  shape_t res(arr.ndim());
   for (size_t i=0; i<res.size(); ++i)
     res[i] = arr.shape(i);
   return res;
   }
 
-vector<int64_t> copy_strides(const py::array &arr)
+stride_t copy_strides(const py::array &arr)
   {
-  vector<int64_t> res(arr.ndim());
+  stride_t res(arr.ndim());
   for (size_t i=0; i<res.size(); ++i)
     {
     res[i] = arr.strides(i)/arr.itemsize();
@@ -2404,16 +2416,16 @@ vector<int64_t> copy_strides(const py::array &arr)
   return res;
   }
 
-vector<size_t> makeaxes(const py::array &in, py::object axes)
+shape_t makeaxes(const py::array &in, py::object axes)
   {
   if (axes.is(py::none()))
     {
-    vector<size_t> res(in.ndim());
+    shape_t res(in.ndim());
     for (size_t i=0; i<res.size(); ++i)
       res[i]=i;
     return res;
     }
-  auto tmp=axes.cast<vector<size_t>>();
+  auto tmp=axes.cast<shape_t>();
   if ((tmp.size()>size_t(in.ndim())) || (tmp.size()==0))
     throw runtime_error("bad axes argument");
   for (size_t i=0; i<tmp.size(); ++i)
@@ -2423,7 +2435,7 @@ vector<size_t> makeaxes(const py::array &in, py::object axes)
   }
 
 template<typename T> py::array xfftn_internal(const py::array &in,
-  const vector<size_t> &axes, double fct, bool fwd)
+  const shape_t &axes, double fct, bool fwd)
   {
   auto dims(copy_shape(in));
   py::array res = py::array_t<complex<T>>(dims);
@@ -2435,11 +2447,9 @@ template<typename T> py::array xfftn_internal(const py::array &in,
 
 py::array xfftn(const py::array &a, py::object axes, double fct, bool fwd)
   {
-  if (a.dtype().is(c128))
-    return xfftn_internal<double>(a, makeaxes(a, axes), fct, fwd);
-  else if (a.dtype().is(c64))
-    return xfftn_internal<float>(a, makeaxes(a, axes), fct, fwd);
-  throw runtime_error("unsupported data type");
+  return tcheck(a, c128, c64) ?
+    xfftn_internal<double>(a, makeaxes(a, axes), fct, fwd) :
+    xfftn_internal<float> (a, makeaxes(a, axes), fct, fwd);
   }
 py::array fftn(const py::array &a, py::object axes, double fct)
   { return xfftn(a, axes, fct, true); }
@@ -2457,7 +2467,7 @@ template<typename T> py::array rfftn_internal(const py::array &in,
   pocketfft_general_r2c<T>(dims_in, s_i, s_o, axes.back(), (const T *)in.data(),
     (cmplx<T> *)res.mutable_data(), fct);
   if (axes.size()==1) return res;
-  vector<size_t> axes2(axes.size()-1);
+  shape_t axes2(axes.size()-1);
   for (size_t i=0; i<axes2.size(); ++i)
     axes2[i] = axes[i];
   pocketfft_general_c<T>(dims_out, s_o, s_o, axes2, true,
@@ -2466,11 +2476,8 @@ template<typename T> py::array rfftn_internal(const py::array &in,
   }
 py::array rfftn(const py::array &in, py::object axes_, double fct)
   {
-  if (in.dtype().is(f64))
-    return rfftn_internal<double>(in, axes_, fct);
-  else if (in.dtype().is(f32))
-    return rfftn_internal<float>(in, axes_, fct);
-  else throw runtime_error("unsupported data type");
+  return tcheck(in, f64, f32) ? rfftn_internal<double>(in, axes_, fct)
+                              : rfftn_internal<float> (in, axes_, fct);
   }
 template<typename T> py::array irfftn_internal(const py::array &in,
   py::object axes_, size_t lastsize, T fct)
@@ -2479,7 +2486,7 @@ template<typename T> py::array irfftn_internal(const py::array &in,
   py::array inter;
   if (axes.size()>1) // need to do complex transforms
     {
-    vector<size_t> axes2(axes.size()-1);
+    shape_t axes2(axes.size()-1);
     for (size_t i=0; i<axes2.size(); ++i)
       axes2[i] = axes[i];
     inter = xfftn_internal<T>(in, axes2, 1., false);
@@ -2502,11 +2509,9 @@ template<typename T> py::array irfftn_internal(const py::array &in,
 py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
   double fct)
   {
-  if (in.dtype().is(c128))
-    return irfftn_internal<double>(in, axes_, lastsize, fct);
-  else if (in.dtype().is(c64))
-    return irfftn_internal<float>(in, axes_, lastsize, fct);
-  throw runtime_error("unsupported data type");
+  return tcheck(in, c128, c64) ?
+    irfftn_internal<double>(in, axes_, lastsize, fct) :
+    irfftn_internal<float> (in, axes_, lastsize, fct);
   }
 
 template<typename T> py::array hartley_internal(const py::array &in,
@@ -2522,11 +2527,8 @@ template<typename T> py::array hartley_internal(const py::array &in,
   }
 py::array hartley(const py::array &in, py::object axes_, double fct)
   {
-  if (in.dtype().is(f64))
-    return hartley_internal<double>(in, axes_, fct);
-  else if (in.dtype().is(f32))
-    return hartley_internal<float>(in, axes_, fct);
-  throw runtime_error("unsupported data type");
+  return tcheck(in, f64, f32) ? hartley_internal<double>(in, axes_, fct)
+                              : hartley_internal<float> (in, axes_, fct);
   }
 
 template<typename T>py::array complex2hartley(const py::array &in,
@@ -2579,11 +2581,8 @@ template<typename T>py::array complex2hartley(const py::array &in,
 py::array mycomplex2hartley(const py::array &in,
   const py::array &tmp, py::object axes_)
   {
-  if (in.dtype().is(f64))
-    return complex2hartley<double>(in, tmp, axes_);
-  else if (in.dtype().is(f32))
-    return complex2hartley<float>(in, tmp, axes_);
-  else throw runtime_error("unsupported data type");
+  return tcheck(in, f64, f32) ? complex2hartley<double>(in, tmp, axes_)
+                              : complex2hartley<float> (in, tmp, axes_);
   }
 py::array hartley2(const py::array &in, py::object axes_, double fct)
   { return mycomplex2hartley(in, rfftn(in, axes_, fct), axes_); }
-- 
GitLab