Commit bdd8ea12 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more cleanup

parent 367ace56
......@@ -248,7 +248,7 @@ class sincos_2pibyn
}
public:
sincos_2pibyn(size_t n, bool half)
NOINLINE sincos_2pibyn(size_t n, bool half)
: data(2*n)
{
sincos_2pibyn_half(n, data.data());
......@@ -1968,26 +1968,8 @@ template<typename T0> class pocketfft_r
using shape_t = vector<size_t>;
using stride_t = vector<int64_t>;
struct diminfo
{ size_t n; int64_t s; };
struct diminfo_io
{ size_t n; int64_t s_i, s_o; };
class multiarr
{
private:
vector<diminfo> dim;
public:
multiarr (const shape_t &n, const stride_t &s)
{
dim.reserve(n.size());
for (size_t i=0; i<n.size(); ++i)
dim.push_back({n[i], s[i]});
}
size_t ndim() const { return dim.size(); }
size_t size(size_t i) const { return dim[i].n; }
int64_t stride(size_t i) const { return dim[i].s; }
};
class multi_iter_io
{
......@@ -2000,17 +1982,18 @@ class multi_iter_io
size_t rem;
public:
multi_iter_io(const multiarr &arr_i, const multiarr &arr_o, size_t idim)
: pos(arr_i.ndim()-1, 0), ofs_i(0), ofs_o(0), len_i(arr_i.size(idim)),
len_o(arr_o.size(idim)), str_i(arr_i.stride(idim)),
str_o(arr_o.stride(idim)), rem(1)
{
dim.reserve(arr_i.ndim()-1);
for (size_t i=0; i<arr_i.ndim(); ++i)
multi_iter_io(const shape_t &shape_i, const stride_t &stride_i,
const shape_t &shape_o, const stride_t &stride_o, size_t idim)
: pos(shape_i.size()-1, 0), ofs_i(0), ofs_o(0), len_i(shape_i[idim]),
len_o(shape_o[idim]), str_i(stride_i[idim]),
str_o(stride_o[idim]), rem(1)
{
dim.reserve(shape_i.size()-1);
for (size_t i=0; i<shape_i.size(); ++i)
if (i!=idim)
{
dim.push_back({arr_i.size(i), arr_i.stride(i), arr_o.stride(i)});
rem *= arr_i.size(i);
dim.push_back({shape_i[i], stride_i[i], stride_o[i]});
rem *= shape_i[i];
}
}
void advance()
......@@ -2126,12 +2109,12 @@ template<typename T> NOINLINE void pocketfft_general_c(const shape_t &shape,
auto tdata = (cmplx<T> *)storage.data();
constexpr int vlen = VTYPE<T>::vlen;
multiarr a_in(shape, stride_in), a_out(shape, stride_out);
unique_ptr<pocketfft_c<T>> plan;
for (size_t iax=0; iax<axes.size(); ++iax)
{
multi_iter_io it(a_in, a_out, axes[iax]);
multi_iter_io it(shape, iax==0 ? stride_in : stride_out, shape, stride_out,
axes[iax]);
size_t len=it.length_in();
int64_t s_i=it.stride_in(), s_o=it.stride_out();
if ((!plan) || (len!=plan->length()))
......@@ -2175,11 +2158,8 @@ template<typename T> NOINLINE void pocketfft_general_c(const shape_t &shape,
}
it.advance();
}
// after the first dimension, take data from output array
a_in = a_out;
data_in = data_out;
// factor has been applied, use 1 for remaining axes
fct = T(1);
data_in = data_out; // after the first dimension, take data from output array
fct = T(1); // factor has been applied, use 1 for remaining axes
}
}
......@@ -2193,12 +2173,12 @@ template<typename T> NOINLINE void pocketfft_general_hartley(const shape_t &shap
constexpr int vlen = VTYPE<T>::vlen;
auto tdata = (T *)storage.data();
multiarr a_in(shape, stride_in), a_out(shape, stride_out);
unique_ptr<pocketfft_r<T>> plan;
for (size_t iax=0; iax<axes.size(); ++iax)
{
multi_iter_io it(a_in, a_out, axes[iax]);
multi_iter_io it(shape, iax==0 ? stride_in : stride_out, shape, stride_out,
axes[iax]);
size_t len=it.length_in();
int64_t s_i=it.stride_in(), s_o=it.stride_out();
if ((!plan) || (len!=plan->length()))
......@@ -2241,11 +2221,8 @@ template<typename T> NOINLINE void pocketfft_general_hartley(const shape_t &shap
data_out[it.offset_out()+i1*s_o] = tdata[i];
it.advance();
}
// after the first dimension, take data from output array
a_in = a_out;
data_in = data_out;
// factor has been applied, use 1 for remaining axes
fct = T(1);
data_in = data_out; // after the first dimension, take data from output array
fct = T(1); // factor has been applied, use 1 for remaining axes
}
}
......@@ -2259,9 +2236,8 @@ template<typename T> NOINLINE void pocketfft_general_r2c(const shape_t &shape,
constexpr int vlen = VTYPE<T>::vlen;
auto tdata = (T *)storage.data();
multiarr a_in(shape, stride_in), a_out(shape, stride_out);
pocketfft_r<T> plan(shape[axis]);
multi_iter_io it(a_in, a_out, axis);
multi_iter_io it(shape, stride_in, shape, stride_out, axis);
size_t len=shape[axis];
int64_t s_i=it.stride_in(), s_o=it.stride_out();
if (vlen>1)
......@@ -2308,9 +2284,8 @@ template<typename T> NOINLINE void pocketfft_general_c2r(const shape_t &shape_ou
constexpr int vlen = VTYPE<T>::vlen;
auto tdata = (T *)storage.data();
multiarr a_in(shape_out, stride_in), a_out(shape_out, stride_out);
pocketfft_r<T> plan(shape_out[axis]);
multi_iter_io it(a_in, a_out, axis);
multi_iter_io it(shape_out, stride_in, shape_out, stride_out, axis);
size_t len=shape_out[axis];
int64_t s_i=it.stride_in(), s_o=it.stride_out();
if (vlen>1)
......@@ -2509,9 +2484,7 @@ template<typename T>py::array complex2hartley(const py::array &in,
auto stride_tmp(copy_strides(tmp)), stride_out(copy_strides(out));
auto axes = makeaxes(in, axes_);
size_t axis = axes.back();
multiarr a_tmp(dims_tmp, stride_tmp),
a_out(dims_out, stride_out);
multi_iter_io it(a_tmp, a_out, axis);
multi_iter_io it(dims_tmp, stride_tmp, dims_out, stride_out, axis);
const cmplx<T> *tdata = (const cmplx<T> *)tmp.data();
T *odata = (T *)out.mutable_data();
vector<bool> swp(ndim-1,false);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment