Commit d2dd06b3 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

compactify

parent fbee59ac
......@@ -2352,6 +2352,14 @@ auto c128 = py::dtype("complex128");
auto f32 = py::dtype("float32");
auto f64 = py::dtype("float64");
vector<size_t> copy_shape(const py::array &arr)
{
vector<size_t> res(arr.ndim());
for (size_t i=0; i<res.size(); ++i)
res[i] = arr.shape(i);
return res;
}
vector<size_t> makeaxes(const py::array &in, py::object axes)
{
if (axes.is(py::none()))
......@@ -2386,10 +2394,8 @@ void make_strides(const py::array &in, const py::array &out,
py::array execute(const py::array &in, vector<size_t> axes, double fct, bool fwd)
{
py::array res;
vector<size_t> dims(in.ndim());
auto dims(copy_shape(in));
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
for (int i=0; i<in.ndim(); ++i)
dims[i] = in.shape(i);
if (in.dtype().is(c128))
res = py::array_t<complex<double>>(dims);
else if (in.dtype().is(c64))
......@@ -2411,48 +2417,34 @@ py::array fftn(const py::array &a, py::object axes, double fct)
py::array ifftn(const py::array &a, py::object axes, double fct)
{ return execute(a, makeaxes(a, axes), fct, false); }
py::array rfftn(const py::array &in, py::object axes_, double fct)
template<typename T> py::array rfftn_internal(const py::array &in,
py::object axes_, T fct)
{
auto axes = makeaxes(in, axes_);
py::array res;
vector<size_t> dims_in(in.ndim()), dims_out(in.ndim());
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
for (int i=0; i<in.ndim(); ++i)
{
dims_in[i] = in.shape(i);
dims_out[i] = in.shape(i);
}
auto dims_in(copy_shape(in)), dims_out(dims_in);
dims_out[axes.back()] = (dims_out[axes.back()]>>1)+1;
if (in.dtype().is(f64))
res = py::array_t<complex<double>>(dims_out);
else if (in.dtype().is(f32))
res = py::array_t<complex<float>>(dims_out);
else throw runtime_error("unsupported data type");
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
py::array res = py::array_t<complex<T>>(dims_out);
make_strides(in, res, s_i, s_o);
if (in.dtype().is(f64))
pocketfft_general_r2c<double>(dims_in,
s_i, s_o,
axes.back(), (const double *)in.data(),
(cmplx<double> *)res.mutable_data(), fct);
else
pocketfft_general_r2c<float>(dims_in,
s_i, s_o,
axes.back(), (const float *)in.data(),
(cmplx<float> *)res.mutable_data(), fct);
pocketfft_general_r2c<T>(dims_in, s_i, s_o, axes.back(), (const T *)in.data(),
(cmplx<T> *)res.mutable_data(), fct);
if (axes.size()==1) return res;
vector<size_t> axes2(axes.size()-1);
for (size_t i=0; i<axes2.size(); ++i)
axes2[i] = axes[i];
if (in.dtype().is(f64))
pocketfft_general_c<double>(dims_out,
s_o, s_o, axes2, true, (const cmplx<double> *)res.data(),
(cmplx<double> *)res.mutable_data(), 1.);
else
pocketfft_general_c<float>(dims_out,
s_o, s_o, axes2, true, (const cmplx<float> *)res.data(),
(cmplx<float> *)res.mutable_data(), 1.);
pocketfft_general_c<T>(dims_out, s_o, s_o, axes2, true,
(const cmplx<T> *)res.data(),
(cmplx<T> *)res.mutable_data(), 1.);
return res;
}
py::array rfftn(const py::array &in, py::object axes_, double fct)
{
if (in.dtype().is(f64))
return rfftn_internal<double>(in, axes_, fct);
else if (in.dtype().is(f32))
return rfftn_internal<float>(in, axes_, fct);
else throw runtime_error("unsupported data type");
}
py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
double fct)
{
......@@ -2472,13 +2464,8 @@ py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
if (lastsize==0) lastsize=2*inter.shape(axis)-1;
if (int64_t(lastsize/2) + 1 != inter.shape(axis))
throw runtime_error("bad lastsize");
vector<size_t> dims_in(inter.ndim()), dims_out(inter.ndim());
auto dims_in(copy_shape(inter)), dims_out(dims_in);
vector<int64_t> s_i(inter.ndim()), s_o(inter.ndim());
for (int i=0; i<inter.ndim(); ++i)
{
dims_in[i] = inter.shape(i);
dims_out[i] = inter.shape(i);
}
dims_out[axis] = lastsize;
py::array res;
if (inter.dtype().is(c128))
......@@ -2499,48 +2486,37 @@ py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
(float *)res.mutable_data(), fct);
return res;
}
py::array hartley(const py::array &in, py::object axes_, double fct)
template<typename T> py::array hartley_internal(const py::array &in,
py::object axes_, double fct)
{
auto axes = makeaxes(in, axes_);
py::array res;
vector<size_t> dims(in.ndim());
auto dims(copy_shape(in));
py::array res = py::array_t<T>(dims);
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
for(int i=0; i<in.ndim(); ++i)
dims[i] = in.shape(i);
if (in.dtype().is(f64))
res = py::array_t<double>(dims);
else if (in.dtype().is(f32))
res = py::array_t<float>(dims);
else throw runtime_error("unsupported data type");
make_strides(in, res, s_i, s_o);
if (in.dtype().is(f64))
pocketfft_general_hartley<double>(dims,
s_i, s_o, axes, (const double *)in.data(),
(double *)res.mutable_data(), 1.);
else
pocketfft_general_hartley<float>(dims,
s_i, s_o, axes, (const float *)in.data(),
(float *)res.mutable_data(), 1.);
pocketfft_general_hartley<T>(dims,
s_i, s_o, axes, (const T *)in.data(),
(T *)res.mutable_data(), 1.);
return res;
}
py::array hartley(const py::array &in, py::object axes_, double fct)
{
if (in.dtype().is(f64))
return hartley_internal<double>(in, axes_, fct);
else if (in.dtype().is(f32))
return hartley_internal<float>(in, axes_, fct);
throw runtime_error("unsupported data type");
}
template<typename T>py::array complex2hartley(const py::array &in,
const py::array &tmp, py::object axes_)
{
int ndim = in.ndim();
vector<size_t> dims_out(in.ndim());
for (int i=0; i<ndim; ++i)
dims_out[i] = in.shape(i);
py::array out;
if (in.dtype().is(f64))
out = py::array_t<double>(dims_out);
else if (in.dtype().is(f32))
out = py::array_t<float>(dims_out);
else throw runtime_error("unsupported data type");
vector<size_t> dims_tmp(ndim);
auto dims_out(copy_shape(in));
py::array out = py::array_t<T>(dims_out);
auto dims_tmp(copy_shape(tmp));
vector<int64_t> stride_tmp(ndim), stride_out(ndim);
for (int i=0; i<ndim; ++i)
dims_tmp[i] = tmp.shape(i);
make_strides(tmp, out, stride_tmp, stride_out);
auto axes = makeaxes(in, axes_);
size_t axis = axes.back();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment