Commit 15e37d0f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

reduce the interface. WIP.

parent da7ba7dd
...@@ -23,10 +23,10 @@ def bench_nd_fftn(ndim, nmax, ntry, tp, nrepeat, filename=""): ...@@ -23,10 +23,10 @@ def bench_nd_fftn(ndim, nmax, ntry, tp, nrepeat, filename=""):
tmin_pp=1e38 tmin_pp=1e38
for i in range(nrepeat): for i in range(nrepeat):
t0=time() t0=time()
b=pypocketfft.fftn(a,nthreads=nthreads) b=pypocketfft.c2c(a,nthreads=nthreads, forward=True)
t1=time() t1=time()
tmin_pp = min(tmin_pp,t1-t0) tmin_pp = min(tmin_pp,t1-t0)
a2=pypocketfft.ifftn(b,inorm=2) a2=pypocketfft.c2c(b,inorm=2, forward=False)
assert(_l2error(a,a2)<(2.5e-15 if tp=='c16' else 6e-7)) assert(_l2error(a,a2)<(2.5e-15 if tp=='c16' else 6e-7))
res.append(tmin_pp/tmin_np) res.append(tmin_pp/tmin_np)
plt.title("t(pypocketfft / numpy 1.17), {}D, {}, max_extent={}".format(ndim, str(tp), nmax)) plt.title("t(pypocketfft / numpy 1.17), {}D, {}, max_extent={}".format(ndim, str(tp), nmax))
......
...@@ -23,7 +23,8 @@ ...@@ -23,7 +23,8 @@
namespace { namespace {
using namespace std; using namespace std;
using namespace pocketfft; using pocketfft::shape_t;
using pocketfft::stride_t;
namespace py = pybind11; namespace py = pybind11;
...@@ -104,8 +105,8 @@ template<typename T> T norm_fct(int inorm, const shape_t &shape, ...@@ -104,8 +105,8 @@ template<typename T> T norm_fct(int inorm, const shape_t &shape,
return norm_fct<T>(inorm, N); return norm_fct<T>(inorm, N);
} }
template<typename T> py::array xfftn_internal(const py::array &in, template<typename T> py::array c2c_internal(const py::array &in,
const shape_t &axes, int inorm, bool inplace, bool fwd, size_t nthreads) const shape_t &axes, bool forward, int inorm, bool inplace, size_t nthreads)
{ {
auto dims(copy_shape(in)); auto dims(copy_shape(in));
py::array res = inplace ? in : py::array_t<complex<T>>(dims); py::array res = inplace ? in : py::array_t<complex<T>>(dims);
...@@ -116,13 +117,13 @@ template<typename T> py::array xfftn_internal(const py::array &in, ...@@ -116,13 +117,13 @@ template<typename T> py::array xfftn_internal(const py::array &in,
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
T fct = norm_fct<T>(inorm, dims, axes); T fct = norm_fct<T>(inorm, dims, axes);
c2c(dims, s_in, s_out, axes, fwd, d_in, d_out, fct, nthreads); pocketfft::c2c(dims, s_in, s_out, axes, forward, d_in, d_out, fct, nthreads);
} }
return res; return res;
} }
template<typename T> py::array sym_rfftn_internal(const py::array &in, template<typename T> py::array c2c_sym_internal(const py::array &in,
const shape_t &axes, int inorm, bool fwd, size_t nthreads) const shape_t &axes, bool forward, int inorm, size_t nthreads)
{ {
auto dims(copy_shape(in)); auto dims(copy_shape(in));
py::array res = py::array_t<complex<T>>(dims); py::array res = py::array_t<complex<T>>(dims);
...@@ -133,7 +134,7 @@ template<typename T> py::array sym_rfftn_internal(const py::array &in, ...@@ -133,7 +134,7 @@ template<typename T> py::array sym_rfftn_internal(const py::array &in,
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
T fct = norm_fct<T>(inorm, dims, axes); T fct = norm_fct<T>(inorm, dims, axes);
r2c(dims, s_in, s_out, axes, fwd, d_in, d_out, fct, nthreads); pocketfft::r2c(dims, s_in, s_out, axes, forward, d_in, d_out, fct, nthreads);
// now fill in second half // now fill in second half
using namespace pocketfft::detail; using namespace pocketfft::detail;
ndarr<complex<T>> ares(res.mutable_data(), dims, s_out); ndarr<complex<T>> ares(res.mutable_data(), dims, s_out);
...@@ -148,28 +149,20 @@ template<typename T> py::array sym_rfftn_internal(const py::array &in, ...@@ -148,28 +149,20 @@ template<typename T> py::array sym_rfftn_internal(const py::array &in,
return res; return res;
} }
py::array xfftn(const py::array &a, py::object axes, int inorm, py::array c2c(const py::array &a, py::object axes, bool forward, int inorm,
bool inplace, bool fwd, size_t nthreads) bool inplace, size_t nthreads)
{ {
if (a.dtype().kind() == 'c') if (a.dtype().kind() == 'c')
DISPATCH(a, c128, c64, clong, xfftn_internal, (a, makeaxes(a, axes), inorm, DISPATCH(a, c128, c64, clong, c2c_internal, (a, makeaxes(a, axes), forward,
inplace, fwd, nthreads)) inorm, inplace, nthreads))
if (inplace) throw runtime_error("cannot do this operation in-place"); if (inplace) throw runtime_error("cannot do this operation in-place");
DISPATCH(a, f64, f32, flong, sym_rfftn_internal, (a, makeaxes(a, axes), inorm, DISPATCH(a, f64, f32, flong, c2c_sym_internal, (a, makeaxes(a, axes), forward,
fwd, nthreads)) inorm, nthreads))
} }
py::array fftn(const py::array &a, py::object axes, int inorm, bool inplace, template<typename T> py::array r2c_internal(const py::array &in,
size_t nthreads) py::object axes_, bool forward, int inorm, size_t nthreads)
{ return xfftn(a, axes, inorm, inplace, true, nthreads); }
py::array ifftn(const py::array &a, py::object axes, int inorm,
bool inplace, size_t nthreads)
{ return xfftn(a, axes, inorm, inplace, false, nthreads); }
template<typename T> py::array rfftn_internal(const py::array &in,
py::object axes_, int inorm, size_t nthreads)
{ {
auto axes = makeaxes(in, axes_); auto axes = makeaxes(in, axes_);
auto dims_in(copy_shape(in)), dims_out(dims_in); auto dims_in(copy_shape(in)), dims_out(dims_in);
...@@ -182,20 +175,23 @@ template<typename T> py::array rfftn_internal(const py::array &in, ...@@ -182,20 +175,23 @@ template<typename T> py::array rfftn_internal(const py::array &in,
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
T fct = norm_fct<T>(inorm, dims_in, axes); T fct = norm_fct<T>(inorm, dims_in, axes);
r2c(dims_in, s_in, s_out, axes, true, d_in, d_out, fct, nthreads); pocketfft::r2c(dims_in, s_in, s_out, axes, forward, d_in, d_out, fct, nthreads);
} }
return res; return res;
} }
py::array rfftn(const py::array &in, py::object axes_, int inorm, py::array r2c(const py::array &in, py::object axes_, bool forward, int inorm,
size_t nthreads) size_t nthreads)
{ {
DISPATCH(in, f64, f32, flong, rfftn_internal, (in, axes_, inorm, nthreads)) DISPATCH(in, f64, f32, flong, r2c_internal, (in, axes_, forward, inorm,
nthreads))
} }
template<typename T> py::array xrfft_scipy(const py::array &in, template<typename T> py::array r2r_fftpack_internal(const py::array &in,
size_t axis, int inorm, bool inplace, bool fwd, size_t nthreads) py::object axes_, bool input_halfcomplex, bool forward, int inorm,
bool inplace, size_t nthreads)
{ {
auto axes = makeaxes(in, axes_);
auto dims(copy_shape(in)); auto dims(copy_shape(in));
py::array res = inplace ? in : py::array_t<T>(dims); py::array res = inplace ? in : py::array_t<T>(dims);
auto s_in=copy_strides(in); auto s_in=copy_strides(in);
...@@ -204,27 +200,23 @@ template<typename T> py::array xrfft_scipy(const py::array &in, ...@@ -204,27 +200,23 @@ template<typename T> py::array xrfft_scipy(const py::array &in,
auto d_out=reinterpret_cast<T *>(res.mutable_data()); auto d_out=reinterpret_cast<T *>(res.mutable_data());
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
T fct = norm_fct<T>(inorm, dims[axis]); T fct = norm_fct<T>(inorm, dims, axes);
r2r_fftpack(dims, s_in, s_out, {axis}, fwd, fwd, d_in, d_out, fct, nthreads); pocketfft::r2r_fftpack(dims, s_in, s_out, axes, !input_halfcomplex, forward, d_in, d_out,
fct, nthreads);
} }
return res; return res;
} }
py::array rfft_scipy(const py::array &in, size_t axis, int inorm, py::array r2r_fftpack(const py::array &in, py::object axes_,
bool inplace, size_t nthreads) bool input_halfcomplex, bool forward, int inorm, bool inplace,
size_t nthreads)
{ {
DISPATCH(in, f64, f32, flong, xrfft_scipy, (in, axis, inorm, inplace, true, DISPATCH(in, f64, f32, flong, r2r_fftpack_internal, (in, axes_,
nthreads)) input_halfcomplex, forward, inorm, inplace, nthreads))
} }
py::array irfft_scipy(const py::array &in, size_t axis, int inorm, template<typename T> py::array c2r_internal(const py::array &in,
bool inplace, size_t nthreads) py::object axes_, size_t lastsize, bool forward, int inorm, size_t nthreads)
{
DISPATCH(in, f64, f32, flong, xrfft_scipy, (in, axis, inorm, inplace, false,
nthreads))
}
template<typename T> py::array irfftn_internal(const py::array &in,
py::object axes_, size_t lastsize, int inorm, size_t nthreads)
{ {
auto axes = makeaxes(in, axes_); auto axes = makeaxes(in, axes_);
size_t axis = axes.back(); size_t axis = axes.back();
...@@ -241,16 +233,16 @@ template<typename T> py::array irfftn_internal(const py::array &in, ...@@ -241,16 +233,16 @@ template<typename T> py::array irfftn_internal(const py::array &in,
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
T fct = norm_fct<T>(inorm, dims_out, axes); T fct = norm_fct<T>(inorm, dims_out, axes);
c2r(dims_out, s_in, s_out, axes, false, d_in, d_out, fct, nthreads); pocketfft::c2r(dims_out, s_in, s_out, axes, forward, d_in, d_out, fct, nthreads);
} }
return res; return res;
} }
py::array irfftn(const py::array &in, py::object axes_, size_t lastsize, py::array c2r(const py::array &in, py::object axes_, size_t lastsize,
int inorm, size_t nthreads) bool forward, int inorm, size_t nthreads)
{ {
DISPATCH(in, c128, c64, clong, irfftn_internal, (in, axes_, lastsize, inorm, DISPATCH(in, c128, c64, clong, c2r_internal, (in, axes_, lastsize, forward,
nthreads)) inorm, nthreads))
} }
template<typename T> py::array hartley_internal(const py::array &in, template<typename T> py::array hartley_internal(const py::array &in,
...@@ -266,7 +258,7 @@ template<typename T> py::array hartley_internal(const py::array &in, ...@@ -266,7 +258,7 @@ template<typename T> py::array hartley_internal(const py::array &in,
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
T fct = norm_fct<T>(inorm, dims, axes); T fct = norm_fct<T>(inorm, dims, axes);
r2r_hartley(dims, s_in, s_out, axes, d_in, d_out, fct, nthreads); pocketfft::r2r_hartley(dims, s_in, s_out, axes, d_in, d_out, fct, nthreads);
} }
return res; return res;
} }
...@@ -312,7 +304,7 @@ py::array mycomplex2hartley(const py::array &in, ...@@ -312,7 +304,7 @@ py::array mycomplex2hartley(const py::array &in,
py::array hartley2(const py::array &in, py::object axes_, int inorm, py::array hartley2(const py::array &in, py::object axes_, int inorm,
bool inplace, size_t nthreads) bool inplace, size_t nthreads)
{ {
return mycomplex2hartley(in, rfftn(in, axes_, inorm, nthreads), axes_, return mycomplex2hartley(in, r2c(in, axes_, true, inorm, nthreads), axes_,
inplace); inplace);
} }
...@@ -328,8 +320,7 @@ vector instructions for faster execution if these are supported by the CPU and ...@@ -328,8 +320,7 @@ vector instructions for faster execution if these are supported by the CPU and
were enabled during compilation. were enabled during compilation.
)"""; )""";
const char *fftn_DS = R"""( const char *c2c_DS = R"""(Performs a complex FFT.
Performs a forward complex FFT.
Parameters Parameters
---------- ----------
...@@ -339,6 +330,8 @@ a : numpy.ndarray (any complex or real type) ...@@ -339,6 +330,8 @@ a : numpy.ndarray (any complex or real type)
axes : list of integers axes : list of integers
The axes along which the FFT is carried out. The axes along which the FFT is carried out.
If not set, all axes will be transformed. If not set, all axes will be transformed.
forward : bool
If `True`, a negative sign is used in the exponent, else a positive one.
inorm : int inorm : int
Normalization type Normalization type
0 : no normalization 0 : no normalization
...@@ -360,38 +353,7 @@ np.ndarray (same shape as `a`, complex type with same accuracy as `a`) ...@@ -360,38 +353,7 @@ np.ndarray (same shape as `a`, complex type with same accuracy as `a`)
The transformed data. The transformed data.
)"""; )""";
const char *ifftn_DS = R"""(Performs a backward complex FFT. const char *r2c_DS = R"""(Performs an FFT whose input is strictly real.
Parameters
----------
a : numpy.ndarray (any complex or real type)
The input data. If its type is real, a more efficient real-to-complex
transform will be used.
axes : list of integers
The axes along which the FFT is carried out.
If not set, all axes will be transformed.
inorm : int
Normalization type
0 : no normalization
1 : divide by sqrt(N)
2 : divide by N
where N is the product of the lengths of the transformed axes.
inplace : bool
if False, returns the result in a new array and leaves the input unchanged.
if True, stores the result in the input array and returns a handle to it.
NOTE: if `a` is real-valued and `inplace` is `True`, an exception will be
raised!
nthreads : int
Number of threads to use. If 0, use the system default (typically governed
by the `OMP_NUM_THREADS` environment variable).
Returns
-------
np.ndarray (same shape as `a`, complex type with same accuracy as `a`)
The transformed data
)""";
const char *rfftn_DS = R"""(Performs a forward real-valued FFT.
Parameters Parameters
---------- ----------
...@@ -400,6 +362,8 @@ a : numpy.ndarray (any real type) ...@@ -400,6 +362,8 @@ a : numpy.ndarray (any real type)
axes : list of integers axes : list of integers
The axes along which the FFT is carried out. The axes along which the FFT is carried out.
If not set, all axes will be transformed in ascending order. If not set, all axes will be transformed in ascending order.
forward : bool
If `True`, a negative sign is used in the exponent, else a positive one.
inorm : int inorm : int
Normalization type Normalization type
0 : no normalization 0 : no normalization
...@@ -418,36 +382,7 @@ np.ndarray (complex type with same accuracy as `a`) ...@@ -418,36 +382,7 @@ np.ndarray (complex type with same accuracy as `a`)
was n on input, it is n//2+1 on output. was n on input, it is n//2+1 on output.
)"""; )""";
const char *rfft_scipy_DS = R"""(Performs a forward real-valued FFT. const char *c2r_DS = R"""(Performs an FFT whose output is strictly real.
Parameters
----------
a : numpy.ndarray (any real type)
The input data
axis : int
The axis along which the FFT is carried out.
inorm : int
Normalization type
0 : no normalization
1 : divide by sqrt(N)
2 : divide by N
where N is the length of `axis`.
inplace : bool
if False, returns the result in a new array and leaves the input unchanged.
if True, stores the result in the input array and returns a handle to it.
nthreads : int
Number of threads to use. If 0, use the system default (typically governed
by the `OMP_NUM_THREADS` environment variable).
Returns
-------
np.ndarray (same shape and data type as `a`)
The transformed data. The shape is identical to that of the input array.
Along the transformed axis, values are arranged in
FFTPACK half-complex order, i.e. `a[0].re, a[1].re, a[1].im, a[2].re ...`.
)""";
const char *irfftn_DS = R"""(Performs a backward real-valued FFT.
Parameters Parameters
---------- ----------
...@@ -458,6 +393,8 @@ axes : list of integers ...@@ -458,6 +393,8 @@ axes : list of integers
If not set, all axes will be transformed in ascending order. If not set, all axes will be transformed in ascending order.
lastsize : the output size of the last axis to be transformed. lastsize : the output size of the last axis to be transformed.
If the corresponding input axis has size n, this can be 2*n-2 or 2*n-1. If the corresponding input axis has size n, this can be 2*n-2 or 2*n-1.
forward : bool
If `True`, a negative sign is used in the exponent, else a positive one.
inorm : int inorm : int
Normalization type Normalization type
0 : no normalization 0 : no normalization
...@@ -476,15 +413,20 @@ np.ndarray (real type with same accuracy as `a`) ...@@ -476,15 +413,20 @@ np.ndarray (real type with same accuracy as `a`)
entries. entries.
)"""; )""";
const char *irfft_scipy_DS = R"""(Performs a backward real-valued FFT. const char *r2r_fftpack_DS = R"""(Performs a real-valued FFT using the FFTPACK storage scheme.
Parameters Parameters
---------- ----------
a : numpy.ndarray (any real type) a : numpy.ndarray (any real type)
The input data. Along the transformed axis, values are expected in The input data
FFTPACK half-complex order, i.e. `a[0].re, a[1].re, a[1].im, a[2].re ...`. axes : list of integers
axis : int The axes along which the FFT is carried out.
The axis along which the FFT is carried out. If not set, all axes will be transformed.
input_halfcomplex : bool
if True, the transform will go from halfcomplex to real format, else in
the other direction
forward : bool
If `True`, a negative sign is used in the exponent, else a positive one.
inorm : int inorm : int
Normalization type Normalization type
0 : no normalization 0 : no normalization
...@@ -542,18 +484,15 @@ PYBIND11_MODULE(pypocketfft, m) ...@@ -542,18 +484,15 @@ PYBIND11_MODULE(pypocketfft, m)
using namespace pybind11::literals; using namespace pybind11::literals;
m.doc() = pypocketfft_DS; m.doc() = pypocketfft_DS;
m.def("fftn",&fftn, fftn_DS, "a"_a, "axes"_a=py::none(), "inorm"_a=0, m.def("c2c",&c2c, c2c_DS, "a"_a, "axes"_a=py::none(), "forward"_a=true,
"inplace"_a=false, "nthreads"_a=1);
m.def("ifftn",&ifftn, ifftn_DS, "a"_a, "axes"_a=py::none(), "inorm"_a=0,
"inplace"_a=false, "nthreads"_a=1);
m.def("rfftn",&rfftn, rfftn_DS, "a"_a, "axes"_a=py::none(), "inorm"_a=0,
"nthreads"_a=1);
m.def("rfft_scipy",&rfft_scipy, rfft_scipy_DS, "a"_a, "axis"_a, "inorm"_a=0,
"inplace"_a=false, "nthreads"_a=1);
m.def("irfftn",&irfftn, irfftn_DS, "a"_a, "axes"_a=py::none(), "lastsize"_a=0,
"inorm"_a=0, "nthreads"_a=1);
m.def("irfft_scipy",&irfft_scipy, irfft_scipy_DS, "a"_a, "axis"_a,
"inorm"_a=0, "inplace"_a=false, "nthreads"_a=1); "inorm"_a=0, "inplace"_a=false, "nthreads"_a=1);
m.def("r2c",&r2c, r2c_DS, "a"_a, "axes"_a=py::none(), "forward"_a=true,
"inorm"_a=0, "nthreads"_a=1);
m.def("c2r",&c2r, c2r_DS, "a"_a, "axes"_a=py::none(), "lastsize"_a=0,
"forward"_a=true, "inorm"_a=0, "nthreads"_a=1);
m.def("r2r_fftpack",&r2r_fftpack, r2r_fftpack_DS, "a"_a, "axes"_a,
"input_halfcomplex"_a, "forward"_a, "inorm"_a=0, "inplace"_a=false,
"nthreads"_a=1);
m.def("hartley",&hartley, hartley_DS, "a"_a, "axes"_a=py::none(), "inorm"_a=0, m.def("hartley",&hartley, hartley_DS, "a"_a, "axes"_a=py::none(), "inorm"_a=0,
"inplace"_a=false, "nthreads"_a=1); "inplace"_a=false, "nthreads"_a=1);
m.def("hartley2",&hartley2, "a"_a, "axes"_a=py::none(), "inorm"_a=0, m.def("hartley2",&hartley2, "a"_a, "axes"_a=py::none(), "inorm"_a=0,
......
...@@ -4,6 +4,31 @@ import pypocketfft ...@@ -4,6 +4,31 @@ import pypocketfft
def _l2error(a,b): def _l2error(a,b):
return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2)) return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2))
def fftn(a, axes=None, inorm=0, inplace=False, nthreads=1):
return pypocketfft.c2c(a, axes=axes, forward=True, inorm=inorm,
inplace=inplace, nthreads=nthreads)
def ifftn(a, axes=None, inorm=0, inplace=False, nthreads=1):
return pypocketfft.c2c(a, axes=axes, forward=False, inorm=inorm,
inplace=inplace, nthreads=nthreads)
def rfftn(a, axes=None, inorm=0, nthreads=1):
return pypocketfft.r2c(a, axes=axes, forward=True, inorm=inorm,
nthreads=nthreads)
def irfftn(a, axes=None, lastsize=0, inorm=0, nthreads=1):
return pypocketfft.c2r(a, axes=axes, lastsize=lastsize,forward=False,
inorm=inorm, nthreads=nthreads)
def rfft_scipy(a, axis, inorm=0, inplace=False, nthreads=1):
return pypocketfft.r2r_fftpack(a, axes=(axis,), input_halfcomplex=False,
forward=True, inorm=inorm, inplace=inplace,
nthreads=nthreads)
def irfft_scipy(a, axis, inorm=0, inplace=False, nthreads=1):
return pypocketfft.r2r_fftpack(a, axes=(axis,), input_halfcomplex=True,
forward=False, inorm=inorm, inplace=inplace,
nthreads=nthreads)
nthreads=0 nthreads=0
cmaxerr=0. cmaxerr=0.
fmaxerr=0. fmaxerr=0.
...@@ -22,32 +47,32 @@ def test(): ...@@ -22,32 +47,32 @@ def test():
axes = axes[:nax] axes = axes[:nax]
lastsize = shape[axes[-1]] lastsize = shape[axes[-1]]
a=np.random.rand(*shape)-0.5 + 1j*np.random.rand(*shape)-0.5j a=np.random.rand(*shape)-0.5 + 1j*np.random.rand(*shape)-0.5j
b=pypocketfft.ifftn(pypocketfft.fftn(a,axes=axes,nthreads=nthreads),axes=axes,inorm=2,nthreads=nthreads) b=ifftn(fftn(a,axes=axes,nthreads=nthreads),axes=axes,inorm=2,nthreads=nthreads)
err = _l2error(a,b) err = _l2error(a,b)
if err > cmaxerr: if err > cmaxerr:
cmaxerr = err cmaxerr = err
print("cmaxerr:", cmaxerr, shape, axes) print("cmaxerr:", cmaxerr, shape, axes)
b=pypocketfft.ifftn(pypocketfft.fftn(a.real,axes=axes,nthreads=nthreads),axes=axes,inorm=2,nthreads=nthreads) b=ifftn(fftn(a.real,axes=axes,nthreads=nthreads),axes=axes,inorm=2,nthreads=nthreads)
err = _l2error(a.real,b) err = _l2error(a.real,b)
if err > cmaxerr: if err > cmaxerr:
cmaxerr = err cmaxerr = err
print("cmaxerr:", cmaxerr, shape, axes) print("cmaxerr:", cmaxerr, shape, axes)
b=pypocketfft.fftn(pypocketfft.ifftn(a.real,axes=axes,nthreads=nthreads),axes=axes,inorm=2,nthreads=nthreads) b=fftn(ifftn(a.real,axes=axes,nthreads=nthreads),axes=axes,inorm=2,nthreads=nthreads)
err = _l2error(a.real,b) err = _l2error(a.real,b)
if err > cmaxerr: if err > cmaxerr:
cmaxerr = err cmaxerr = err
print("cmaxerr:", cmaxerr, shape, axes) print("cmaxerr:", cmaxerr, shape, axes)
b=pypocketfft.irfftn(pypocketfft.rfftn(a.real,axes=axes,nthreads=nthreads),axes=axes,inorm=2,lastsize=lastsize,nthreads=nthreads) b=irfftn(rfftn(a.real,axes=axes,nthreads=nthreads),axes=axes,inorm=2,lastsize=lastsize,nthreads=nthreads)
err = _l2error(a.real,b) err = _l2error(a.real,b)
if err > fmaxerr: if err > fmaxerr:
fmaxerr = err fmaxerr = err
print("fmaxerr:", fmaxerr, shape, axes) print("fmaxerr:", fmaxerr, shape, axes)
b=pypocketfft.ifftn(pypocketfft.fftn(a.astype(np.complex64),axes=axes,nthreads=nthreads),axes=axes,inorm=2,nthreads=nthreads) b=ifftn(fftn(a.astype(np.complex64),axes=axes,nthreads=nthreads),axes=axes,inorm=2,nthreads=nthreads)
err = _l2error(a.astype(np.complex64),b) err = _l2error(a.astype(np.complex64),b)
if err > cmaxerrf: if err > cmaxerrf:
cmaxerrf = err cmaxerrf = err
print("cmaxerrf:", cmaxerrf, shape, axes) print("cmaxerrf:", cmaxerrf, shape, axes)
b=pypocketfft.irfftn(pypocketfft.rfftn(a.real.astype(np.float32),axes=axes,nthreads=nthreads),axes=axes,inorm=2,lastsize=lastsize,nthreads=nthreads) b=irfftn(rfftn(a.real.astype(np.float32),axes=axes,nthreads=nthreads),axes=axes,inorm=2,lastsize=lastsize,nthreads=nthreads)
err = _l2error(a.real.astype(np.float32),b) err = _l2error(a.real.astype(np.float32),b)
if err > fmaxerrf: if err > fmaxerrf:
fmaxerrf = err fmaxerrf = err
......
...@@ -15,77 +15,101 @@ len1D=range(1,2048) ...@@ -15,77 +15,101 @@ len1D=range(1,2048)
def _l2error(a,b): def _l2error(a,b):
return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2)) return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2))
def fftn(a, axes=None, inorm=0, inplace=False, nthreads=1):
return pypocketfft.c2c(a, axes=axes, forward=True, inorm=inorm,
inplace=inplace, nthreads=nthreads)
def ifftn(a, axes=None, inorm=0, inplace=False, nthreads=1):
return pypocketfft.c2c(a, axes=axes, forward=False, inorm=inorm,
inplace=inplace, nthreads=nthreads)
def rfftn(a, axes=None, inorm=0, nthreads=1):
return pypocketfft.r2c(a, axes=axes, forward=True, inorm=inorm,
nthreads=nthreads)
def irfftn(a, axes=None, lastsize=0, inorm=0, nthreads=1):
return pypocketfft.c2r(a, axes=axes, lastsize=lastsize,forward=False,
inorm=inorm, nthreads=nthreads)
def rfft_scipy(a, axis, inorm=0, inplace=False, nthreads=1):
return pypocketfft.r2r_fftpack(a, axes=(axis,), input_halfcomplex=False,