Commit 0ad63b5f authored by Martin Reinecke's avatar Martin Reinecke

sync

parent 89a7dd70
......@@ -19,6 +19,7 @@
#include <stdexcept>
#include <memory>
#include <vector>
#include <complex>
#if defined(_WIN32)
#include <malloc.h>
#endif
......@@ -2381,9 +2382,9 @@ template<typename T> NOINLINE void general_r(
} // namespace detail
template<typename T> void c2c(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
bool forward, const void *data_in, void *data_out, T fct)
template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,
const stride_t &stride_out, const shape_t &axes, bool forward,
const std::complex<T> *data_in, std::complex<T> *data_out, T fct)
{
using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
......@@ -2392,23 +2393,23 @@ template<typename T> void c2c(const shape_t &shape,
general_c(ain, aout, axes, forward, fct);
}
template<typename T> void r2c(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, size_t axis,
const void *data_in, void *data_out, T fct)
template<typename T> void r2c(const shape_t &shape, const stride_t &stride_in,
const stride_t &stride_out, size_t axis, const T *data_in,
std::complex<T> *data_out, T fct)
{
using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axis);
util::sanity_check(shape, stride_in, stride_out, false, axis);
ndarr<T> ain(data_in, shape, stride_in);
ndarr<cmplx<T>> aout(data_out, shape, stride_out);
general_r2c(ain, aout, axis, fct);
}
template<typename T> void r2c(const shape_t &shape, const stride_t &stride_in,
const stride_t &stride_out, const shape_t &axes, const void *data_in,
void *data_out, T fct)
const stride_t &stride_out, const shape_t &axes, const T *data_in,
std::complex<T> *data_out, T fct)
{
using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
util::sanity_check(shape, stride_in, stride_out, false, axes);
r2c(shape, stride_in, stride_out, axes.back(), data_in, data_out, fct);
if (axes.size()==1) return;
......@@ -2421,10 +2422,10 @@ template<typename T> void r2c(const shape_t &shape, const stride_t &stride_in,
template<typename T> void c2r(const shape_t &shape, size_t new_size,
const stride_t &stride_in, const stride_t &stride_out, size_t axis,
const void *data_in, void *data_out, T fct)
const std::complex<T> *data_in, T *data_out, T fct)
{
using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axis);
util::sanity_check(shape, stride_in, stride_out, false, axis);
shape_t shape_out(shape);
shape_out[axis] = new_size;
ndarr<cmplx<T>> ain(data_in, shape, stride_in);
......@@ -2434,7 +2435,7 @@ template<typename T> void c2r(const shape_t &shape, size_t new_size,
template<typename T> void c2r(const shape_t &shape, size_t new_size,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
const void *data_in, void *data_out, T fct)
const std::complex<T> *data_in, T *data_out, T fct)
{
using namespace detail;
if (axes.size()==1)
......@@ -2443,13 +2444,13 @@ template<typename T> void c2r(const shape_t &shape, size_t new_size,
data_in, data_out, fct);
return;
}
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
util::sanity_check(shape, stride_in, stride_out, false, axes);
auto nval = util::prod(shape);
stride_t stride_inter(shape.size());
stride_inter.back() = sizeof(cmplx<T>);
for (int i=shape.size()-2; i>=0; --i)
stride_inter[i] = stride_inter[i+1]*shape[i+1];
arr<char> tmp(nval*sizeof(cmplx<T>));
arr<complex<T>> tmp(nval);
auto newaxes = shape_t({axes.begin(), --axes.end()});
c2c(shape, stride_in, stride_inter, newaxes, false, data_in, tmp.data(), T(1));
c2r(shape, new_size, stride_inter, stride_out, axes.back(),
......@@ -2458,7 +2459,7 @@ template<typename T> void c2r(const shape_t &shape, size_t new_size,
template<typename T> void r2r_fftpack(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, size_t axis,
bool forward, const void *data_in, void *data_out, T fct)
bool forward, const T *data_in, T *data_out, T fct)
{
using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axis);
......@@ -2468,7 +2469,7 @@ template<typename T> void r2r_fftpack(const shape_t &shape,
template<typename T> void r2r_hartley(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
const void *data_in, void *data_out, T fct)
const T *data_in, T *data_out, T fct)
{
using namespace detail;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
......
......@@ -82,8 +82,9 @@ template<typename T> py::array xfftn_internal(const py::array &in,
{
auto dims(copy_shape(in));
py::array res = inplace ? in : py::array_t<complex<T>>(dims);
c2c(dims, copy_strides(in), copy_strides(res), axes, fwd, in.data(),
res.mutable_data(), T(fct));
c2c(dims, copy_strides(in), copy_strides(res), axes, fwd,
reinterpret_cast<const complex<T> *>(in.data()),
reinterpret_cast<complex<T> *>(res.mutable_data()), T(fct));
return res;
}
......@@ -106,8 +107,9 @@ template<typename T> py::array rfftn_internal(const py::array &in,
auto dims_in(copy_shape(in)), dims_out(dims_in);
dims_out[axes.back()] = (dims_out[axes.back()]>>1)+1;
py::array res = py::array_t<complex<T>>(dims_out);
r2c(dims_in, copy_strides(in), copy_strides(res), axes, in.data(),
res.mutable_data(), T(fct));
r2c(dims_in, copy_strides(in), copy_strides(res), axes,
reinterpret_cast<const T *>(in.data()),
reinterpret_cast<complex<T> *>(res.mutable_data()), T(fct));
return res;
}
py::array rfftn(const py::array &in, py::object axes_, double fct)
......@@ -121,7 +123,8 @@ template<typename T> py::array xrfft_scipy(const py::array &in,
auto dims(copy_shape(in));
py::array res = inplace ? in : py::array_t<T>(dims);
r2r_fftpack(dims, copy_strides(in), copy_strides(res), axis, fwd,
in.data(), res.mutable_data(), T(fct));
reinterpret_cast<const T *>(in.data()),
reinterpret_cast<T *>(res.mutable_data()), T(fct));
return res;
}
py::array rfft_scipy(const py::array &in, size_t axis, double fct, bool inplace)
......@@ -149,7 +152,8 @@ template<typename T> py::array irfftn_internal(const py::array &in,
dims_out[axis] = lastsize;
py::array res = py::array_t<T>(dims_out);
c2r(dims_in, lastsize, copy_strides(in), copy_strides(res), axes,
in.data(), res.mutable_data(), T(fct));
reinterpret_cast<const complex<T> *>(in.data()),
reinterpret_cast<T *>(res.mutable_data()), T(fct));
return res;
}
py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
......@@ -166,7 +170,8 @@ template<typename T> py::array hartley_internal(const py::array &in,
auto dims(copy_shape(in));
py::array res = inplace ? in : py::array_t<T>(dims);
r2r_hartley(dims, copy_strides(in), copy_strides(res), makeaxes(in, axes_),
in.data(), res.mutable_data(), T(fct));
reinterpret_cast<const T *>(in.data()),
reinterpret_cast<T *>(res.mutable_data()), T(fct));
return res;
}
py::array hartley(const py::array &in, py::object axes_, double fct,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment