Commit d6c1f716 authored by Martin Reinecke's avatar Martin Reinecke

polishing

parent 47cc2243
......@@ -2075,9 +2075,8 @@ template<typename T> arr<char> alloc_tmp(const shape_t &shape,
size_t axsize, size_t elemsize)
{
size_t fullsize=1;
size_t ndim = shape.size();
for (size_t i=0; i<ndim; ++i)
fullsize*=shape[i];
for (auto sz: shape)
fullsize*=sz;
auto othersize = fullsize/axsize;
auto tmpsize = axsize*((othersize>=VTYPE<T>::vlen) ? VTYPE<T>::vlen : 1);
return arr<char>(tmpsize*elemsize);
......@@ -2086,9 +2085,8 @@ template<typename T> arr<char> alloc_tmp(const shape_t &shape,
const shape_t &axes, size_t elemsize)
{
size_t fullsize=1;
size_t ndim = shape.size();
for (size_t i=0; i<ndim; ++i)
fullsize*=shape[i];
for (auto sz:shape)
fullsize*=sz;
size_t tmpsize=0;
for (size_t i=0; i<axes.size(); ++i)
{
......@@ -2127,44 +2125,46 @@ template<typename T> void pocketfft_general_c(const shape_t &shape,
for (size_t iax=0; iax<axes.size(); ++iax)
{
multi_iter it_in(a_in, axes[iax]), it_out(a_out, axes[iax]);
if ((!plan) || (it_in.length()!=plan->length()))
plan.reset(new pocketfft_c<T>(it_in.length()));
size_t len=it_in.length();
int64_t s_i=it_in.stride(), s_o=it_out.stride();
if ((!plan) || (len!=plan->length()))
plan.reset(new pocketfft_c<T>(len));
if (vlen>1)
while (it_in.remaining()>=vlen)
{
multioffset<vlen> p_i(it_in), p_o(it_out);
for (size_t i=0; i<it_in.length(); ++i)
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
{
tdatav[i].r[j] = data_in[p_i[j]+i*it_in.stride()].r;
tdatav[i].i[j] = data_in[p_i[j]+i*it_in.stride()].i;
tdatav[i].r[j] = data_in[p_i[j]+i*s_i].r;
tdatav[i].i[j] = data_in[p_i[j]+i*s_i].i;
}
forward ? plan->forward((vtype *)tdatav, fct)
: plan->backward((vtype *)tdatav, fct);
for (size_t i=0; i<it_out.length(); ++i)
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
data_out[p_o[j]+i*it_out.stride()].Set(tdatav[i].r[j],tdatav[i].i[j]);
data_out[p_o[j]+i*s_o].Set(tdatav[i].r[j],tdatav[i].i[j]);
}
while (it_in.remaining()>0)
{
if ((data_in==data_out) && it_in.stride()==1) // fully in-place
if ((data_in==data_out) && s_i==1) // fully in-place
forward ? plan->forward((T *)(data_in+it_in.offset()), fct)
: plan->backward((T *)(data_in+it_in.offset()), fct);
else if (it_out.stride()==1) // compute FFT in output location
else if (s_o==1) // compute FFT in output location
{
for (size_t i=0; i<it_out.length(); ++i)
data_out[it_out.offset()+i] = data_in[it_in.offset()+i*it_in.stride()];
for (size_t i=0; i<len; ++i)
data_out[it_out.offset()+i] = data_in[it_in.offset()+i*s_i];
forward ? plan->forward((T *)(data_out+it_out.offset()), fct)
: plan->backward((T *)(data_out+it_out.offset()), fct);
}
else
{
for (size_t i=0; i<it_in.length(); ++i)
tdata[i] = data_in[it_in.offset() + i*it_in.stride()];
for (size_t i=0; i<len; ++i)
tdata[i] = data_in[it_in.offset() + i*s_i];
forward ? plan->forward((T *)tdata, fct)
: plan->backward((T *)tdata, fct);
for (size_t i=0; i<it_out.length(); ++i)
data_out[it_out.offset()+i*it_out.stride()] = tdata[i];
for (size_t i=0; i<len; ++i)
data_out[it_out.offset()+i*s_o] = tdata[i];
}
it_in.advance();
it_out.advance();
......@@ -2193,44 +2193,46 @@ template<typename T> void pocketfft_general_hartley(const shape_t &shape,
for (size_t iax=0; iax<axes.size(); ++iax)
{
multi_iter it_in(a_in, axes[iax]), it_out(a_out, axes[iax]);
if ((!plan) || (it_in.length()!=plan->length()))
plan.reset(new pocketfft_r<T>(it_in.length()));
size_t len=it_in.length();
int64_t s_i=it_in.stride(), s_o=it_out.stride();
if ((!plan) || (len!=plan->length()))
plan.reset(new pocketfft_r<T>(len));
if (vlen>1)
while (it_in.remaining()>=vlen)
{
multioffset<vlen> p_i(it_in), p_o(it_out);
for (size_t i=0; i<it_in.length(); ++i)
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = data_in[p_i[j]+i*it_in.stride()];
tdatav[i][j] = data_in[p_i[j]+i*s_i];
plan->forward((vtype *)tdatav, fct);
for (size_t j=0; j<vlen; ++j)
data_out[p_o[j]] = tdatav[0][j];
size_t i=1, i1=1, i2=it_out.length()-1;
for (i=1; i<it_out.length()-1; i+=2, ++i1, --i2)
size_t i=1, i1=1, i2=len-1;
for (i=1; i<len-1; i+=2, ++i1, --i2)
for (size_t j=0; j<vlen; ++j)
{
data_out[p_o[j]+i1*it_out.stride()] = tdatav[i][j]+tdatav[i+1][j];
data_out[p_o[j]+i2*it_out.stride()] = tdatav[i][j]-tdatav[i+1][j];
data_out[p_o[j]+i1*s_o] = tdatav[i][j]+tdatav[i+1][j];
data_out[p_o[j]+i2*s_o] = tdatav[i][j]-tdatav[i+1][j];
}
if (i<it_out.length())
if (i<len)
for (size_t j=0; j<vlen; ++j)
data_out[p_o[j]+i1*it_out.stride()] = tdatav[i][j];
data_out[p_o[j]+i1*s_o] = tdatav[i][j];
}
while (it_in.remaining()>0)
{
for (size_t i=0; i<it_in.length(); ++i)
tdata[i] = data_in[it_in.offset()+i*it_in.stride()];
for (size_t i=0; i<len; ++i)
tdata[i] = data_in[it_in.offset()+i*s_i];
plan->forward((T *)tdata, fct);
// Hartley order
data_out[it_out.offset()] = tdata[0];
size_t i=1, i1=1, i2=it_out.length()-1;
for (i=1; i<it_out.length()-1; i+=2, ++i1, --i2)
for (i=1; i<len-1; i+=2, ++i1, --i2)
{
data_out[it_out.offset()+i1*it_out.stride()] = tdata[i]+tdata[i+1];
data_out[it_out.offset()+i2*it_out.stride()] = tdata[i]-tdata[i+1];
data_out[it_out.offset()+i1*s_o] = tdata[i]+tdata[i+1];
data_out[it_out.offset()+i2*s_o] = tdata[i]-tdata[i+1];
}
if (i<it_out.length())
data_out[it_out.offset()+i1*it_out.stride()] = tdata[i];
data_out[it_out.offset()+i1*s_o] = tdata[i];
it_in.advance();
it_out.advance();
}
......@@ -2255,30 +2257,25 @@ template<typename T> void pocketfft_general_r2c(const shape_t &shape,
multiarr a_in(shape, stride_in), a_out(shape, stride_out);
pocketfft_r<T> plan(shape[axis]);
multi_iter it_in(a_in, axis), it_out(a_out, axis);
size_t len=shape[axis], s_i=it_in.stride(), s_o=it_out.stride();
size_t len=shape[axis];
int64_t s_i=it_in.stride(), s_o=it_out.stride();
if (vlen>1)
while (it_in.remaining()>=vlen)
{
multioffset<vlen> p_i(it_in), p_o(it_out);
for (size_t i=0; i<it_in.length(); ++i)
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = data_in[p_i[j]+i*it_in.stride()];
tdatav[i][j] = data_in[p_i[j]+i*s_i];
plan.forward((vtype *)tdatav, fct);
for (size_t j=0; j<vlen; ++j)
data_out[p_o[j]].Set(tdatav[0][j]);
size_t i;
for (i=1; i<len-1; i+=2)
{
size_t io = (i+1)/2;
for (size_t j=0; j<vlen; ++j)
data_out[p_o[j]+io*it_out.stride()].Set(tdatav[i][j], tdatav[i+1][j]);
}
data_out[p_o[j]+((i+1)/2)*s_o].Set(tdatav[i][j], tdatav[i+1][j]);
if (i<len)
{
size_t io = (i+1)/2;
for (size_t j=0; j<vlen; ++j)
data_out[p_o[j]+io*it_out.stride()].Set(tdatav[i][j]);
}
data_out[p_o[j]+((i+1)/2)*s_o].Set(tdatav[i][j]);
}
while (it_in.remaining()>0)
{
......@@ -2310,7 +2307,8 @@ template<typename T> void pocketfft_general_c2r(const shape_t &shape_out,
multiarr a_in(shape_out, stride_in), a_out(shape_out, stride_out);
pocketfft_r<T> plan(shape_out[axis]);
multi_iter it_in(a_in, axis), it_out(a_out, axis);
size_t len=shape_out[axis], s_i=it_in.stride(), s_o=it_out.stride();
size_t len=shape_out[axis];
int64_t s_i=it_in.stride(), s_o=it_out.stride();
if (vlen>1)
while (it_in.remaining()>=vlen)
{
......@@ -2319,20 +2317,14 @@ template<typename T> void pocketfft_general_c2r(const shape_t &shape_out,
tdatav[0][j]=data_in[p_i[j]].r;
size_t i;
for (i=1; i<len-1; i+=2)
{
size_t ii = (i+1)/2;
for (size_t j=0; j<vlen; ++j)
{
tdatav[i][j] = data_in[p_i[j]+ii*s_i].r;
tdatav[i+1][j] = data_in[p_i[j]+ii*s_i].i;
tdatav[i][j] = data_in[p_i[j]+((i+1)/2)*s_i].r;
tdatav[i+1][j] = data_in[p_i[j]+((i+1)/2)*s_i].i;
}
}
if (i<len)
{
size_t ii = (i+1)/2;
for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = data_in[p_i[j]+ii*s_i].r;
}
tdatav[i][j] = data_in[p_i[j]+((i+1)/2)*s_i].r;
plan.backward(tdatav, fct);
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
......@@ -2412,8 +2404,8 @@ shape_t makeaxes(const py::array &in, py::object axes)
auto tmp=axes.cast<shape_t>();
if ((tmp.size()>size_t(in.ndim())) || (tmp.size()==0))
throw runtime_error("bad axes argument");
for (size_t i=0; i<tmp.size(); ++i)
if (tmp[i]>=size_t(in.ndim()))
for (auto sz: tmp)
if (sz>=size_t(in.ndim()))
throw runtime_error("invalid axis number");
return tmp;
}
......@@ -2452,9 +2444,7 @@ template<typename T> py::array rfftn_internal(const py::array &in,
pocketfft_general_r2c<T>(dims_in, s_i, s_o, axes.back(), (const T *)in.data(),
(cmplx<T> *)res.mutable_data(), fct);
if (axes.size()==1) return res;
shape_t axes2(axes.size()-1);
for (size_t i=0; i<axes2.size(); ++i)
axes2[i] = axes[i];
shape_t axes2(axes.begin(), --axes.end());
pocketfft_general_c<T>(dims_out, s_o, s_o, axes2, true,
(const cmplx<T> *)res.data(), (cmplx<T> *)res.mutable_data(), 1.);
return res;
......@@ -2468,16 +2458,8 @@ template<typename T> py::array irfftn_internal(const py::array &in,
py::object axes_, size_t lastsize, T fct)
{
auto axes = makeaxes(in, axes_);
py::array inter;
if (axes.size()>1) // need to do complex transforms
{
shape_t axes2(axes.size()-1);
for (size_t i=0; i<axes2.size(); ++i)
axes2[i] = axes[i];
inter = xfftn_internal<T>(in, axes2, 1., false, false);
}
else
inter = in;
py::array inter = (axes.size()==1) ? in :
xfftn_internal<T>(in, shape_t(axes.begin(), --axes.end()), 1., false, false);
size_t axis = axes.back();
if (lastsize==0) lastsize=2*inter.shape(axis)-1;
......@@ -2541,7 +2523,7 @@ template<typename T>py::array complex2hartley(const py::array &in,
{
auto tofs = it_tmp.offset();
auto ofs = it_out.offset();
size_t rofs = 0;
int64_t rofs = 0;
for (size_t i=0; i<it_out.pos.size(); ++i)
{
if (!swp[i])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment