Commit 84f57346 authored by Peter Bell's avatar Peter Bell

Support python-style negative indices in axes arguments

parent 89d506fd
......@@ -63,13 +63,18 @@ shape_t makeaxes(const py::array &in, py::object axes)
res[i]=i;
return res;
}
auto tmp=axes.cast<shape_t>();
if ((tmp.size()>size_t(in.ndim())) || (tmp.size()==0))
auto tmp=axes.cast<std::vector<ptrdiff_t>>();
auto ndim = in.ndim();
if ((tmp.size()>size_t(ndim)) || (tmp.size()==0))
throw runtime_error("bad axes argument");
for (auto sz: tmp)
if (sz>=size_t(in.ndim()))
throw runtime_error("invalid axis number");
return tmp;
for (auto& sz: tmp)
{
if (sz<0)
sz += ndim;
if ((sz>=ndim) || (sz<0))
throw invalid_argument("axes exceeds dimensionality of output");
}
return shape_t(tmp.begin(), tmp.end());
}
#define DISPATCH(arr, T1, T2, T3, func, args) \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment