Commit 287c1d70 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

factor out common code

parent 052f8584
......@@ -1997,7 +1997,15 @@ class multi_iter
};
#if defined(__AVX__)
#if (defined(__AVX512F__))
#include <x86intrin.h>
#define HAVE_VECSUPPORT
template<typename T> struct VTYPE{};
template<> struct VTYPE<double>
{ using type = __m512d; };
template<> struct VTYPE<float>
{ using type = __m512; };
#elif (defined(__AVX__))
#include <x86intrin.h>
#define HAVE_VECSUPPORT
template<typename T> struct VTYPE{};
......@@ -2005,7 +2013,7 @@ template<> struct VTYPE<double>
{ using type = __m256d; };
template<> struct VTYPE<float>
{ using type = __m256; };
#elif defined(__SSE2__)
#elif (defined(__SSE2__))
#include <x86intrin.h>
#define HAVE_VECSUPPORT
template<typename T> struct VTYPE{};
......@@ -2358,6 +2366,19 @@ vector<size_t> makeaxes(const py::array &in, py::object axes)
throw runtime_error("invalid axis number");
return tmp;
}
void make_strides(const py::array &in, const py::array &out,
vector<int64_t> &stride_in, vector<int64_t> &stride_out)
{
for (size_t i=0; i<stride_in.size(); ++i)
{
stride_in[i] = in.strides(i)/in.itemsize();
if (stride_in[i]*in.itemsize() != in.strides(i))
throw runtime_error("weird strides");
stride_out[i] = out.strides(i)/out.itemsize();
if (stride_out[i]*out.itemsize() != out.strides(i))
throw runtime_error("weird strides");
}
}
py::array execute(const py::array &in, vector<size_t> axes, double fct, bool fwd)
{
......@@ -2371,15 +2392,7 @@ py::array execute(const py::array &in, vector<size_t> axes, double fct, bool fwd
else if (in.dtype().is(c64))
res = py::array_t<complex<float>>(dims);
else throw runtime_error("unsupported data type");
for (int i=0; i<in.ndim(); ++i)
{
s_i[i] = in.strides(i)/in.itemsize();
if (s_i[i]*in.itemsize() != in.strides(i))
throw runtime_error("weird strides");
s_o[i] = res.strides(i)/res.itemsize();
if (s_o[i]*res.itemsize() != res.strides(i))
throw runtime_error("weird strides");
}
make_strides(in, res, s_i, s_o);
if (in.dtype().is(c128))
pocketfft_general_c<double>(in.ndim(), dims.data(),
s_i.data(), s_o.data(), axes.size(),
......@@ -2414,15 +2427,7 @@ py::array rfftn(const py::array &in, py::object axes_, double fct)
else if (in.dtype().is(f32))
res = py::array_t<complex<float>>(dims_out);
else throw runtime_error("unsupported data type");
for (int i=0; i<in.ndim(); ++i)
{
s_i[i] = in.strides(i)/in.itemsize();
if (s_i[i]*in.itemsize() != in.strides(i))
throw runtime_error("weird strides");
s_o[i] = res.strides(i)/res.itemsize();
if (s_o[i]*res.itemsize() != res.strides(i))
throw runtime_error("weird strides");
}
make_strides(in, res, s_i, s_o);
if (in.dtype().is(f64))
{
pocketfft_general_r2c<double>(in.ndim(), dims_in.data(),
......@@ -2482,15 +2487,7 @@ py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
else if (inter.dtype().is(c64))
res = py::array_t<float>(dims_out);
else throw runtime_error("unsupported data type");
for (int i=0; i<inter.ndim(); ++i)
{
s_i[i] = inter.strides(i)/inter.itemsize();
if (s_i[i]*inter.itemsize() != inter.strides(i))
throw runtime_error("weird strides");
s_o[i] = res.strides(i)/res.itemsize();
if (s_o[i]*res.itemsize() != res.strides(i))
throw runtime_error("weird strides");
}
make_strides(inter, res, s_i, s_o);
if (inter.dtype().is(c128))
pocketfft_general_c2r<double>(inter.ndim(), dims_out.data(),
s_i.data(), s_o.data(),
......@@ -2516,15 +2513,7 @@ py::array hartley(const py::array &in, py::object axes_, double fct)
else if (in.dtype().is(f32))
res = py::array_t<float>(dims);
else throw runtime_error("unsupported data type");
for (int i=0; i<in.ndim(); ++i)
{
s_i[i] = in.strides(i)/in.itemsize();
if (s_i[i]*in.itemsize() != in.strides(i))
throw runtime_error("weird strides");
s_o[i] = res.strides(i)/res.itemsize();
if (s_o[i]*res.itemsize() != res.strides(i))
throw runtime_error("weird strides");
}
make_strides(in, res, s_i, s_o);
if (in.dtype().is(f64))
pocketfft_general_hartley<double>(in.ndim(), dims.data(),
s_i.data(), s_o.data(), axes.size(),
......@@ -2554,17 +2543,10 @@ template<typename T>py::array complex2hartley(const py::array &in,
vector<size_t> dims_tmp(ndim);
vector<int64_t> stride_tmp(ndim), stride_out(ndim);
for (int i=0; i<ndim; ++i)
{
dims_tmp[i] = tmp.shape(i);
stride_tmp[i] = tmp.strides(i)/tmp.itemsize();
if (stride_tmp[i]*tmp.itemsize() != tmp.strides(i))
throw runtime_error("weird strides");
stride_out[i] = out.strides(i)/out.itemsize();
if (stride_out[i]*out.itemsize() != out.strides(i))
throw runtime_error("weird strides");
}
make_strides(tmp, out, stride_tmp, stride_out);
auto axes = makeaxes(in, axes_);
int axis = axes.back();
size_t axis = axes.back();
multiarr a_tmp(ndim, dims_tmp.data(), stride_tmp.data()),
a_out(ndim, dims_out.data(), stride_out.data());
multi_iter it_tmp(a_tmp, axis), it_out(a_out, axis);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment