Skip to content
Snippets Groups Projects

Support python-style negative indices in axes arguments

Merged Peter Bell requested to merge negative-axes into new_norm
1 file
+ 11
6
Compare changes
  • Side-by-side
  • Inline
+ 11
6
@@ -63,13 +63,18 @@ shape_t makeaxes(const py::array &in, py::object axes)
@@ -63,13 +63,18 @@ shape_t makeaxes(const py::array &in, py::object axes)
res[i]=i;
res[i]=i;
return res;
return res;
}
}
auto tmp=axes.cast<shape_t>();
auto tmp=axes.cast<std::vector<ptrdiff_t>>();
if ((tmp.size()>size_t(in.ndim())) || (tmp.size()==0))
auto ndim = in.ndim();
 
if ((tmp.size()>size_t(ndim)) || (tmp.size()==0))
throw runtime_error("bad axes argument");
throw runtime_error("bad axes argument");
for (auto sz: tmp)
for (auto& sz: tmp)
if (sz>=size_t(in.ndim()))
{
throw runtime_error("invalid axis number");
if (sz<0)
return tmp;
sz += ndim;
 
if ((sz>=ndim) || (sz<0))
 
throw invalid_argument("axes exceeds dimensionality of output");
 
}
 
return shape_t(tmp.begin(), tmp.end());
}
}
#define DISPATCH(arr, T1, T2, T3, func, args) \
#define DISPATCH(arr, T1, T2, T3, func, args) \
Loading