diff --git a/bench.py b/bench.py index 62924447b238b4f69917f055af32d1b82012789f..095b2baa2f5650dfdf821275ad5cd2d70ed3ff79 100644 --- a/bench.py +++ b/bench.py @@ -2,7 +2,6 @@ import numpy as np import pypocketfft from time import time import matplotlib.pyplot as plt -import math def _l2error(a, b): @@ -43,9 +42,10 @@ def measure_fftw_np_interface(a, nrepeat, nthr): def measure_pypocketfft(a, nrepeat, nthr): import pypocketfft as ppf tmin = 1e38 + b = a.copy() for i in range(nrepeat): t0 = time() - b = ppf.c2c(a, forward=True, nthreads=nthr) + b = ppf.c2c(a, out=b, forward=True, nthreads=nthr) t1 = time() tmin = min(tmin, t1-t0) return tmin, b diff --git a/bench_r2c.py b/bench_r2c.py index 3ea12fb5aafa80513958d973f9310de6d5aeaa9b..37efb519b29ee6280aa04a303b97488e0589af59 100644 --- a/bench_r2c.py +++ b/bench_r2c.py @@ -2,7 +2,13 @@ import numpy as np import pypocketfft from time import time import matplotlib.pyplot as plt -import math + + +def get_complex_array(real_array, allocfunc): + tocomplex = { np.float32: np.complex64, np.float64: np.complex128 } + shape = list(real_array.shape) + shape[-1] = shape[-1]//2 + 1 + return allocfunc(shape, dtype=tocomplex[real_array.dtype.type]) def _l2error(a, b): @@ -12,11 +18,7 @@ def _l2error(a, b): def measure_fftw(a, nrepeat, nthr, flags=('FFTW_MEASURE',)): import pyfftw f1 = pyfftw.empty_aligned(a.shape, dtype=a.dtype) - tval = np.ones(1).astype(f1.dtype) - t2 = (tval+1j*tval).dtype - shape_out=list(a.shape) - shape_out[-1] = shape_out[-1]//2 + 1 - f2 = pyfftw.empty_aligned(shape_out, dtype=t2) + f2 = get_complex_array(a, pyfftw.empty_aligned) fftw = pyfftw.FFTW(f1, f2, flags=flags, axes=range(a.ndim), threads=nthr) f1[()] = a tmin = 1e38 @@ -47,9 +49,10 @@ def measure_fftw_np_interface(a, nrepeat, nthr): def measure_pypocketfft(a, nrepeat, nthr): import pypocketfft as ppf tmin = 1e38 + b = get_complex_array(a, np.empty) for i in range(nrepeat): t0 = time() - b = ppf.r2c(a, forward=True, nthreads=nthr) + b = ppf.r2c(a, forward=True, nthreads=nthr, out=b) t1 = time() tmin = min(tmin, t1-t0) return tmin, b diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index eff68499df0e8650728e5f4781371abfcda9ac3e..8549c9489118028f9c7294a36ca61e7114165957 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -3392,6 +3392,9 @@ template void r2r_genuine_hartley(const shape_t &shape, const T *data_in, T *data_out, T fct, size_t nthreads=1) { if (util::prod(shape)==0) return; + if (axes.size()==1) + return r2r_separable_hartley(shape, stride_in, stride_out, axes, data_in, + data_out, fct, nthreads); util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); shape_t tshp(shape); tshp[axes.back()] = tshp[axes.back()]/2+1;