Commit 92cfcfdd authored by Martin Reinecke's avatar Martin Reinecke
Browse files

add tests and fix a few bugs

parent 07c6a8e7
......@@ -228,7 +228,7 @@ class sincos_2pibyn
public:
sincos_2pibyn(size_t n, bool half)
: data(half ? n : 2*n)
: data(2*n)
{
sincos_2pibyn_half(n, data.data());
if (!half) fill_second_half(n, data.data());
......@@ -1653,7 +1653,7 @@ void factorize()
while ((len%divisor)==0)
{
add_factor(divisor);
length/=divisor;
len/=divisor;
}
maxl=(size_t)(sqrt((double)len))+1;
}
......@@ -2040,21 +2040,28 @@ template<typename T> void pocketfft_general_r(int ndim, const size_t *shape,
bool forward, const T *data_in, T *data_out, T fct)
{
// allocate temporary 1D array storage, if necessary
size_t tmpsize = shape[axis];
arr<T> tdata(tmpsize);
bool inplace = (data_in==data_out) && (stride_in[axis]==1);
arr<T> tdata(inplace ? 0 : shape[axis]);
multiarr a_in(ndim, shape, stride_in), a_out(ndim, shape, stride_out);
pocketfft_r<T> plan(shape[axis]);
multi_iter it_in(a_in, axis), it_out(a_out, axis);
while (!it_in.done())
{
for (size_t i=0; i<it_in.length(); ++i)
tdata[i] = data_in[it_in.offset()+i*it_in.stride()];
forward ? plan.forward((T *)tdata.data(), fct)
: plan.backward((T *)tdata.data(), fct);
for (size_t i=0; i<it_out.length(); ++i)
data_out[it_out.offset()+i*it_out.stride()] = tdata[i];
if (inplace)
forward ? plan.forward(data_out+it_in.offset(), fct)
: plan.backward(data_out+it_in.offset(), fct);
else
{
for (size_t i=0; i<it_in.length(); ++i)
tdata[i] = data_in[it_in.offset()+i*it_in.stride()];
forward ? plan.forward((T *)tdata.data(), fct)
: plan.backward((T *)tdata.data(), fct);
for (size_t i=0; i<it_out.length(); ++i)
data_out[it_out.offset()+i*it_out.stride()] = tdata[i];
}
it_in.advance();
it_out.advance();
if (!inplace)
it_out.advance();
}
}
......
import pypocketfft
import pyfftw
import numpy as np
import pytest
from numpy.testing import assert_allclose
pmp = pytest.mark.parametrize
shapes = ((10,), (127,), (128, 128), (128, 129), (1, 129), (129,1), (32,17,39))
@pmp("shp", shapes)
def test_fftn(shp):
a=np.random.rand(*shp)-0.5 + 1j*np.random.rand(*shp)-0.5j
vmax = np.max(np.abs(pyfftw.interfaces.numpy_fft.fftn(a)))
assert_allclose(pyfftw.interfaces.numpy_fft.fftn(a), pypocketfft.fftn(a), atol=2e-15*vmax, rtol=0)
a=a.astype(np.complex64)
assert_allclose(pyfftw.interfaces.numpy_fft.fftn(a), pypocketfft.fftn(a), atol=6e-7*vmax, rtol=0)
@pmp("shp", shapes)
def test_identity(shp):
a=np.random.rand(*shp)-0.5 + 1j*np.random.rand(*shp)-0.5j
vmax = np.max(np.abs(a))
assert_allclose(pypocketfft.ifftn(pypocketfft.fftn(a))/a.size, a, atol=2e-15*vmax, rtol=0)
a=a.astype(np.complex64)
assert_allclose(pypocketfft.ifftn(pypocketfft.fftn(a))/a.size, a, atol=6e-7*vmax, rtol=0)
@pmp("shp", shapes)
def test_identity_r(shp):
a=np.random.rand(*shp)-0.5
b=a.astype(np.float32)
vmax = np.max(np.abs(a))
for ax in range(a.ndim):
n = a.shape[ax]
assert_allclose(pypocketfft.irfftn(pypocketfft.rfftn(a,ax),ax)/n, a, atol=2e-15*vmax, rtol=0)
assert_allclose(pypocketfft.irfftn(pypocketfft.rfftn(b,ax),ax)/n, b, atol=6e-7*vmax, rtol=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment