Commit 3c7e76d6 authored by Martin Reinecke's avatar Martin Reinecke

improve error handling; throw invalid_argument instead of runtime_error in some cases

parent 5483e02e
...@@ -474,8 +474,8 @@ struct util // hack to avoid duplicate symbols ...@@ -474,8 +474,8 @@ struct util // hack to avoid duplicate symbols
shape_t tmp(ndim,0); shape_t tmp(ndim,0);
for (auto ax : axes) for (auto ax : axes)
{ {
if (ax>=ndim) throw runtime_error("bad axis number"); if (ax>=ndim) throw invalid_argument("bad axis number");
if (++tmp[ax]>1) throw runtime_error("axis specified repeatedly"); if (++tmp[ax]>1) throw invalid_argument("axis specified repeatedly");
} }
} }
...@@ -484,7 +484,7 @@ struct util // hack to avoid duplicate symbols ...@@ -484,7 +484,7 @@ struct util // hack to avoid duplicate symbols
size_t axis) size_t axis)
{ {
sanity_check(shape, stride_in, stride_out, inplace); sanity_check(shape, stride_in, stride_out, inplace);
if (axis>=shape.size()) throw runtime_error("bad axis number"); if (axis>=shape.size()) throw invalid_argument("bad axis number");
} }
#ifdef POCKETFFT_OPENMP #ifdef POCKETFFT_OPENMP
...@@ -1245,7 +1245,7 @@ template<bool fwd, typename T> void pass_all(T c[], T0 fct) ...@@ -1245,7 +1245,7 @@ template<bool fwd, typename T> void pass_all(T c[], T0 fct)
POCKETFFT_NOINLINE cfftp(size_t length_) POCKETFFT_NOINLINE cfftp(size_t length_)
: length(length_) : length(length_)
{ {
if (length==0) throw runtime_error("zero length FFT requested"); if (length==0) throw runtime_error("zero-length FFT requested");
if (length==1) return; if (length==1) return;
factorize(); factorize();
mem.resize(twsize()); mem.resize(twsize());
...@@ -2085,7 +2085,7 @@ template<typename T> void radbg(size_t ido, size_t ip, size_t l1, ...@@ -2085,7 +2085,7 @@ template<typename T> void radbg(size_t ido, size_t ip, size_t l1,
POCKETFFT_NOINLINE rfftp(size_t length_) POCKETFFT_NOINLINE rfftp(size_t length_)
: length(length_) : length(length_)
{ {
if (length==0) throw runtime_error("zero-sized FFT"); if (length==0) throw runtime_error("zero-length FFT requested");
if (length==1) return; if (length==1) return;
factorize(); factorize();
mem.resize(twsize()); mem.resize(twsize());
...@@ -3298,7 +3298,7 @@ template<typename T> void r2r_dct(const shape_t &shape, ...@@ -3298,7 +3298,7 @@ template<typename T> void r2r_dct(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
int type, const T *data_in, T *data_out, T fct, size_t nthreads=1) int type, const T *data_in, T *data_out, T fct, size_t nthreads=1)
{ {
if ((type<1) || (type>4)) throw runtime_error("invalid DCT type"); if ((type<1) || (type>4)) throw invalid_argument("invalid DCT type");
if (util::prod(shape)==0) return; if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<T> ain(data_in, shape, stride_in); cndarr<T> ain(data_in, shape, stride_in);
...@@ -3319,7 +3319,7 @@ template<typename T> void r2r_dst(const shape_t &shape, ...@@ -3319,7 +3319,7 @@ template<typename T> void r2r_dst(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
int type, const T *data_in, T *data_out, T fct, size_t nthreads=1) int type, const T *data_in, T *data_out, T fct, size_t nthreads=1)
{ {
if ((type<1) || (type>4)) throw runtime_error("invalid DST type"); if ((type<1) || (type>4)) throw invalid_argument("invalid DST type");
if (util::prod(shape)==0) return; if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<T> ain(data_in, shape, stride_in); cndarr<T> ain(data_in, shape, stride_in);
......
...@@ -250,7 +250,7 @@ template<typename T> py::array r2r_dct_internal(const py::array &in, ...@@ -250,7 +250,7 @@ template<typename T> py::array r2r_dct_internal(const py::array &in,
py::array r2r_dct(const py::array &in, int type, const py::object &axes_, py::array r2r_dct(const py::array &in, int type, const py::object &axes_,
int inorm, py::object &out_, size_t nthreads) int inorm, py::object &out_, size_t nthreads)
{ {
if ((type<1) || (type>4)) throw runtime_error("invalid DCT type"); if ((type<1) || (type>4)) throw invalid_argument("invalid DCT type");
DISPATCH(in, f64, f32, flong, r2r_dct_internal, (in, axes_, type, inorm, DISPATCH(in, f64, f32, flong, r2r_dct_internal, (in, axes_, type, inorm,
out_, nthreads)) out_, nthreads))
} }
...@@ -279,7 +279,7 @@ template<typename T> py::array r2r_dst_internal(const py::array &in, ...@@ -279,7 +279,7 @@ template<typename T> py::array r2r_dst_internal(const py::array &in,
py::array r2r_dst(const py::array &in, int type, const py::object &axes_, py::array r2r_dst(const py::array &in, int type, const py::object &axes_,
int inorm, py::object &out_, size_t nthreads) int inorm, py::object &out_, size_t nthreads)
{ {
if ((type<1) || (type>4)) throw runtime_error("invalid DST type"); if ((type<1) || (type>4)) throw invalid_argument("invalid DST type");
DISPATCH(in, f64, f32, flong, r2r_dst_internal, (in, axes_, type, inorm, DISPATCH(in, f64, f32, flong, r2r_dst_internal, (in, axes_, type, inorm,
out_, nthreads)) out_, nthreads))
} }
...@@ -293,7 +293,7 @@ template<typename T> py::array c2r_internal(const py::array &in, ...@@ -293,7 +293,7 @@ template<typename T> py::array c2r_internal(const py::array &in,
shape_t dims_in(copy_shape(in)), dims_out=dims_in; shape_t dims_in(copy_shape(in)), dims_out=dims_in;
if (lastsize==0) lastsize=2*dims_in[axis]-1; if (lastsize==0) lastsize=2*dims_in[axis]-1;
if ((lastsize/2) + 1 != dims_in[axis]) if ((lastsize/2) + 1 != dims_in[axis])
throw runtime_error("bad lastsize"); throw invalid_argument("bad lastsize");
dims_out[axis] = lastsize; dims_out[axis] = lastsize;
py::array res = prepare_output<T>(out_, dims_out); py::array res = prepare_output<T>(out_, dims_out);
auto s_in=copy_strides(in); auto s_in=copy_strides(in);
...@@ -355,7 +355,6 @@ template<typename T>py::array complex2hartley(const py::array &in, ...@@ -355,7 +355,6 @@ template<typename T>py::array complex2hartley(const py::array &in,
py::gil_scoped_release release; py::gil_scoped_release release;
simple_iter iin(atmp); simple_iter iin(atmp);
rev_iter iout(aout, axes); rev_iter iout(aout, axes);
if (iin.remaining()!=iout.remaining()) throw runtime_error("oops");
while(iin.remaining()>0) while(iin.remaining()>0)
{ {
auto v = atmp[iin.ofs()]; auto v = atmp[iin.ofs()];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment