Commit 89a7dd70 authored by Martin Reinecke's avatar Martin Reinecke

synchronize

parent 39de2274
......@@ -377,6 +377,42 @@ struct util // hack to avoid duplicate symbols
res*=sz;
return res;
}
static NOINLINE void sanity_check(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, bool inplace)
{
auto ndim = shape.size();
if (ndim<1) throw runtime_error("ndim must be >= 1");
if ((stride_in.size()!=ndim) || (stride_out.size()!=ndim))
throw runtime_error("stride dimension mismatch");
for (auto shp: shape)
if (shp<1) throw runtime_error("zero extent detected");
if (inplace)
for (size_t i=0; i<ndim; ++i)
if (stride_in[i]!=stride_out[i]) throw runtime_error("stride mismatch");
}
static NOINLINE void sanity_check(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, bool inplace,
const shape_t &axes)
{
sanity_check(shape, stride_in, stride_out, inplace);
auto ndim = shape.size();
shape_t tmp(ndim,0);
for (auto ax : axes)
{
if (ax>=ndim) throw runtime_error("bad axis number");
if (++tmp[ax]>1) throw runtime_error("axis specified repeatedly");
}
}
static NOINLINE void sanity_check(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, bool inplace,
size_t axis)
{
sanity_check(shape, stride_in, stride_out, inplace);
if (axis>=shape.size()) throw runtime_error("bad axis number");
}
};
#define CH(a,b,c) ch[(a)+ido*((b)+l1*(c))]
......@@ -2350,6 +2386,7 @@ template<typename T> void c2c(const shape_t &shape,
bool forward, const void *data_in, void *data_out, T fct)
{
using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
ndarr<cmplx<T>> ain(data_in, shape, stride_in),
aout(data_out, shape, stride_out);
general_c(ain, aout, axes, forward, fct);
......@@ -2360,16 +2397,34 @@ template<typename T> void r2c(const shape_t &shape,
const void *data_in, void *data_out, T fct)
{
using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axis);
ndarr<T> ain(data_in, shape, stride_in);
ndarr<cmplx<T>> aout(data_out, shape, stride_out);
general_r2c(ain, aout, axis, fct);
}
template<typename T> void r2c(const shape_t &shape, const stride_t &stride_in,
const stride_t &stride_out, const shape_t &axes, const void *data_in,
void *data_out, T fct)
{
using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
r2c(shape, stride_in, stride_out, axes.back(), data_in, data_out, fct);
if (axes.size()==1) return;
shape_t shape_out(shape);
shape_out[axes.back()] = shape[axes.back()]/2 + 1;
auto newaxes = shape_t{axes.begin(), --axes.end()};
c2c(shape_out, stride_out, stride_out, newaxes, true, data_out, data_out,
T(1));
}
template<typename T> void c2r(const shape_t &shape, size_t new_size,
const stride_t &stride_in, const stride_t &stride_out, size_t axis,
const void *data_in, void *data_out, T fct)
{
using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axis);
shape_t shape_out(shape);
shape_out[axis] = new_size;
ndarr<cmplx<T>> ain(data_in, shape, stride_in);
......@@ -2377,11 +2432,36 @@ template<typename T> void c2r(const shape_t &shape, size_t new_size,
general_c2r(ain, aout, axis, fct);
}
template<typename T> void c2r(const shape_t &shape, size_t new_size,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
const void *data_in, void *data_out, T fct)
{
using namespace detail;
if (axes.size()==1)
{
c2r(shape, new_size, stride_in, stride_out, axes[0],
data_in, data_out, fct);
return;
}
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
auto nval = util::prod(shape);
stride_t stride_inter(shape.size());
stride_inter.back() = sizeof(cmplx<T>);
for (int i=shape.size()-2; i>=0; --i)
stride_inter[i] = stride_inter[i+1]*shape[i+1];
arr<char> tmp(nval*sizeof(cmplx<T>));
auto newaxes = shape_t({axes.begin(), --axes.end()});
c2c(shape, stride_in, stride_inter, newaxes, false, data_in, tmp.data(), T(1));
c2r(shape, new_size, stride_inter, stride_out, axes.back(),
tmp.data(), data_out, fct);
}
template<typename T> void r2r_fftpack(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, size_t axis,
bool forward, const void *data_in, void *data_out, T fct)
{
using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axis);
ndarr<T> ain(data_in, shape, stride_in), aout(data_out, shape, stride_out);
general_r(ain, aout, axis, forward, fct);
}
......@@ -2391,6 +2471,7 @@ template<typename T> void r2r_hartley(const shape_t &shape,
const void *data_in, void *data_out, T fct)
{
using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
ndarr<T> ain(data_in, shape, stride_in), aout(data_out, shape, stride_out);
general_hartley(ain, aout, axes, fct);
}
......
......@@ -26,7 +26,6 @@ namespace {
using namespace std;
using namespace pocketfft;
using namespace pocketfft::detail;
namespace py = pybind11;
......@@ -83,9 +82,8 @@ template<typename T> py::array xfftn_internal(const py::array &in,
{
auto dims(copy_shape(in));
py::array res = inplace ? in : py::array_t<complex<T>>(dims);
ndarr<cmplx<T>> ain(in.data(), dims, copy_strides(in));
ndarr<cmplx<T>> aout(res.mutable_data(), dims, copy_strides(res));
general_c<T>(ain, aout, axes, fwd, fct);
c2c(dims, copy_strides(in), copy_strides(res), axes, fwd, in.data(),
res.mutable_data(), T(fct));
return res;
}
......@@ -108,12 +106,8 @@ template<typename T> py::array rfftn_internal(const py::array &in,
auto dims_in(copy_shape(in)), dims_out(dims_in);
dims_out[axes.back()] = (dims_out[axes.back()]>>1)+1;
py::array res = py::array_t<complex<T>>(dims_out);
ndarr<T> ain(in.data(), dims_in, copy_strides(in));
ndarr<cmplx<T>> aout(res.mutable_data(), dims_out, copy_strides(res));
general_r2c<T>(ain, aout, axes.back(), fct);
if (axes.size()==1) return res;
shape_t axes2(axes.begin(), --axes.end());
general_c<T>(aout, aout, axes2, true, 1.);
r2c(dims_in, copy_strides(in), copy_strides(res), axes, in.data(),
res.mutable_data(), T(fct));
return res;
}
py::array rfftn(const py::array &in, py::object axes_, double fct)
......@@ -126,9 +120,8 @@ template<typename T> py::array xrfft_scipy(const py::array &in,
{
auto dims(copy_shape(in));
py::array res = inplace ? in : py::array_t<T>(dims);
ndarr<T> ain(in.data(), dims, copy_strides(in));
ndarr<T> aout(res.mutable_data(), dims, copy_strides(res));
general_r<T>(ain, aout, axis, fwd, fct);
r2r_fftpack(dims, copy_strides(in), copy_strides(res), axis, fwd,
in.data(), res.mutable_data(), T(fct));
return res;
}
py::array rfft_scipy(const py::array &in, size_t axis, double fct, bool inplace)
......@@ -148,19 +141,15 @@ template<typename T> py::array irfftn_internal(const py::array &in,
py::object axes_, size_t lastsize, T fct)
{
auto axes = makeaxes(in, axes_);
py::array inter = (axes.size()==1) ? in :
xfftn_internal<T>(in, shape_t(axes.begin(), --axes.end()), 1., false, false);
size_t axis = axes.back();
if (lastsize==0) lastsize=2*inter.shape(axis)-1;
if (ptrdiff_t(lastsize/2) + 1 != inter.shape(axis))
shape_t dims_in(copy_shape(in)), dims_out=dims_in;
if (lastsize==0) lastsize=2*dims_in[axis]-1;
if ((lastsize/2) + 1 != dims_in[axis])
throw runtime_error("bad lastsize");
auto dims_out(copy_shape(inter));
dims_out[axis] = lastsize;
py::array res = py::array_t<T>(dims_out);
ndarr<cmplx<T>> ain(inter.data(), copy_shape(inter), copy_strides(inter));
ndarr<T> aout(res.mutable_data(), dims_out, copy_strides(res));
general_c2r<T>(ain, aout, axis, fct);
c2r(dims_in, lastsize, copy_strides(in), copy_strides(res), axes,
in.data(), res.mutable_data(), T(fct));
return res;
}
py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
......@@ -176,9 +165,8 @@ template<typename T> py::array hartley_internal(const py::array &in,
{
auto dims(copy_shape(in));
py::array res = inplace ? in : py::array_t<T>(dims);
ndarr<T> ain(in.data(), copy_shape(in), copy_strides(in));
ndarr<T> aout(res.mutable_data(), copy_shape(res), copy_strides(res));
general_hartley<T>(ain, aout, makeaxes(in, axes_), fct);
r2r_hartley(dims, copy_strides(in), copy_strides(res), makeaxes(in, axes_),
in.data(), res.mutable_data(), T(fct));
return res;
}
py::array hartley(const py::array &in, py::object axes_, double fct,
......@@ -192,6 +180,7 @@ py::array hartley(const py::array &in, py::object axes_, double fct,
template<typename T>py::array complex2hartley(const py::array &in,
const py::array &tmp, py::object axes_, bool inplace)
{
using namespace pocketfft::detail;
int ndim = in.ndim();
auto dims_out(copy_shape(in));
py::array out = inplace ? in : py::array_t<T>(dims_out);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment