Commit 415305d7 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

better interfaces

parent 007a96f6
......@@ -1968,50 +1968,76 @@ template<typename T0> class pocketfft_r
using shape_t = vector<size_t>;
using stride_t = vector<int64_t>;
size_t prod(const shape_t &shape)
{
size_t res=1;
for (auto sz: shape)
res*=sz;
return res;
}
template<typename T> class ndarr
{
private:
char *d;
const char *cd;
shape_t shp;
stride_t str;
public:
ndarr(const void *data_, const shape_t &shape_, const stride_t &stride_)
: d(nullptr), cd(reinterpret_cast<const char *>(data_)), shp(shape_),
str(stride_) {}
ndarr(void *data_, const shape_t &shape_, const stride_t &stride_)
: d(reinterpret_cast<char *>(data_)),
cd(reinterpret_cast<const char *>(data_)), shp(shape_), str(stride_) {}
size_t ndim() const { return shp.size(); }
size_t size() const { return prod(shp); }
const shape_t &shape() const { return shp; }
size_t shape(size_t i) const { return shp[i]; }
const stride_t &stride() const { return str; }
const int64_t &stride(size_t i) const { return str[i]; }
T &operator[](int64_t ofs)
{ return *reinterpret_cast<T *>(d+ofs); }
const T &operator[](int64_t ofs) const
{ return *reinterpret_cast<const T *>(cd+ofs); }
};
template<size_t N, typename Ti, typename To> class multi_iter
{
public:
struct diminfo_io
{ size_t n; int64_t s_i, s_o; };
vector<diminfo_io> dim;
shape_t pos;
const ndarr<Ti> &iarr;
ndarr<To> &oarr;
private:
const Ti *p_ii, *p_i[N];
To *p_oi, *p_o[N];
size_t len_i, len_o;
int64_t str_i, str_o;
int64_t p_ii, p_i[N], str_i;
int64_t p_oi, p_o[N], str_o;
size_t idim;
size_t rem;
void advance_i()
{
for (int i=pos.size()-1; i>=0; --i)
{
if (i==int(idim)) continue;
++pos[i];
p_ii += dim[i].s_i;
p_oi += dim[i].s_o;
if (pos[i] < dim[i].n)
p_ii += iarr.stride(i);
p_oi += oarr.stride(i);
if (pos[i] < iarr.shape(i))
return;
pos[i] = 0;
p_ii -= dim[i].n*dim[i].s_i;
p_oi -= dim[i].n*dim[i].s_o;
p_ii -= iarr.shape(i)*iarr.stride(i);
p_oi -= oarr.shape(i)*oarr.stride(i);
}
}
public:
multi_iter(const Ti *d_i, const shape_t &shape_i, const stride_t &stride_i,
To *d_o, const shape_t &shape_o, const stride_t &stride_o, size_t idim)
: pos(shape_i.size()-1, 0), p_ii(d_i), p_oi(d_o), len_i(shape_i[idim]),
len_o(shape_o[idim]), str_i(stride_i[idim]),
str_o(stride_o[idim]), rem(1)
{
dim.reserve(shape_i.size()-1);
for (size_t i=0; i<shape_i.size(); ++i)
if (i!=idim)
{
dim.push_back({shape_i[i], stride_i[i], stride_o[i]});
rem *= shape_i[i];
}
}
multi_iter(const ndarr<Ti> &iarr_, ndarr<To> &oarr_, size_t idim_)
: pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0),
str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)),
idim(idim_), rem(iarr.size()/iarr.shape(idim))
{}
void advance(size_t n)
{
if (rem<n) throw runtime_error("underrun");
......@@ -2023,14 +2049,18 @@ template<size_t N, typename Ti, typename To> class multi_iter
}
rem -= n;
}
const Ti &in (size_t j, size_t i) const { return p_i[j][i*str_i]; }
To &out (size_t j, size_t i) const { return p_o[j][i*str_o]; }
size_t length_in() const { return len_i; }
size_t length_out() const { return len_o; }
const Ti &in (size_t i) const { return iarr[p_i[0] + i*str_i]; }
const Ti &in (size_t j, size_t i) const { return iarr[p_i[j] + i*str_i]; }
To &out (size_t i) { return oarr[p_o[0] + i*str_o]; }
To &out (size_t j, size_t i) { return oarr[p_o[j] + i*str_o]; }
size_t length_in() const { return iarr.shape(idim); }
size_t length_out() const { return oarr.shape(idim); }
int64_t stride_in() const { return str_i; }
int64_t stride_out() const { return str_o; }
size_t remaining() const { return rem; }
bool inplace() const { return p_ii==p_oi; }
bool inplace() const { return &iarr[0]==&oarr[0]; }
bool contiguous_in() const { return str_i==sizeof(Ti); }
bool contiguous_out() const { return str_o==sizeof(To); }
};
......@@ -2071,13 +2101,6 @@ template<> struct VTYPE<float>
};
#endif
size_t prod(const shape_t &shape)
{
size_t res=1;
for (auto sz: shape)
res*=sz;
return res;
}
template<typename T> arr<char> alloc_tmp(const shape_t &shape,
size_t axsize, size_t elemsize)
{
......@@ -2099,19 +2122,17 @@ template<typename T> arr<char> alloc_tmp(const shape_t &shape,
return arr<char>(tmpsize*elemsize);
}
template<typename T> NOINLINE void pocketfft_general_c(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out,
const shape_t &axes, bool forward, const cmplx<T> *data_in,
cmplx<T> *data_out, T fct)
template<typename T> NOINLINE void pocketfft_general_c(
const ndarr<cmplx<T>> &in, ndarr<cmplx<T>> &out,
const shape_t &axes, bool forward, T fct)
{
auto storage = alloc_tmp<T>(shape, axes, sizeof(cmplx<T>));
auto storage = alloc_tmp<T>(in.shape(), axes, sizeof(cmplx<T>));
unique_ptr<pocketfft_c<T>> plan;
for (size_t iax=0; iax<axes.size(); ++iax)
{
constexpr int vlen = VTYPE<T>::vlen;
multi_iter<vlen, cmplx<T>, cmplx<T>> it(iax==0? data_in : data_out, shape,
iax==0 ? stride_in : stride_out, data_out, shape, stride_out, axes[iax]);
multi_iter<vlen, cmplx<T>, cmplx<T>> it(iax==0? in : out, out, axes[iax]);
size_t len=it.length_in();
if ((!plan) || (len!=plan->length()))
plan.reset(new pocketfft_c<T>(len));
......@@ -2137,42 +2158,40 @@ template<typename T> NOINLINE void pocketfft_general_c(const shape_t &shape,
{
it.advance(1);
auto tdata = (cmplx<T> *)storage.data();
if (it.inplace() && it.stride_in()==1) // fully in-place
forward ? plan->forward((T *)(&it.in(0,0)), fct)
: plan->backward((T *)(&it.in(0,0)), fct);
else if (it.stride_out()==1) // compute FFT in output location
if (it.inplace() && it.contiguous_out()) // fully in-place
forward ? plan->forward((T *)(&it.in(0)), fct)
: plan->backward((T *)(&it.in(0)), fct);
else if (it.contiguous_out()) // compute FFT in output location
{
for (size_t i=0; i<len; ++i)
it.out(0,i) = it.in(0,i);
forward ? plan->forward((T *)(&it.out(0,0)), fct)
: plan->backward((T *)(&it.out(0,0)), fct);
it.out(i) = it.in(i);
forward ? plan->forward((T *)(&it.out(0)), fct)
: plan->backward((T *)(&it.out(0)), fct);
}
else
{
for (size_t i=0; i<len; ++i)
tdata[i] = it.in(0,i);
tdata[i] = it.in(i);
forward ? plan->forward((T *)tdata, fct)
: plan->backward((T *)tdata, fct);
for (size_t i=0; i<len; ++i)
it.out(0,i) = tdata[i];
it.out(i) = tdata[i];
}
}
fct = T(1); // factor has been applied, use 1 for remaining axes
}
}
template<typename T> NOINLINE void pocketfft_general_hartley(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out,
const shape_t &axes, const T *data_in, T *data_out, T fct)
template<typename T> NOINLINE void pocketfft_general_hartley(
const ndarr<T> &in, ndarr<T> &out, const shape_t &axes, T fct)
{
auto storage = alloc_tmp<T>(shape, axes, sizeof(T));
auto storage = alloc_tmp<T>(in.shape(), axes, sizeof(T));
unique_ptr<pocketfft_r<T>> plan;
for (size_t iax=0; iax<axes.size(); ++iax)
{
constexpr int vlen = VTYPE<T>::vlen;
multi_iter<vlen, T, T> it(iax==0 ? data_in : data_out, shape,
iax==0 ? stride_in : stride_out, data_out, shape, stride_out, axes[iax]);
multi_iter<vlen, T, T> it(iax==0 ? in : out, out, axes[iax]);
size_t len=it.length_in();
if ((!plan) || (len!=plan->length()))
plan.reset(new pocketfft_r<T>(len));
......@@ -2204,34 +2223,32 @@ template<typename T> NOINLINE void pocketfft_general_hartley(const shape_t &shap
it.advance(1);
auto tdata = (T *)storage.data();
for (size_t i=0; i<len; ++i)
tdata[i] = it.in(0,i);
tdata[i] = it.in(i);
plan->forward((T *)tdata, fct);
// Hartley order
it.out(0,0) = tdata[0];
it.out(0) = tdata[0];
size_t i=1, i1=1, i2=len-1;
for (i=1; i<len-1; i+=2, ++i1, --i2)
{
it.out(0,i1) = tdata[i]+tdata[i+1];
it.out(0,i2) = tdata[i]-tdata[i+1];
it.out(i1) = tdata[i]+tdata[i+1];
it.out(i2) = tdata[i]-tdata[i+1];
}
if (i<len)
it.out(0,i1) = tdata[i];
it.out(i1) = tdata[i];
}
fct = T(1); // factor has been applied, use 1 for remaining axes
}
}
template<typename T> NOINLINE void pocketfft_general_r2c(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, size_t axis,
const T *data_in, cmplx<T> *data_out, T fct)
template<typename T> NOINLINE void pocketfft_general_r2c(
const ndarr<T> &in, ndarr<cmplx<T>> &out, size_t axis, T fct)
{
auto storage = alloc_tmp<T>(shape, shape[axis], sizeof(T));
auto storage = alloc_tmp<T>(in.shape(), in.shape(axis), sizeof(T));
pocketfft_r<T> plan(shape[axis]);
pocketfft_r<T> plan(in.shape(axis));
constexpr int vlen = VTYPE<T>::vlen;
multi_iter<vlen, T, cmplx<T>>
it(data_in, shape, stride_in, data_out, shape, stride_out, axis);
size_t len=shape[axis];
multi_iter<vlen, T, cmplx<T>> it(in, out, axis);
size_t len=in.shape(axis);
if (vlen>1)
while (it.remaining()>=vlen)
{
......@@ -2257,27 +2274,25 @@ template<typename T> NOINLINE void pocketfft_general_r2c(const shape_t &shape,
it.advance(1);
auto tdata = (T *)storage.data();
for (size_t i=0; i<len; ++i)
tdata[i] = it.in(0,i);
tdata[i] = it.in(i);
plan.forward(tdata, fct);
it.out(0,0).Set(tdata[0]);
it.out(0).Set(tdata[0]);
size_t i=1, ii=1;
for (; i<len-1; i+=2, ++ii)
it.out(0,ii).Set(tdata[i], tdata[i+1]);
it.out(ii).Set(tdata[i], tdata[i+1]);
if (i<len)
it.out(0,ii).Set(tdata[i]);
it.out(ii).Set(tdata[i]);
}
}
template<typename T> NOINLINE void pocketfft_general_c2r(const shape_t &shape_out,
const stride_t &stride_in, const stride_t &stride_out, size_t axis,
const cmplx<T> *data_in, T *data_out, T fct)
template<typename T> NOINLINE void pocketfft_general_c2r(
const ndarr<cmplx<T>> &in, ndarr<T> &out, size_t axis, T fct)
{
auto storage = alloc_tmp<T>(shape_out, shape_out[axis], sizeof(T));
pocketfft_r<T> plan(shape_out[axis]);
auto storage = alloc_tmp<T>(out.shape(), out.shape(axis), sizeof(T));
pocketfft_r<T> plan(out.shape(axis));
constexpr int vlen = VTYPE<T>::vlen;
multi_iter<vlen, cmplx<T>, T>
it(data_in, shape_out, stride_in, data_out, shape_out, stride_out, axis);
size_t len=shape_out[axis];
multi_iter<vlen, cmplx<T>, T> it(in, out, axis);
size_t len=out.shape(axis);
if (vlen>1)
while (it.remaining()>=vlen)
{
......@@ -2305,18 +2320,18 @@ template<typename T> NOINLINE void pocketfft_general_c2r(const shape_t &shape_ou
{
it.advance(1);
auto tdata = (T *)storage.data();
tdata[0]=it.in(0,0).r;
tdata[0]=it.in(0).r;
size_t i=1, ii=1;
for (; i<len-1; i+=2, ++ii)
{
tdata[i] = it.in(0,ii).r;
tdata[i+1] = it.in(0,ii).i;
tdata[i] = it.in(ii).r;
tdata[i+1] = it.in(ii).i;
}
if (i<len)
tdata[i] = it.in(0,ii).r;
tdata[i] = it.in(ii).r;
plan.backward(tdata, fct);
for (size_t i=0; i<len; ++i)
it.out(0,i) = tdata[i];
it.out(i) = tdata[i];
}
}
......@@ -2352,11 +2367,7 @@ stride_t copy_strides(const py::array &arr)
{
stride_t res(arr.ndim());
for (size_t i=0; i<res.size(); ++i)
{
res[i] = arr.strides(i)/arr.itemsize();
if (res[i]*arr.itemsize() != arr.strides(i))
throw runtime_error("weird strides");
}
res[i] = arr.strides(i);
return res;
}
......@@ -2383,9 +2394,9 @@ template<typename T> py::array xfftn_internal(const py::array &in,
{
auto dims(copy_shape(in));
py::array res = inplace ? in : py::array_t<complex<T>>(dims);
auto s_i(copy_strides(in)), s_o(copy_strides(res));
pocketfft_general_c<T>(dims, s_i, s_o, axes, fwd, (const cmplx<T> *)in.data(),
(cmplx<T> *)res.mutable_data(), fct);
ndarr<cmplx<T>> ain(in.data(), dims, copy_strides(in));
ndarr<cmplx<T>> aout(res.mutable_data(), dims, copy_strides(res));
pocketfft_general_c<T>(ain, aout, axes, fwd, fct);
return res;
}
......@@ -2408,13 +2419,12 @@ template<typename T> py::array rfftn_internal(const py::array &in,
auto dims_in(copy_shape(in)), dims_out(dims_in);
dims_out[axes.back()] = (dims_out[axes.back()]>>1)+1;
py::array res = py::array_t<complex<T>>(dims_out);
auto s_i(copy_strides(in)), s_o(copy_strides(res));
pocketfft_general_r2c<T>(dims_in, s_i, s_o, axes.back(), (const T *)in.data(),
(cmplx<T> *)res.mutable_data(), fct);
ndarr<T> ain(in.data(), dims_in, copy_strides(in));
ndarr<cmplx<T>> aout(res.mutable_data(), dims_out, copy_strides(res));
pocketfft_general_r2c<T>(ain, aout, axes.back(), fct);
if (axes.size()==1) return res;
shape_t axes2(axes.begin(), --axes.end());
pocketfft_general_c<T>(dims_out, s_o, s_o, axes2, true,
(const cmplx<T> *)res.data(), (cmplx<T> *)res.mutable_data(), 1.);
pocketfft_general_c<T>(aout, aout, axes2, true, 1.);
return res;
}
py::array rfftn(const py::array &in, py::object axes_, double fct)
......@@ -2436,8 +2446,9 @@ template<typename T> py::array irfftn_internal(const py::array &in,
auto dims_out(copy_shape(inter));
dims_out[axis] = lastsize;
py::array res = py::array_t<T>(dims_out);
pocketfft_general_c2r<T>(dims_out, copy_strides(inter), copy_strides(res),
axis, (const cmplx<T> *)inter.data(), (T *)res.mutable_data(), fct);
ndarr<cmplx<T>> ain(inter.data(), copy_shape(inter), copy_strides(inter));
ndarr<T> aout(res.mutable_data(), dims_out, copy_strides(res));
pocketfft_general_c2r<T>(ain, aout, axis, fct);
return res;
}
py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
......@@ -2453,8 +2464,9 @@ template<typename T> py::array hartley_internal(const py::array &in,
{
auto dims(copy_shape(in));
py::array res = inplace ? in : py::array_t<T>(dims);
pocketfft_general_hartley<T>(dims, copy_strides(in), copy_strides(res),
makeaxes(in, axes_), (const T *)in.data(), (T *)res.mutable_data(), fct);
ndarr<T> ain(in.data(), copy_shape(in), copy_strides(in));
ndarr<T> aout(res.mutable_data(), copy_shape(res), copy_strides(res));
pocketfft_general_hartley<T>(ain, aout, makeaxes(in, axes_), fct);
return res;
}
py::array hartley(const py::array &in, py::object axes_, double fct,
......@@ -2466,55 +2478,55 @@ py::array hartley(const py::array &in, py::object axes_, double fct,
}
template<typename T>py::array complex2hartley(const py::array &in,
const py::array &tmp, py::object axes_)
const py::array &tmp, py::object axes_, bool inplace)
{
int ndim = in.ndim();
auto dims_out(copy_shape(in));
py::array out = py::array_t<T>(dims_out);
auto dims_tmp(copy_shape(tmp));
auto stride_tmp(copy_strides(tmp)), stride_out(copy_strides(out));
py::array out = inplace ? in : py::array_t<T>(dims_out);
ndarr<cmplx<T>> atmp(tmp.data(), copy_shape(tmp), copy_strides(tmp));
ndarr<T> aout(out.mutable_data(), copy_shape(out), copy_strides(out));
auto axes = makeaxes(in, axes_);
size_t axis = axes.back();
const cmplx<T> *tdata = (const cmplx<T> *)tmp.data();
T *odata = (T *)out.mutable_data();
multi_iter<1,cmplx<T>,T> it(tdata, dims_tmp, stride_tmp, odata, dims_out, stride_out, axis);
vector<bool> swp(ndim-1,false);
multi_iter<1,cmplx<T>,T> it(atmp, aout, axis);
vector<bool> swp(ndim,false);
for (auto i: axes)
if (i!=axis)
swp[i<axis ? i : i-1] = true;
swp[i] = true;
while(it.remaining()>0)
{
int64_t rofs = 0;
for (size_t i=0; i<it.pos.size(); ++i)
{
if (i==axis) continue;
if (!swp[i])
rofs += it.pos[i]*it.dim[i].s_o;
rofs += it.pos[i]*it.oarr.stride(i);
else
{
auto x = (it.pos[i]==0) ? 0 : it.dim[i].n-it.pos[i];
rofs += x*it.dim[i].s_o;
auto x = (it.pos[i]==0) ? 0 : it.iarr.shape(i)-it.pos[i];
rofs += x*it.oarr.stride(i);
}
}
it.advance(1);
for (size_t i=0; i<it.length_in(); ++i)
{
auto re = it.in(0,i).r;
auto im = it.in(0,i).i;
auto re = it.in(i).r;
auto im = it.in(i).i;
auto rev_i = (i==0) ? 0 : it.length_out()-i;
it.out(0,i) = re+im;
odata[rofs + rev_i*it.stride_out()] = re-im;
it.out(i) = re+im;
aout[rofs + rev_i*it.stride_out()] = re-im;
}
}
return out;
}
py::array mycomplex2hartley(const py::array &in,
const py::array &tmp, py::object axes_)
const py::array &tmp, py::object axes_, bool inplace)
{
return tcheck(in, f64, f32) ? complex2hartley<double>(in, tmp, axes_)
: complex2hartley<float> (in, tmp, axes_);
return tcheck(in, f64, f32) ? complex2hartley<double>(in, tmp, axes_, inplace)
: complex2hartley<float> (in, tmp, axes_, inplace);
}
py::array hartley2(const py::array &in, py::object axes_, double fct)
{ return mycomplex2hartley(in, rfftn(in, axes_, fct), axes_); }
py::array hartley2(const py::array &in, py::object axes_, double fct,
bool inplace)
{ return mycomplex2hartley(in, rfftn(in, axes_, fct), axes_, inplace); }
const char *pypocketfft_DS = R"DELIM(Fast Fourier and Hartley transforms.
......@@ -2653,6 +2665,8 @@ PYBIND11_MODULE(pypocketfft, m)
"fct"_a=1.);
m.def("hartley",&hartley, hartley_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.,
"inplace"_a=false);
m.def("hartley2",&hartley2, "a"_a, "axes"_a=py::none(), "fct"_a=1.);
m.def("complex2hartley",&mycomplex2hartley, "in"_a, "tmp"_a, "axes"_a);
m.def("hartley2",&hartley2, "a"_a, "axes"_a=py::none(), "fct"_a=1.,
"inplace"_a=false);
m.def("complex2hartley",&mycomplex2hartley, "in"_a, "tmp"_a, "axes"_a,
"inplace"_a=false);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment