Commit 0ad63b5f authored by Martin Reinecke's avatar Martin Reinecke

sync

parent 89a7dd70
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <stdexcept> #include <stdexcept>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <complex>
#if defined(_WIN32) #if defined(_WIN32)
#include <malloc.h> #include <malloc.h>
#endif #endif
...@@ -2381,9 +2382,9 @@ template<typename T> NOINLINE void general_r( ...@@ -2381,9 +2382,9 @@ template<typename T> NOINLINE void general_r(
} // namespace detail } // namespace detail
template<typename T> void c2c(const shape_t &shape, template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, const stride_t &stride_out, const shape_t &axes, bool forward,
bool forward, const void *data_in, void *data_out, T fct) const std::complex<T> *data_in, std::complex<T> *data_out, T fct)
{ {
using namespace detail; using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
...@@ -2392,23 +2393,23 @@ template<typename T> void c2c(const shape_t &shape, ...@@ -2392,23 +2393,23 @@ template<typename T> void c2c(const shape_t &shape,
general_c(ain, aout, axes, forward, fct); general_c(ain, aout, axes, forward, fct);
} }
template<typename T> void r2c(const shape_t &shape, template<typename T> void r2c(const shape_t &shape, const stride_t &stride_in,
const stride_t &stride_in, const stride_t &stride_out, size_t axis, const stride_t &stride_out, size_t axis, const T *data_in,
const void *data_in, void *data_out, T fct) std::complex<T> *data_out, T fct)
{ {
using namespace detail; using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axis); util::sanity_check(shape, stride_in, stride_out, false, axis);
ndarr<T> ain(data_in, shape, stride_in); ndarr<T> ain(data_in, shape, stride_in);
ndarr<cmplx<T>> aout(data_out, shape, stride_out); ndarr<cmplx<T>> aout(data_out, shape, stride_out);
general_r2c(ain, aout, axis, fct); general_r2c(ain, aout, axis, fct);
} }
template<typename T> void r2c(const shape_t &shape, const stride_t &stride_in, template<typename T> void r2c(const shape_t &shape, const stride_t &stride_in,
const stride_t &stride_out, const shape_t &axes, const void *data_in, const stride_t &stride_out, const shape_t &axes, const T *data_in,
void *data_out, T fct) std::complex<T> *data_out, T fct)
{ {
using namespace detail; using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); util::sanity_check(shape, stride_in, stride_out, false, axes);
r2c(shape, stride_in, stride_out, axes.back(), data_in, data_out, fct); r2c(shape, stride_in, stride_out, axes.back(), data_in, data_out, fct);
if (axes.size()==1) return; if (axes.size()==1) return;
...@@ -2421,10 +2422,10 @@ template<typename T> void r2c(const shape_t &shape, const stride_t &stride_in, ...@@ -2421,10 +2422,10 @@ template<typename T> void r2c(const shape_t &shape, const stride_t &stride_in,
template<typename T> void c2r(const shape_t &shape, size_t new_size, template<typename T> void c2r(const shape_t &shape, size_t new_size,
const stride_t &stride_in, const stride_t &stride_out, size_t axis, const stride_t &stride_in, const stride_t &stride_out, size_t axis,
const void *data_in, void *data_out, T fct) const std::complex<T> *data_in, T *data_out, T fct)
{ {
using namespace detail; using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axis); util::sanity_check(shape, stride_in, stride_out, false, axis);
shape_t shape_out(shape); shape_t shape_out(shape);
shape_out[axis] = new_size; shape_out[axis] = new_size;
ndarr<cmplx<T>> ain(data_in, shape, stride_in); ndarr<cmplx<T>> ain(data_in, shape, stride_in);
...@@ -2434,7 +2435,7 @@ template<typename T> void c2r(const shape_t &shape, size_t new_size, ...@@ -2434,7 +2435,7 @@ template<typename T> void c2r(const shape_t &shape, size_t new_size,
template<typename T> void c2r(const shape_t &shape, size_t new_size, template<typename T> void c2r(const shape_t &shape, size_t new_size,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
const void *data_in, void *data_out, T fct) const std::complex<T> *data_in, T *data_out, T fct)
{ {
using namespace detail; using namespace detail;
if (axes.size()==1) if (axes.size()==1)
...@@ -2443,13 +2444,13 @@ template<typename T> void c2r(const shape_t &shape, size_t new_size, ...@@ -2443,13 +2444,13 @@ template<typename T> void c2r(const shape_t &shape, size_t new_size,
data_in, data_out, fct); data_in, data_out, fct);
return; return;
} }
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); util::sanity_check(shape, stride_in, stride_out, false, axes);
auto nval = util::prod(shape); auto nval = util::prod(shape);
stride_t stride_inter(shape.size()); stride_t stride_inter(shape.size());
stride_inter.back() = sizeof(cmplx<T>); stride_inter.back() = sizeof(cmplx<T>);
for (int i=shape.size()-2; i>=0; --i) for (int i=shape.size()-2; i>=0; --i)
stride_inter[i] = stride_inter[i+1]*shape[i+1]; stride_inter[i] = stride_inter[i+1]*shape[i+1];
arr<char> tmp(nval*sizeof(cmplx<T>)); arr<complex<T>> tmp(nval);
auto newaxes = shape_t({axes.begin(), --axes.end()}); auto newaxes = shape_t({axes.begin(), --axes.end()});
c2c(shape, stride_in, stride_inter, newaxes, false, data_in, tmp.data(), T(1)); c2c(shape, stride_in, stride_inter, newaxes, false, data_in, tmp.data(), T(1));
c2r(shape, new_size, stride_inter, stride_out, axes.back(), c2r(shape, new_size, stride_inter, stride_out, axes.back(),
...@@ -2458,7 +2459,7 @@ template<typename T> void c2r(const shape_t &shape, size_t new_size, ...@@ -2458,7 +2459,7 @@ template<typename T> void c2r(const shape_t &shape, size_t new_size,
template<typename T> void r2r_fftpack(const shape_t &shape, template<typename T> void r2r_fftpack(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, size_t axis, const stride_t &stride_in, const stride_t &stride_out, size_t axis,
bool forward, const void *data_in, void *data_out, T fct) bool forward, const T *data_in, T *data_out, T fct)
{ {
using namespace detail; using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axis); util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axis);
...@@ -2468,7 +2469,7 @@ template<typename T> void r2r_fftpack(const shape_t &shape, ...@@ -2468,7 +2469,7 @@ template<typename T> void r2r_fftpack(const shape_t &shape,
template<typename T> void r2r_hartley(const shape_t &shape, template<typename T> void r2r_hartley(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
const void *data_in, void *data_out, T fct) const T *data_in, T *data_out, T fct)
{ {
using namespace detail; using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
......
...@@ -82,8 +82,9 @@ template<typename T> py::array xfftn_internal(const py::array &in, ...@@ -82,8 +82,9 @@ template<typename T> py::array xfftn_internal(const py::array &in,
{ {
auto dims(copy_shape(in)); auto dims(copy_shape(in));
py::array res = inplace ? in : py::array_t<complex<T>>(dims); py::array res = inplace ? in : py::array_t<complex<T>>(dims);
c2c(dims, copy_strides(in), copy_strides(res), axes, fwd, in.data(), c2c(dims, copy_strides(in), copy_strides(res), axes, fwd,
res.mutable_data(), T(fct)); reinterpret_cast<const complex<T> *>(in.data()),
reinterpret_cast<complex<T> *>(res.mutable_data()), T(fct));
return res; return res;
} }
...@@ -106,8 +107,9 @@ template<typename T> py::array rfftn_internal(const py::array &in, ...@@ -106,8 +107,9 @@ template<typename T> py::array rfftn_internal(const py::array &in,
auto dims_in(copy_shape(in)), dims_out(dims_in); auto dims_in(copy_shape(in)), dims_out(dims_in);
dims_out[axes.back()] = (dims_out[axes.back()]>>1)+1; dims_out[axes.back()] = (dims_out[axes.back()]>>1)+1;
py::array res = py::array_t<complex<T>>(dims_out); py::array res = py::array_t<complex<T>>(dims_out);
r2c(dims_in, copy_strides(in), copy_strides(res), axes, in.data(), r2c(dims_in, copy_strides(in), copy_strides(res), axes,
res.mutable_data(), T(fct)); reinterpret_cast<const T *>(in.data()),
reinterpret_cast<complex<T> *>(res.mutable_data()), T(fct));
return res; return res;
} }
py::array rfftn(const py::array &in, py::object axes_, double fct) py::array rfftn(const py::array &in, py::object axes_, double fct)
...@@ -121,7 +123,8 @@ template<typename T> py::array xrfft_scipy(const py::array &in, ...@@ -121,7 +123,8 @@ template<typename T> py::array xrfft_scipy(const py::array &in,
auto dims(copy_shape(in)); auto dims(copy_shape(in));
py::array res = inplace ? in : py::array_t<T>(dims); py::array res = inplace ? in : py::array_t<T>(dims);
r2r_fftpack(dims, copy_strides(in), copy_strides(res), axis, fwd, r2r_fftpack(dims, copy_strides(in), copy_strides(res), axis, fwd,
in.data(), res.mutable_data(), T(fct)); reinterpret_cast<const T *>(in.data()),
reinterpret_cast<T *>(res.mutable_data()), T(fct));
return res; return res;
} }
py::array rfft_scipy(const py::array &in, size_t axis, double fct, bool inplace) py::array rfft_scipy(const py::array &in, size_t axis, double fct, bool inplace)
...@@ -149,7 +152,8 @@ template<typename T> py::array irfftn_internal(const py::array &in, ...@@ -149,7 +152,8 @@ template<typename T> py::array irfftn_internal(const py::array &in,
dims_out[axis] = lastsize; dims_out[axis] = lastsize;
py::array res = py::array_t<T>(dims_out); py::array res = py::array_t<T>(dims_out);
c2r(dims_in, lastsize, copy_strides(in), copy_strides(res), axes, c2r(dims_in, lastsize, copy_strides(in), copy_strides(res), axes,
in.data(), res.mutable_data(), T(fct)); reinterpret_cast<const complex<T> *>(in.data()),
reinterpret_cast<T *>(res.mutable_data()), T(fct));
return res; return res;
} }
py::array irfftn(const py::array &in, py::object axes_, size_t lastsize, py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
...@@ -166,7 +170,8 @@ template<typename T> py::array hartley_internal(const py::array &in, ...@@ -166,7 +170,8 @@ template<typename T> py::array hartley_internal(const py::array &in,
auto dims(copy_shape(in)); auto dims(copy_shape(in));
py::array res = inplace ? in : py::array_t<T>(dims); py::array res = inplace ? in : py::array_t<T>(dims);
r2r_hartley(dims, copy_strides(in), copy_strides(res), makeaxes(in, axes_), r2r_hartley(dims, copy_strides(in), copy_strides(res), makeaxes(in, axes_),
in.data(), res.mutable_data(), T(fct)); reinterpret_cast<const T *>(in.data()),
reinterpret_cast<T *>(res.mutable_data()), T(fct));
return res; return res;
} }
py::array hartley(const py::array &in, py::object axes_, double fct, py::array hartley(const py::array &in, py::object axes_, double fct,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment