diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index 1fbb503dfb746d3f99d47d1d2c141a54989be523..2c76ad863ba3bc8521c6f06d6d7f312001d2cbad 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -474,8 +474,8 @@ struct util // hack to avoid duplicate symbols shape_t tmp(ndim,0); for (auto ax : axes) { - if (ax>=ndim) throw runtime_error("bad axis number"); - if (++tmp[ax]>1) throw runtime_error("axis specified repeatedly"); + if (ax>=ndim) throw invalid_argument("bad axis number"); + if (++tmp[ax]>1) throw invalid_argument("axis specified repeatedly"); } } @@ -484,7 +484,7 @@ struct util // hack to avoid duplicate symbols size_t axis) { sanity_check(shape, stride_in, stride_out, inplace); - if (axis>=shape.size()) throw runtime_error("bad axis number"); + if (axis>=shape.size()) throw invalid_argument("bad axis number"); } #ifdef POCKETFFT_OPENMP @@ -1245,7 +1245,7 @@ template void pass_all(T c[], T0 fct) POCKETFFT_NOINLINE cfftp(size_t length_) : length(length_) { - if (length==0) throw runtime_error("zero length FFT requested"); + if (length==0) throw runtime_error("zero-length FFT requested"); if (length==1) return; factorize(); mem.resize(twsize()); @@ -2085,7 +2085,7 @@ template void radbg(size_t ido, size_t ip, size_t l1, POCKETFFT_NOINLINE rfftp(size_t length_) : length(length_) { - if (length==0) throw runtime_error("zero-sized FFT"); + if (length==0) throw runtime_error("zero-length FFT requested"); if (length==1) return; factorize(); mem.resize(twsize()); @@ -3298,7 +3298,7 @@ template void r2r_dct(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, int type, const T *data_in, T *data_out, T fct, size_t nthreads=1) { - if ((type<1) || (type>4)) throw runtime_error("invalid DCT type"); + if ((type<1) || (type>4)) throw invalid_argument("invalid DCT type"); if (util::prod(shape)==0) return; util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); cndarr ain(data_in, shape, stride_in); @@ -3319,7 +3319,7 @@ template void r2r_dst(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, int type, const T *data_in, T *data_out, T fct, size_t nthreads=1) { - if ((type<1) || (type>4)) throw runtime_error("invalid DST type"); + if ((type<1) || (type>4)) throw invalid_argument("invalid DST type"); if (util::prod(shape)==0) return; util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); cndarr ain(data_in, shape, stride_in); diff --git a/pypocketfft.cc b/pypocketfft.cc index 4137e9e1824a783cbe81b66a7192f4e73d2cff9e..fa57a574b7011427b4f1ff0b84f7c7f01e304d5f 100644 --- a/pypocketfft.cc +++ b/pypocketfft.cc @@ -250,7 +250,7 @@ template py::array r2r_dct_internal(const py::array &in, py::array r2r_dct(const py::array &in, int type, const py::object &axes_, int inorm, py::object &out_, size_t nthreads) { - if ((type<1) || (type>4)) throw runtime_error("invalid DCT type"); + if ((type<1) || (type>4)) throw invalid_argument("invalid DCT type"); DISPATCH(in, f64, f32, flong, r2r_dct_internal, (in, axes_, type, inorm, out_, nthreads)) } @@ -279,7 +279,7 @@ template py::array r2r_dst_internal(const py::array &in, py::array r2r_dst(const py::array &in, int type, const py::object &axes_, int inorm, py::object &out_, size_t nthreads) { - if ((type<1) || (type>4)) throw runtime_error("invalid DST type"); + if ((type<1) || (type>4)) throw invalid_argument("invalid DST type"); DISPATCH(in, f64, f32, flong, r2r_dst_internal, (in, axes_, type, inorm, out_, nthreads)) } @@ -293,7 +293,7 @@ template py::array c2r_internal(const py::array &in, shape_t dims_in(copy_shape(in)), dims_out=dims_in; if (lastsize==0) lastsize=2*dims_in[axis]-1; if ((lastsize/2) + 1 != dims_in[axis]) - throw runtime_error("bad lastsize"); + throw invalid_argument("bad lastsize"); dims_out[axis] = lastsize; py::array res = prepare_output(out_, dims_out); auto s_in=copy_strides(in); @@ -355,7 +355,6 @@ templatepy::array complex2hartley(const py::array &in, py::gil_scoped_release release; simple_iter iin(atmp); rev_iter iout(aout, axes); - if (iin.remaining()!=iout.remaining()) throw runtime_error("oops"); while(iin.remaining()>0) { auto v = atmp[iin.ofs()];