Commit 97730839 authored by Martin Reinecke's avatar Martin Reinecke

cleanup

parent d2dd06b3
......@@ -2359,6 +2359,17 @@ vector<size_t> copy_shape(const py::array &arr)
res[i] = arr.shape(i);
return res;
}
vector<int64_t> copy_strides(const py::array &arr)
{
vector<int64_t> res(arr.ndim());
for (size_t i=0; i<res.size(); ++i)
{
res[i] = arr.strides(i)/arr.itemsize();
if (res[i]*arr.itemsize() != arr.strides(i))
throw runtime_error("weird strides");
}
return res;
}
vector<size_t> makeaxes(const py::array &in, py::object axes)
{
......@@ -2377,45 +2388,33 @@ vector<size_t> makeaxes(const py::array &in, py::object axes)
throw runtime_error("invalid axis number");
return tmp;
}
void make_strides(const py::array &in, const py::array &out,
vector<int64_t> &stride_in, vector<int64_t> &stride_out)
{
for (size_t i=0; i<stride_in.size(); ++i)
{
stride_in[i] = in.strides(i)/in.itemsize();
if (stride_in[i]*in.itemsize() != in.strides(i))
throw runtime_error("weird strides");
stride_out[i] = out.strides(i)/out.itemsize();
if (stride_out[i]*out.itemsize() != out.strides(i))
throw runtime_error("weird strides");
}
}
py::array execute(const py::array &in, vector<size_t> axes, double fct, bool fwd)
template<typename T> py::array execute(const py::array &in, const vector<size_t> &axes, double fct, bool fwd)
{
py::array res;
auto dims(copy_shape(in));
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
if (in.dtype().is(c128))
res = py::array_t<complex<double>>(dims);
else if (in.dtype().is(c64))
res = py::array_t<complex<float>>(dims);
else throw runtime_error("unsupported data type");
make_strides(in, res, s_i, s_o);
if (in.dtype().is(c128))
pocketfft_general_c<double>(dims,
s_i, s_o, axes, fwd, (const cmplx<double> *)in.data(),
(cmplx<double> *)res.mutable_data(), fct);
else
pocketfft_general_c<float>(dims,
s_i, s_o, axes, fwd, (const cmplx<float> *)in.data(),
(cmplx<float> *)res.mutable_data(), fct);
py::array res = py::array_t<complex<T>>(dims);
auto s_i(copy_strides(in)), s_o(copy_strides(res));
pocketfft_general_c<T>(dims,
s_i, s_o, axes, fwd, (const cmplx<T> *)in.data(),
(cmplx<T> *)res.mutable_data(), fct);
return res;
}
py::array fftn(const py::array &a, py::object axes, double fct)
{ return execute(a, makeaxes(a, axes), fct, true); }
{
if (a.dtype().is(c128))
return execute<double>(a, makeaxes(a, axes), fct, true);
else if (a.dtype().is(c64))
return execute<float>(a, makeaxes(a, axes), fct, true);
throw runtime_error("unsupported data type");
}
py::array ifftn(const py::array &a, py::object axes, double fct)
{ return execute(a, makeaxes(a, axes), fct, false); }
{
if (a.dtype().is(c128))
return execute<double>(a, makeaxes(a, axes), fct, false);
else if (a.dtype().is(c64))
return execute<float>(a, makeaxes(a, axes), fct, false);
throw runtime_error("unsupported data type");
}
template<typename T> py::array rfftn_internal(const py::array &in,
py::object axes_, T fct)
......@@ -2423,9 +2422,8 @@ template<typename T> py::array rfftn_internal(const py::array &in,
auto axes = makeaxes(in, axes_);
auto dims_in(copy_shape(in)), dims_out(dims_in);
dims_out[axes.back()] = (dims_out[axes.back()]>>1)+1;
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
py::array res = py::array_t<complex<T>>(dims_out);
make_strides(in, res, s_i, s_o);
auto s_i(copy_strides(in)), s_o(copy_strides(res));
pocketfft_general_r2c<T>(dims_in, s_i, s_o, axes.back(), (const T *)in.data(),
(cmplx<T> *)res.mutable_data(), fct);
if (axes.size()==1) return res;
......@@ -2433,8 +2431,7 @@ template<typename T> py::array rfftn_internal(const py::array &in,
for (size_t i=0; i<axes2.size(); ++i)
axes2[i] = axes[i];
pocketfft_general_c<T>(dims_out, s_o, s_o, axes2, true,
(const cmplx<T> *)res.data(),
(cmplx<T> *)res.mutable_data(), 1.);
(const cmplx<T> *)res.data(), (cmplx<T> *)res.mutable_data(), 1.);
return res;
}
py::array rfftn(const py::array &in, py::object axes_, double fct)
......@@ -2445,8 +2442,8 @@ py::array rfftn(const py::array &in, py::object axes_, double fct)
return rfftn_internal<float>(in, axes_, fct);
else throw runtime_error("unsupported data type");
}
py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
double fct)
template<typename T> py::array irfftn_internal(const py::array &in,
py::object axes_, size_t lastsize, T fct)
{
auto axes = makeaxes(in, axes_);
py::array inter;
......@@ -2455,7 +2452,7 @@ py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
vector<size_t> axes2(axes.size()-1);
for (size_t i=0; i<axes2.size(); ++i)
axes2[i] = axes[i];
inter = execute(in, axes2, 1., false);
inter = execute<T>(in, axes2, 1., false);
}
else
inter = in;
......@@ -2465,27 +2462,22 @@ py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
if (int64_t(lastsize/2) + 1 != inter.shape(axis))
throw runtime_error("bad lastsize");
auto dims_in(copy_shape(inter)), dims_out(dims_in);
vector<int64_t> s_i(inter.ndim()), s_o(inter.ndim());
dims_out[axis] = lastsize;
py::array res;
if (inter.dtype().is(c128))
res = py::array_t<double>(dims_out);
else if (inter.dtype().is(c64))
res = py::array_t<float>(dims_out);
else throw runtime_error("unsupported data type");
make_strides(inter, res, s_i, s_o);
if (inter.dtype().is(c128))
pocketfft_general_c2r<double>(dims_out,
s_i, s_o,
axis, (const cmplx<double> *)inter.data(),
(double *)res.mutable_data(), fct);
else
pocketfft_general_c2r<float>(dims_out,
s_i, s_o,
axis, (const cmplx<float> *)inter.data(),
(float *)res.mutable_data(), fct);
py::array res = py::array_t<T>(dims_out);
auto s_i(copy_strides(inter)), s_o(copy_strides(res));
pocketfft_general_c2r<T>(dims_out, s_i, s_o, axis,
(const cmplx<T> *)inter.data(), (T *)res.mutable_data(), fct);
return res;
}
py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
double fct)
{
if (in.dtype().is(c128))
return irfftn_internal<double>(in, axes_, lastsize, fct);
else if (in.dtype().is(c64))
return irfftn_internal<float>(in, axes_, lastsize, fct);
throw runtime_error("unsupported data type");
}
template<typename T> py::array hartley_internal(const py::array &in,
py::object axes_, double fct)
......@@ -2493,10 +2485,8 @@ template<typename T> py::array hartley_internal(const py::array &in,
auto axes = makeaxes(in, axes_);
auto dims(copy_shape(in));
py::array res = py::array_t<T>(dims);
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
make_strides(in, res, s_i, s_o);
pocketfft_general_hartley<T>(dims,
s_i, s_o, axes, (const T *)in.data(),
auto s_i(copy_strides(in)), s_o(copy_strides(res));
pocketfft_general_hartley<T>(dims, s_i, s_o, axes, (const T *)in.data(),
(T *)res.mutable_data(), 1.);
return res;
}
......@@ -2516,8 +2506,7 @@ template<typename T>py::array complex2hartley(const py::array &in,
auto dims_out(copy_shape(in));
py::array out = py::array_t<T>(dims_out);
auto dims_tmp(copy_shape(tmp));
vector<int64_t> stride_tmp(ndim), stride_out(ndim);
make_strides(tmp, out, stride_tmp, stride_out);
auto stride_tmp(copy_strides(tmp)), stride_out(copy_strides(out));
auto axes = makeaxes(in, axes_);
size_t axis = axes.back();
multiarr a_tmp(dims_tmp, stride_tmp),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment