Commit 007a96f6 authored by Martin Reinecke's avatar Martin Reinecke

simplify

parent a287f23e
......@@ -1968,63 +1968,14 @@ template<typename T0> class pocketfft_r
using shape_t = vector<size_t>;
using stride_t = vector<int64_t>;
struct diminfo_io
{ size_t n; int64_t s_i, s_o; };
class multi_iter_io
template<size_t N, typename Ti, typename To> class multi_iter
{
public:
struct diminfo_io
{ size_t n; int64_t s_i, s_o; };
vector<diminfo_io> dim;
shape_t pos;
int64_t ofs_i, ofs_o;
size_t len_i, len_o;
int64_t str_i, str_o;
size_t rem;
public:
multi_iter_io(const shape_t &shape_i, const stride_t &stride_i,
const shape_t &shape_o, const stride_t &stride_o, size_t idim)
: pos(shape_i.size()-1, 0), ofs_i(0), ofs_o(0), len_i(shape_i[idim]),
len_o(shape_o[idim]), str_i(stride_i[idim]),
str_o(stride_o[idim]), rem(1)
{
dim.reserve(shape_i.size()-1);
for (size_t i=0; i<shape_i.size(); ++i)
if (i!=idim)
{
dim.push_back({shape_i[i], stride_i[i], stride_o[i]});
rem *= shape_i[i];
}
}
void advance()
{
if (--rem==0) return;
for (int i=pos.size()-1; i>=0; --i)
{
++pos[i];
ofs_i += dim[i].s_i;
ofs_o += dim[i].s_o;
if (pos[i] < dim[i].n)
return;
pos[i] = 0;
ofs_i -= dim[i].n*dim[i].s_i;
ofs_o -= dim[i].n*dim[i].s_o;
}
}
int64_t offset_in() const { return ofs_i; }
int64_t offset_out() const { return ofs_o; }
size_t length_in() const { return len_i; }
size_t length_out() const { return len_o; }
int64_t stride_in() const { return str_i; }
int64_t stride_out() const { return str_o; }
size_t remaining() const { return rem; }
};
template<size_t N, typename Ti, typename To> class multi_iter
{
private:
vector<diminfo_io> dim;
shape_t pos;
const Ti *p_ii, *p_i[N];
To *p_oi, *p_o[N];
size_t len_i, len_o;
......@@ -2148,18 +2099,6 @@ template<typename T> arr<char> alloc_tmp(const shape_t &shape,
return arr<char>(tmpsize*elemsize);
}
template<size_t vlen> struct multioffset_io
{
int64_t ofs_i[vlen], ofs_o[vlen];
multioffset_io(multi_iter_io &it)
{
for (size_t i=0; i<vlen; ++i)
{ ofs_i[i] = it.offset_in(); ofs_o[i] = it.offset_out(); it.advance(); }
}
int64_t in(size_t i) const { return ofs_i[i]; }
int64_t out(size_t i) const { return ofs_o[i]; }
};
template<typename T> NOINLINE void pocketfft_general_c(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out,
const shape_t &axes, bool forward, const cmplx<T> *data_in,
......@@ -2536,17 +2475,15 @@ template<typename T>py::array complex2hartley(const py::array &in,
auto stride_tmp(copy_strides(tmp)), stride_out(copy_strides(out));
auto axes = makeaxes(in, axes_);
size_t axis = axes.back();
multi_iter_io it(dims_tmp, stride_tmp, dims_out, stride_out, axis);
const cmplx<T> *tdata = (const cmplx<T> *)tmp.data();
T *odata = (T *)out.mutable_data();
multi_iter<1,cmplx<T>,T> it(tdata, dims_tmp, stride_tmp, odata, dims_out, stride_out, axis);
vector<bool> swp(ndim-1,false);
for (auto i: axes)
if (i!=axis)
swp[i<axis ? i : i-1] = true;
while(it.remaining()>0)
{
auto tofs = it.offset_in();
auto ofs = it.offset_out();
int64_t rofs = 0;
for (size_t i=0; i<it.pos.size(); ++i)
{
......@@ -2558,15 +2495,15 @@ template<typename T>py::array complex2hartley(const py::array &in,
rofs += x*it.dim[i].s_o;
}
}
it.advance(1);
for (size_t i=0; i<it.length_in(); ++i)
{
auto re = tdata[tofs + i*it.stride_in()].r;
auto im = tdata[tofs + i*it.stride_in()].i;
auto re = it.in(0,i).r;
auto im = it.in(0,i).i;
auto rev_i = (i==0) ? 0 : it.length_out()-i;
odata[ofs + i*it.stride_out()] = re+im;
it.out(0,i) = re+im;
odata[rofs + rev_i*it.stride_out()] = re-im;
}
it.advance();
}
return out;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment