Commit bdd8ea12 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more cleanup

parent 367ace56
...@@ -248,7 +248,7 @@ class sincos_2pibyn ...@@ -248,7 +248,7 @@ class sincos_2pibyn
} }
public: public:
sincos_2pibyn(size_t n, bool half) NOINLINE sincos_2pibyn(size_t n, bool half)
: data(2*n) : data(2*n)
{ {
sincos_2pibyn_half(n, data.data()); sincos_2pibyn_half(n, data.data());
...@@ -1968,26 +1968,8 @@ template<typename T0> class pocketfft_r ...@@ -1968,26 +1968,8 @@ template<typename T0> class pocketfft_r
using shape_t = vector<size_t>; using shape_t = vector<size_t>;
using stride_t = vector<int64_t>; using stride_t = vector<int64_t>;
struct diminfo
{ size_t n; int64_t s; };
struct diminfo_io struct diminfo_io
{ size_t n; int64_t s_i, s_o; }; { size_t n; int64_t s_i, s_o; };
class multiarr
{
private:
vector<diminfo> dim;
public:
multiarr (const shape_t &n, const stride_t &s)
{
dim.reserve(n.size());
for (size_t i=0; i<n.size(); ++i)
dim.push_back({n[i], s[i]});
}
size_t ndim() const { return dim.size(); }
size_t size(size_t i) const { return dim[i].n; }
int64_t stride(size_t i) const { return dim[i].s; }
};
class multi_iter_io class multi_iter_io
{ {
...@@ -2000,17 +1982,18 @@ class multi_iter_io ...@@ -2000,17 +1982,18 @@ class multi_iter_io
size_t rem; size_t rem;
public: public:
multi_iter_io(const multiarr &arr_i, const multiarr &arr_o, size_t idim) multi_iter_io(const shape_t &shape_i, const stride_t &stride_i,
: pos(arr_i.ndim()-1, 0), ofs_i(0), ofs_o(0), len_i(arr_i.size(idim)), const shape_t &shape_o, const stride_t &stride_o, size_t idim)
len_o(arr_o.size(idim)), str_i(arr_i.stride(idim)), : pos(shape_i.size()-1, 0), ofs_i(0), ofs_o(0), len_i(shape_i[idim]),
str_o(arr_o.stride(idim)), rem(1) len_o(shape_o[idim]), str_i(stride_i[idim]),
{ str_o(stride_o[idim]), rem(1)
dim.reserve(arr_i.ndim()-1); {
for (size_t i=0; i<arr_i.ndim(); ++i) dim.reserve(shape_i.size()-1);
for (size_t i=0; i<shape_i.size(); ++i)
if (i!=idim) if (i!=idim)
{ {
dim.push_back({arr_i.size(i), arr_i.stride(i), arr_o.stride(i)}); dim.push_back({shape_i[i], stride_i[i], stride_o[i]});
rem *= arr_i.size(i); rem *= shape_i[i];
} }
} }
void advance() void advance()
...@@ -2126,12 +2109,12 @@ template<typename T> NOINLINE void pocketfft_general_c(const shape_t &shape, ...@@ -2126,12 +2109,12 @@ template<typename T> NOINLINE void pocketfft_general_c(const shape_t &shape,
auto tdata = (cmplx<T> *)storage.data(); auto tdata = (cmplx<T> *)storage.data();
constexpr int vlen = VTYPE<T>::vlen; constexpr int vlen = VTYPE<T>::vlen;
multiarr a_in(shape, stride_in), a_out(shape, stride_out);
unique_ptr<pocketfft_c<T>> plan; unique_ptr<pocketfft_c<T>> plan;
for (size_t iax=0; iax<axes.size(); ++iax) for (size_t iax=0; iax<axes.size(); ++iax)
{ {
multi_iter_io it(a_in, a_out, axes[iax]); multi_iter_io it(shape, iax==0 ? stride_in : stride_out, shape, stride_out,
axes[iax]);
size_t len=it.length_in(); size_t len=it.length_in();
int64_t s_i=it.stride_in(), s_o=it.stride_out(); int64_t s_i=it.stride_in(), s_o=it.stride_out();
if ((!plan) || (len!=plan->length())) if ((!plan) || (len!=plan->length()))
...@@ -2175,11 +2158,8 @@ template<typename T> NOINLINE void pocketfft_general_c(const shape_t &shape, ...@@ -2175,11 +2158,8 @@ template<typename T> NOINLINE void pocketfft_general_c(const shape_t &shape,
} }
it.advance(); it.advance();
} }
// after the first dimension, take data from output array data_in = data_out; // after the first dimension, take data from output array
a_in = a_out; fct = T(1); // factor has been applied, use 1 for remaining axes
data_in = data_out;
// factor has been applied, use 1 for remaining axes
fct = T(1);
} }
} }
...@@ -2193,12 +2173,12 @@ template<typename T> NOINLINE void pocketfft_general_hartley(const shape_t &shap ...@@ -2193,12 +2173,12 @@ template<typename T> NOINLINE void pocketfft_general_hartley(const shape_t &shap
constexpr int vlen = VTYPE<T>::vlen; constexpr int vlen = VTYPE<T>::vlen;
auto tdata = (T *)storage.data(); auto tdata = (T *)storage.data();
multiarr a_in(shape, stride_in), a_out(shape, stride_out);
unique_ptr<pocketfft_r<T>> plan; unique_ptr<pocketfft_r<T>> plan;
for (size_t iax=0; iax<axes.size(); ++iax) for (size_t iax=0; iax<axes.size(); ++iax)
{ {
multi_iter_io it(a_in, a_out, axes[iax]); multi_iter_io it(shape, iax==0 ? stride_in : stride_out, shape, stride_out,
axes[iax]);
size_t len=it.length_in(); size_t len=it.length_in();
int64_t s_i=it.stride_in(), s_o=it.stride_out(); int64_t s_i=it.stride_in(), s_o=it.stride_out();
if ((!plan) || (len!=plan->length())) if ((!plan) || (len!=plan->length()))
...@@ -2241,11 +2221,8 @@ template<typename T> NOINLINE void pocketfft_general_hartley(const shape_t &shap ...@@ -2241,11 +2221,8 @@ template<typename T> NOINLINE void pocketfft_general_hartley(const shape_t &shap
data_out[it.offset_out()+i1*s_o] = tdata[i]; data_out[it.offset_out()+i1*s_o] = tdata[i];
it.advance(); it.advance();
} }
// after the first dimension, take data from output array data_in = data_out; // after the first dimension, take data from output array
a_in = a_out; fct = T(1); // factor has been applied, use 1 for remaining axes
data_in = data_out;
// factor has been applied, use 1 for remaining axes
fct = T(1);
} }
} }
...@@ -2259,9 +2236,8 @@ template<typename T> NOINLINE void pocketfft_general_r2c(const shape_t &shape, ...@@ -2259,9 +2236,8 @@ template<typename T> NOINLINE void pocketfft_general_r2c(const shape_t &shape,
constexpr int vlen = VTYPE<T>::vlen; constexpr int vlen = VTYPE<T>::vlen;
auto tdata = (T *)storage.data(); auto tdata = (T *)storage.data();
multiarr a_in(shape, stride_in), a_out(shape, stride_out);
pocketfft_r<T> plan(shape[axis]); pocketfft_r<T> plan(shape[axis]);
multi_iter_io it(a_in, a_out, axis); multi_iter_io it(shape, stride_in, shape, stride_out, axis);
size_t len=shape[axis]; size_t len=shape[axis];
int64_t s_i=it.stride_in(), s_o=it.stride_out(); int64_t s_i=it.stride_in(), s_o=it.stride_out();
if (vlen>1) if (vlen>1)
...@@ -2308,9 +2284,8 @@ template<typename T> NOINLINE void pocketfft_general_c2r(const shape_t &shape_ou ...@@ -2308,9 +2284,8 @@ template<typename T> NOINLINE void pocketfft_general_c2r(const shape_t &shape_ou
constexpr int vlen = VTYPE<T>::vlen; constexpr int vlen = VTYPE<T>::vlen;
auto tdata = (T *)storage.data(); auto tdata = (T *)storage.data();
multiarr a_in(shape_out, stride_in), a_out(shape_out, stride_out);
pocketfft_r<T> plan(shape_out[axis]); pocketfft_r<T> plan(shape_out[axis]);
multi_iter_io it(a_in, a_out, axis); multi_iter_io it(shape_out, stride_in, shape_out, stride_out, axis);
size_t len=shape_out[axis]; size_t len=shape_out[axis];
int64_t s_i=it.stride_in(), s_o=it.stride_out(); int64_t s_i=it.stride_in(), s_o=it.stride_out();
if (vlen>1) if (vlen>1)
...@@ -2509,9 +2484,7 @@ template<typename T>py::array complex2hartley(const py::array &in, ...@@ -2509,9 +2484,7 @@ template<typename T>py::array complex2hartley(const py::array &in,
auto stride_tmp(copy_strides(tmp)), stride_out(copy_strides(out)); auto stride_tmp(copy_strides(tmp)), stride_out(copy_strides(out));
auto axes = makeaxes(in, axes_); auto axes = makeaxes(in, axes_);
size_t axis = axes.back(); size_t axis = axes.back();
multiarr a_tmp(dims_tmp, stride_tmp), multi_iter_io it(dims_tmp, stride_tmp, dims_out, stride_out, axis);
a_out(dims_out, stride_out);
multi_iter_io it(a_tmp, a_out, axis);
const cmplx<T> *tdata = (const cmplx<T> *)tmp.data(); const cmplx<T> *tdata = (const cmplx<T> *)tmp.data();
T *odata = (T *)out.mutable_data(); T *odata = (T *)out.mutable_data();
vector<bool> swp(ndim-1,false); vector<bool> swp(ndim-1,false);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment