Skip to content
Snippets Groups Projects
Commit ac09e78b authored by Martin Reinecke's avatar Martin Reinecke
Browse files

enable long double; warning fixes

parent e02245ca
No related branches found
No related tags found
1 merge request!6Develop
...@@ -29,24 +29,24 @@ namespace py = pybind11; ...@@ -29,24 +29,24 @@ namespace py = pybind11;
auto c64 = py::dtype("complex64"); auto c64 = py::dtype("complex64");
auto c128 = py::dtype("complex128"); auto c128 = py::dtype("complex128");
//auto c256 = py::dtype("complex256"); auto c256 = py::dtype("complex256");
auto f32 = py::dtype("float32"); auto f32 = py::dtype("float32");
auto f64 = py::dtype("float64"); auto f64 = py::dtype("float64");
//auto f128 = py::dtype("float128"); auto f128 = py::dtype("float128");
shape_t copy_shape(const py::array &arr) shape_t copy_shape(const py::array &arr)
{ {
shape_t res(arr.ndim()); shape_t res(size_t(arr.ndim()));
for (size_t i=0; i<res.size(); ++i) for (size_t i=0; i<res.size(); ++i)
res[i] = arr.shape(i); res[i] = size_t(arr.shape(int(i)));
return res; return res;
} }
stride_t copy_strides(const py::array &arr) stride_t copy_strides(const py::array &arr)
{ {
stride_t res(arr.ndim()); stride_t res(size_t(arr.ndim()));
for (size_t i=0; i<res.size(); ++i) for (size_t i=0; i<res.size(); ++i)
res[i] = arr.strides(i); res[i] = arr.strides(int(i));
return res; return res;
} }
...@@ -54,7 +54,7 @@ shape_t makeaxes(const py::array &in, py::object axes) ...@@ -54,7 +54,7 @@ shape_t makeaxes(const py::array &in, py::object axes)
{ {
if (axes.is(py::none())) if (axes.is(py::none()))
{ {
shape_t res(in.ndim()); shape_t res(size_t(in.ndim()));
for (size_t i=0; i<res.size(); ++i) for (size_t i=0; i<res.size(); ++i)
res[i]=i; res[i]=i;
return res; return res;
...@@ -72,7 +72,7 @@ shape_t makeaxes(const py::array &in, py::object axes) ...@@ -72,7 +72,7 @@ shape_t makeaxes(const py::array &in, py::object axes)
auto dtype = arr.dtype(); \ auto dtype = arr.dtype(); \
if (dtype.is(T1)) return func<double> args; \ if (dtype.is(T1)) return func<double> args; \
if (dtype.is(T2)) return func<float> args; \ if (dtype.is(T2)) return func<float> args; \
/* if (dtype.is(T3)) return func<long double> args; */ \ if (dtype.is(T3)) return func<long double> args; \
throw runtime_error("unsupported data type"); throw runtime_error("unsupported data type");
template<typename T> py::array xfftn_internal(const py::array &in, template<typename T> py::array xfftn_internal(const py::array &in,
...@@ -190,7 +190,7 @@ template<typename T>py::array complex2hartley(const py::array &in, ...@@ -190,7 +190,7 @@ template<typename T>py::array complex2hartley(const py::array &in,
const py::array &tmp, py::object axes_, bool inplace) const py::array &tmp, py::object axes_, bool inplace)
{ {
using namespace pocketfft::detail; using namespace pocketfft::detail;
int ndim = in.ndim(); size_t ndim = size_t(in.ndim());
auto dims_out(copy_shape(in)); auto dims_out(copy_shape(in));
py::array out = inplace ? in : py::array_t<T>(dims_out); py::array out = inplace ? in : py::array_t<T>(dims_out);
ndarr<cmplx<T>> atmp(tmp.data(), copy_shape(tmp), copy_strides(tmp)); ndarr<cmplx<T>> atmp(tmp.data(), copy_shape(tmp), copy_strides(tmp));
...@@ -209,10 +209,10 @@ template<typename T>py::array complex2hartley(const py::array &in, ...@@ -209,10 +209,10 @@ template<typename T>py::array complex2hartley(const py::array &in,
{ {
if (i==axis) continue; if (i==axis) continue;
if (!swp[i]) if (!swp[i])
rofs += it.pos[i]*it.oarr.stride(i); rofs += ptrdiff_t(it.pos[i])*it.oarr.stride(i);
else else
{ {
auto x = (it.pos[i]==0) ? 0 : it.iarr.shape(i)-it.pos[i]; auto x = ptrdiff_t((it.pos[i]==0) ? 0 : it.iarr.shape(i)-it.pos[i]);
rofs += x*it.oarr.stride(i); rofs += x*it.oarr.stride(i);
} }
} }
...@@ -221,7 +221,7 @@ template<typename T>py::array complex2hartley(const py::array &in, ...@@ -221,7 +221,7 @@ template<typename T>py::array complex2hartley(const py::array &in,
{ {
auto re = it.in(i).r; auto re = it.in(i).r;
auto im = it.in(i).i; auto im = it.in(i).i;
auto rev_i = (i==0) ? 0 : it.length_out()-i; auto rev_i = ptrdiff_t((i==0) ? 0 : it.length_out()-i);
it.out(i) = re+im; it.out(i) = re+im;
aout[rofs + rev_i*it.stride_out()] = re-im; aout[rofs + rev_i*it.stride_out()] = re-im;
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment