Commit 46b8e3af authored by Martin Reinecke's avatar Martin Reinecke

cleanup

parent 02db4525
......@@ -1581,7 +1581,7 @@ template<typename T> NOINLINE void radbg(size_t ido, size_t ip, size_t l1,
#undef MULPM
#undef WA
template<typename T> static void copy_and_norm(T *c, T *p1, size_t n, T0 fct)
template<typename T> void copy_and_norm(T *c, T *p1, size_t n, T0 fct)
{
if (p1!=c)
{
......@@ -1963,6 +1963,9 @@ template<typename T0> class pocketfft_r
// multi-D infrastructure
//
using shape_t = vector<size_t>;
using stride_t = vector<int64_t>;
struct diminfo
{ size_t n; int64_t s; };
class multiarr
......@@ -1971,7 +1974,7 @@ class multiarr
vector<diminfo> dim;
public:
multiarr (const vector<size_t> &n, const vector<int64_t> &s)
multiarr (const shape_t &n, const stride_t &s)
{
dim.reserve(n.size());
for (size_t i=0; i<n.size(); ++i)
......@@ -1986,7 +1989,7 @@ class multi_iter
{
public:
vector<diminfo> dim;
vector<size_t> pos;
shape_t pos;
size_t ofs_, len;
int64_t str;
int64_t rem;
......@@ -2054,7 +2057,7 @@ template<> struct VTYPE<float>
{ using type = __m128; };
#endif
template<typename T> arr<char> alloc_tmp(const vector<size_t> &shape,
template<typename T> arr<char> alloc_tmp(const shape_t &shape,
size_t tmpsize, size_t elemsize)
{
#ifdef HAVE_VECSUPPORT
......@@ -2064,8 +2067,8 @@ template<typename T> arr<char> alloc_tmp(const vector<size_t> &shape,
#endif
return arr<char>(tmpsize*elemsize);
}
template<typename T> arr<char> alloc_tmp(const vector<size_t> &shape,
const vector<size_t> &axes, size_t elemsize)
template<typename T> arr<char> alloc_tmp(const shape_t &shape,
const shape_t &axes, size_t elemsize)
{
size_t fullsize=1;
size_t ndim = shape.size();
......@@ -2078,9 +2081,9 @@ template<typename T> arr<char> alloc_tmp(const vector<size_t> &shape,
return alloc_tmp<T>(shape, tmpsize, elemsize);
}
template<typename T> void pocketfft_general_c(const vector<size_t> &shape,
const vector<int64_t> &stride_in, const vector<int64_t> &stride_out,
const vector<size_t> &axes, bool forward, const cmplx<T> *data_in,
template<typename T> void pocketfft_general_c(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out,
const shape_t &axes, bool forward, const cmplx<T> *data_in,
cmplx<T> *data_out, T fct)
{
auto storage = alloc_tmp<T>(shape, axes, sizeof(cmplx<T>));
......@@ -2144,9 +2147,9 @@ template<typename T> void pocketfft_general_c(const vector<size_t> &shape,
}
}
template<typename T> void pocketfft_general_hartley(const vector<size_t> &shape,
const vector<int64_t> &stride_in, const vector<int64_t> &stride_out,
const vector<size_t> &axes, const T *data_in, T *data_out, T fct)
template<typename T> void pocketfft_general_hartley(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out,
const shape_t &axes, const T *data_in, T *data_out, T fct)
{
auto storage = alloc_tmp<T>(shape, axes, sizeof(T));
#ifdef HAVE_VECSUPPORT
......@@ -2218,8 +2221,8 @@ template<typename T> void pocketfft_general_hartley(const vector<size_t> &shape,
}
}
template<typename T> void pocketfft_general_r2c(const vector<size_t> &shape,
const vector<int64_t> &stride_in, const vector<int64_t> &stride_out, size_t axis,
template<typename T> void pocketfft_general_r2c(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, size_t axis,
const T *data_in, cmplx<T> *data_out, T fct)
{
auto storage = alloc_tmp<T>(shape, shape[axis], sizeof(T));
......@@ -2299,8 +2302,8 @@ template<typename T> void pocketfft_general_r2c(const vector<size_t> &shape,
it_out.advance();
}
}
template<typename T> void pocketfft_general_c2r(const vector<size_t> &shape_out,
const vector<int64_t> &stride_in, const vector<int64_t> &stride_out, size_t axis,
template<typename T> void pocketfft_general_c2r(const shape_t &shape_out,
const stride_t &stride_in, const stride_t &stride_out, size_t axis,
const cmplx<T> *data_in, T *data_out, T fct)
{
auto storage = alloc_tmp<T>(shape_out, shape_out[axis], sizeof(T));
......@@ -2384,17 +2387,26 @@ auto c128 = py::dtype("complex128");
auto f32 = py::dtype("float32");
auto f64 = py::dtype("float64");
vector<size_t> copy_shape(const py::array &arr)
bool tcheck(const py::array &arr, const py::object &t1, const py::object &t2)
{
if (arr.dtype().is(t1))
return true;
if (arr.dtype().is(t2))
return false;
throw runtime_error("unsupported data type");
}
shape_t copy_shape(const py::array &arr)
{
vector<size_t> res(arr.ndim());
shape_t res(arr.ndim());
for (size_t i=0; i<res.size(); ++i)
res[i] = arr.shape(i);
return res;
}
vector<int64_t> copy_strides(const py::array &arr)
stride_t copy_strides(const py::array &arr)
{
vector<int64_t> res(arr.ndim());
stride_t res(arr.ndim());
for (size_t i=0; i<res.size(); ++i)
{
res[i] = arr.strides(i)/arr.itemsize();
......@@ -2404,16 +2416,16 @@ vector<int64_t> copy_strides(const py::array &arr)
return res;
}
vector<size_t> makeaxes(const py::array &in, py::object axes)
shape_t makeaxes(const py::array &in, py::object axes)
{
if (axes.is(py::none()))
{
vector<size_t> res(in.ndim());
shape_t res(in.ndim());
for (size_t i=0; i<res.size(); ++i)
res[i]=i;
return res;
}
auto tmp=axes.cast<vector<size_t>>();
auto tmp=axes.cast<shape_t>();
if ((tmp.size()>size_t(in.ndim())) || (tmp.size()==0))
throw runtime_error("bad axes argument");
for (size_t i=0; i<tmp.size(); ++i)
......@@ -2423,7 +2435,7 @@ vector<size_t> makeaxes(const py::array &in, py::object axes)
}
template<typename T> py::array xfftn_internal(const py::array &in,
const vector<size_t> &axes, double fct, bool fwd)
const shape_t &axes, double fct, bool fwd)
{
auto dims(copy_shape(in));
py::array res = py::array_t<complex<T>>(dims);
......@@ -2435,11 +2447,9 @@ template<typename T> py::array xfftn_internal(const py::array &in,
py::array xfftn(const py::array &a, py::object axes, double fct, bool fwd)
{
if (a.dtype().is(c128))
return xfftn_internal<double>(a, makeaxes(a, axes), fct, fwd);
else if (a.dtype().is(c64))
return xfftn_internal<float>(a, makeaxes(a, axes), fct, fwd);
throw runtime_error("unsupported data type");
return tcheck(a, c128, c64) ?
xfftn_internal<double>(a, makeaxes(a, axes), fct, fwd) :
xfftn_internal<float> (a, makeaxes(a, axes), fct, fwd);
}
py::array fftn(const py::array &a, py::object axes, double fct)
{ return xfftn(a, axes, fct, true); }
......@@ -2457,7 +2467,7 @@ template<typename T> py::array rfftn_internal(const py::array &in,
pocketfft_general_r2c<T>(dims_in, s_i, s_o, axes.back(), (const T *)in.data(),
(cmplx<T> *)res.mutable_data(), fct);
if (axes.size()==1) return res;
vector<size_t> axes2(axes.size()-1);
shape_t axes2(axes.size()-1);
for (size_t i=0; i<axes2.size(); ++i)
axes2[i] = axes[i];
pocketfft_general_c<T>(dims_out, s_o, s_o, axes2, true,
......@@ -2466,11 +2476,8 @@ template<typename T> py::array rfftn_internal(const py::array &in,
}
py::array rfftn(const py::array &in, py::object axes_, double fct)
{
if (in.dtype().is(f64))
return rfftn_internal<double>(in, axes_, fct);
else if (in.dtype().is(f32))
return rfftn_internal<float>(in, axes_, fct);
else throw runtime_error("unsupported data type");
return tcheck(in, f64, f32) ? rfftn_internal<double>(in, axes_, fct)
: rfftn_internal<float> (in, axes_, fct);
}
template<typename T> py::array irfftn_internal(const py::array &in,
py::object axes_, size_t lastsize, T fct)
......@@ -2479,7 +2486,7 @@ template<typename T> py::array irfftn_internal(const py::array &in,
py::array inter;
if (axes.size()>1) // need to do complex transforms
{
vector<size_t> axes2(axes.size()-1);
shape_t axes2(axes.size()-1);
for (size_t i=0; i<axes2.size(); ++i)
axes2[i] = axes[i];
inter = xfftn_internal<T>(in, axes2, 1., false);
......@@ -2502,11 +2509,9 @@ template<typename T> py::array irfftn_internal(const py::array &in,
py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
double fct)
{
if (in.dtype().is(c128))
return irfftn_internal<double>(in, axes_, lastsize, fct);
else if (in.dtype().is(c64))
return irfftn_internal<float>(in, axes_, lastsize, fct);
throw runtime_error("unsupported data type");
return tcheck(in, c128, c64) ?
irfftn_internal<double>(in, axes_, lastsize, fct) :
irfftn_internal<float> (in, axes_, lastsize, fct);
}
template<typename T> py::array hartley_internal(const py::array &in,
......@@ -2522,11 +2527,8 @@ template<typename T> py::array hartley_internal(const py::array &in,
}
py::array hartley(const py::array &in, py::object axes_, double fct)
{
if (in.dtype().is(f64))
return hartley_internal<double>(in, axes_, fct);
else if (in.dtype().is(f32))
return hartley_internal<float>(in, axes_, fct);
throw runtime_error("unsupported data type");
return tcheck(in, f64, f32) ? hartley_internal<double>(in, axes_, fct)
: hartley_internal<float> (in, axes_, fct);
}
template<typename T>py::array complex2hartley(const py::array &in,
......@@ -2579,11 +2581,8 @@ template<typename T>py::array complex2hartley(const py::array &in,
py::array mycomplex2hartley(const py::array &in,
const py::array &tmp, py::object axes_)
{
if (in.dtype().is(f64))
return complex2hartley<double>(in, tmp, axes_);
else if (in.dtype().is(f32))
return complex2hartley<float>(in, tmp, axes_);
else throw runtime_error("unsupported data type");
return tcheck(in, f64, f32) ? complex2hartley<double>(in, tmp, axes_)
: complex2hartley<float> (in, tmp, axes_);
}
py::array hartley2(const py::array &in, py::object axes_, double fct)
{ return mycomplex2hartley(in, rfftn(in, axes_, fct), axes_); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment