Commit 572fc723 authored by Martin Reinecke's avatar Martin Reinecke

better interface, some docstrings

parent d68a9f1e
......@@ -63,7 +63,8 @@ class sincos_2pibyn
arr<double> data;
// adapted from https://stackoverflow.com/questions/42792939/
// CAUTION: this function only works for arguments in the range [-0.25; 0.25]!
// CAUTION: this function only works for arguments in the range
// [-0.25; 0.25]!
void my_sincosm1pi (double a, double *restrict res)
{
double s = a * a;
......@@ -2080,35 +2081,6 @@ template<typename T> void pocketfft_general_hartley(int ndim, const size_t *shap
}
}
template<typename T> void pocketfft_general_r(int ndim, const size_t *shape,
const int64_t *stride_in, const int64_t *stride_out, size_t axis,
bool forward, const T *data_in, T *data_out, T fct)
{
// allocate temporary 1D array storage, if necessary
bool inplace = (data_in==data_out) && (stride_in[axis]==1);
arr<T> tdata(inplace ? 0 : shape[axis]);
multiarr a_in(ndim, shape, stride_in), a_out(ndim, shape, stride_out);
pocketfft_r<T> plan(shape[axis]);
multi_iter it_in(a_in, axis), it_out(a_out, axis);
while (!it_in.done())
{
if (inplace)
forward ? plan.forward(data_out+it_in.offset(), fct)
: plan.backward(data_out+it_in.offset(), fct);
else
{
for (size_t i=0; i<it_in.length(); ++i)
tdata[i] = data_in[it_in.offset()+i*it_in.stride()];
forward ? plan.forward((T *)tdata.data(), fct)
: plan.backward((T *)tdata.data(), fct);
for (size_t i=0; i<it_out.length(); ++i)
data_out[it_out.offset()+i*it_out.stride()] = tdata[i];
}
it_in.advance();
if (!inplace)
it_out.advance();
}
}
template<typename T> void pocketfft_general_r2c(int ndim, const size_t *shape,
const int64_t *stride_in, const int64_t *stride_out, size_t axis,
const T *data_in, cmplx<T> *data_out, T fct)
......@@ -2203,7 +2175,7 @@ void check_args(const py::array &in, vector<size_t> &axes)
throw runtime_error("invalid axis number");
}
py::array execute(const py::array &in, vector<size_t> axes, const string &norm, bool fwd)
py::array execute(const py::array &in, vector<size_t> axes, double fct, bool fwd)
{
check_args(in, axes);
py::array res;
......@@ -2211,15 +2183,6 @@ py::array execute(const py::array &in, vector<size_t> axes, const string &norm,
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
for (int i=0; i<in.ndim(); ++i)
dims[i] = in.shape(i);
size_t n_elem=1;
for (size_t i=0; i<axes.size(); ++i)
n_elem *= in.shape(axes[i]);
double fct=1;
if (norm=="ortho")
fct=1./sqrt(n_elem);
else
if (!fwd)
fct = 1./n_elem;
if (in.dtype().is(c128))
res = py::array_t<complex<double>>(dims);
else if (in.dtype().is(c64))
......@@ -2236,22 +2199,22 @@ py::array execute(const py::array &in, vector<size_t> axes, const string &norm,
}
if (in.dtype().is(c128))
pocketfft_general_c<double>(in.ndim(), dims.data(),
s_i.data(), s_o.data(), axes.size(),
axes.data(), fwd, (const cmplx<double> *)in.data(),
(cmplx<double> *)res.mutable_data(), fct);
s_i.data(), s_o.data(), axes.size(),
axes.data(), fwd, (const cmplx<double> *)in.data(),
(cmplx<double> *)res.mutable_data(), fct);
else
pocketfft_general_c<float>(in.ndim(), dims.data(),
s_i.data(), s_o.data(), axes.size(),
axes.data(), fwd, (const cmplx<float> *)in.data(),
(cmplx<float> *)res.mutable_data(), fct);
s_i.data(), s_o.data(), axes.size(),
axes.data(), fwd, (const cmplx<float> *)in.data(),
(cmplx<float> *)res.mutable_data(), fct);
return res;
}
py::array fftn(const py::array &a, const vector<size_t> &axes, const string &norm)
{ return execute(a, axes, norm, true); }
py::array ifftn(const py::array &a, const vector<size_t> &axes, const string &norm)
{ return execute(a, axes, norm, false); }
py::array fftn(const py::array &a, const vector<size_t> &axes, double fct)
{ return execute(a, axes, fct, true); }
py::array ifftn(const py::array &a, const vector<size_t> &axes, double fct)
{ return execute(a, axes, fct, false); }
py::array rfftn(const py::array &in, vector<size_t> axes, const string &norm)
py::array rfftn(const py::array &in, vector<size_t> axes, double fct)
{
check_args(in, axes);
py::array res;
......@@ -2263,12 +2226,6 @@ py::array rfftn(const py::array &in, vector<size_t> axes, const string &norm)
dims_out[i] = in.shape(i);
}
dims_out[axes.back()] = (dims_out[axes.back()]>>1)+1;
size_t n_elem=1;
for (size_t i=0; i<axes.size(); ++i)
n_elem *= in.shape(axes[i]);
double fct=1;
if (norm=="ortho")
fct=1./sqrt(n_elem);
if (in.dtype().is(f64))
res = py::array_t<complex<double>>(dims_out);
else if (in.dtype().is(f32))
......@@ -2286,28 +2243,31 @@ py::array rfftn(const py::array &in, vector<size_t> axes, const string &norm)
if (in.dtype().is(f64))
{
pocketfft_general_r2c<double>(in.ndim(), dims_in.data(),
s_i.data(), s_o.data(),
axes.back(), (const double *)in.data(),
(cmplx<double> *)res.mutable_data(), fct);
pocketfft_general_c<double>(res.ndim(), dims_out.data(),
s_o.data(), s_o.data(), axes.size()-1,
axes.data(), true, (const cmplx<double> *)res.data(),
(cmplx<double> *)res.mutable_data(), 1.);
s_i.data(), s_o.data(),
axes.back(), (const double *)in.data(),
(cmplx<double> *)res.mutable_data(), fct);
if (axes.size()>1)
pocketfft_general_c<double>(res.ndim(), dims_out.data(),
s_o.data(), s_o.data(), axes.size()-1,
axes.data(), true, (const cmplx<double> *)res.data(),
(cmplx<double> *)res.mutable_data(), 1.);
}
else
{
pocketfft_general_r2c<float>(in.ndim(), dims_in.data(),
s_i.data(), s_o.data(),
axes.back(), (const float *)in.data(),
(cmplx<float> *)res.mutable_data(), fct);
pocketfft_general_c<float>(res.ndim(), dims_out.data(),
s_o.data(), s_o.data(), axes.size()-1,
axes.data(), true, (const cmplx<float> *)res.data(),
(cmplx<float> *)res.mutable_data(), 1.);
s_i.data(), s_o.data(),
axes.back(), (const float *)in.data(),
(cmplx<float> *)res.mutable_data(), fct);
if (axes.size()>1)
pocketfft_general_c<float>(res.ndim(), dims_out.data(),
s_o.data(), s_o.data(), axes.size()-1,
axes.data(), true, (const cmplx<float> *)res.data(),
(cmplx<float> *)res.mutable_data(), 1.);
}
return res;
}
py::array irfftn(const py::array &in, vector<size_t> axes, size_t lastsize, const string &norm)
py::array irfftn(const py::array &in, vector<size_t> axes, size_t lastsize,
double fct)
{
check_args(in, axes);
py::array inter;
......@@ -2316,14 +2276,14 @@ py::array irfftn(const py::array &in, vector<size_t> axes, size_t lastsize, cons
vector<size_t> axes2(axes.size()-1);
for (size_t i=0; i<axes2.size(); ++i)
axes2[i] = axes[i];
inter = ifftn(in, axes2, norm);
inter = ifftn(in, axes2, 1.);
}
else
inter = in;
size_t axis = axes.back();
if (lastsize==0) lastsize=2*(inter.shape(axis)/2);
if (lastsize/2 + 1 != inter.shape(axis))
if (lastsize==0) lastsize=1+(inter.shape(axis)/2);
if (int64_t(lastsize/2) + 1 != inter.shape(axis))
throw runtime_error("bad lastsize");
vector<size_t> dims_in(inter.ndim()), dims_out(inter.ndim());
vector<int64_t> s_i(inter.ndim()), s_o(inter.ndim());
......@@ -2333,11 +2293,6 @@ py::array irfftn(const py::array &in, vector<size_t> axes, size_t lastsize, cons
dims_out[i] = inter.shape(i);
}
dims_out[axis] = lastsize;
double fct=1;
if (norm=="ortho")
fct=1./sqrt(lastsize);
else
fct = 1./lastsize;
py::array res;
if (inter.dtype().is(c128))
res = py::array_t<double>(dims_out);
......@@ -2355,17 +2310,17 @@ py::array irfftn(const py::array &in, vector<size_t> axes, size_t lastsize, cons
}
if (inter.dtype().is(c128))
pocketfft_general_c2r<double>(inter.ndim(), dims_out.data(),
s_i.data(), s_o.data(),
axis, (const cmplx<double> *)inter.data(),
(double *)res.mutable_data(), fct);
s_i.data(), s_o.data(),
axis, (const cmplx<double> *)inter.data(),
(double *)res.mutable_data(), fct);
else
pocketfft_general_c2r<float>(inter.ndim(), dims_out.data(),
s_i.data(), s_o.data(),
axis, (const cmplx<float> *)inter.data(),
(float *)res.mutable_data(), fct);
s_i.data(), s_o.data(),
axis, (const cmplx<float> *)inter.data(),
(float *)res.mutable_data(), fct);
return res;
}
py::array hartley(const py::array &in, vector<size_t> axes)
py::array hartley(const py::array &in, vector<size_t> axes, double fct)
{
check_args(in, axes);
py::array res;
......@@ -2399,15 +2354,70 @@ py::array hartley(const py::array &in, vector<size_t> axes)
(float *)res.mutable_data(), 1.);
return res;
}
const char *fftn_DS = R"DELIM(
Performs a forward complex FFT.
Parameters
----------
a : numpy.ndarray (np.complex64 or np.complex128)
The input data
axes : list of integers
The axes on which the FFT is carried out.
If not set, all axes will be transformed.
Returns
-------
np.ndarray (same shape and data type as a)
the transformed data
)DELIM";
const char *ifftn_DS = R"DELIM(
Performs a backward complex FFT.
Parameters
----------
a : numpy.ndarray (np.complex64 or np.complex128)
The input data
axes : list of integers
The axes on which the FFT is carried out.
If not set, all axes will be transformed.
Returns
-------
np.ndarray (same shape and data type as a)
the transformed data
)DELIM";
const char *hartley_DS = R"DELIM(
Performs a Hartley FFT.
For every requested axis, a 1D forward Fourier transform is carried out,
and the sum of real and imaginary parts of the result is stored in the output
array.
Parameters
----------
a : numpy.ndarray (np.float32 or np.float64)
The input data
axes : list of integers
The axes on which the transform is carried out.
If not set, all axes will be transformed.
Returns
-------
np.ndarray (same shape and data type as a)
the transformed data
)DELIM";
} // unnamed namespace
PYBIND11_MODULE(pypocketfft, m)
{
using namespace pybind11::literals;
m.def("fftn",&fftn, "a"_a, "axes"_a=vector<size_t>(1,10000), "norm"_a=string(""));
m.def("ifftn",&ifftn, "a"_a, "axes"_a=vector<size_t>(1,10000), "norm"_a=string(""));
m.def("rfftn",&rfftn, "a"_a, "axes"_a=vector<size_t>(1,10000), "norm"_a=string(""));
m.def("irfftn",&irfftn, "a"_a, "axes"_a=vector<size_t>(1,10000), "lastsize"_a=0, "norm"_a=string(""));
m.def("hartley",&hartley, "a"_a, "axes"_a=vector<size_t>(1,10000));
m.def("fftn",&fftn, fftn_DS, "a"_a, "axes"_a=vector<size_t>(1,10000), "fct"_a=1.);
m.def("ifftn",&ifftn, ifftn_DS, "a"_a, "axes"_a=vector<size_t>(1,10000), "fct"_a=1.);
m.def("rfftn",&rfftn, "a"_a, "axes"_a=vector<size_t>(1,10000), "fct"_a=1.);
m.def("irfftn",&irfftn, "a"_a, "axes"_a=vector<size_t>(1,10000), "lastsize"_a=0, "fct"_a=1.);
m.def("hartley",&hartley, hartley_DS, "a"_a, "axes"_a=vector<size_t>(1,10000), "fct"_a=1.);
}
......@@ -16,7 +16,6 @@ def test_fftn(shp):
a=np.random.rand(*shp)-0.5 + 1j*np.random.rand(*shp)-0.5j
vmax = np.max(np.abs(pyfftw.interfaces.numpy_fft.fftn(a)))
assert_allclose(pyfftw.interfaces.numpy_fft.fftn(a), pypocketfft.fftn(a), atol=2e-15*vmax, rtol=0)
assert_allclose(pyfftw.interfaces.numpy_fft.fftn(a,norm="ortho"), pypocketfft.fftn(a,norm="ortho"), atol=2e-15*vmax, rtol=0)
a=a.astype(np.complex64)
assert_allclose(pyfftw.interfaces.numpy_fft.fftn(a), pypocketfft.fftn(a), atol=6e-7*vmax, rtol=0)
......@@ -34,7 +33,6 @@ def test_rfftn(shp):
a=np.random.rand(*shp)-0.5
vmax = np.max(np.abs(pyfftw.interfaces.numpy_fft.fftn(a)))
assert_allclose(pyfftw.interfaces.numpy_fft.rfftn(a), pypocketfft.rfftn(a), atol=2e-15*vmax, rtol=0)
assert_allclose(pyfftw.interfaces.numpy_fft.rfftn(a,norm="ortho"), pypocketfft.rfftn(a,norm="ortho"), atol=2e-15*vmax, rtol=0)
a=a.astype(np.float32)
assert_allclose(pyfftw.interfaces.numpy_fft.rfftn(a), pypocketfft.rfftn(a), atol=6e-7*vmax, rtol=0)
......@@ -51,10 +49,9 @@ def test_rfftn2D(shp, axes):
def test_identity(shp):
a=np.random.rand(*shp)-0.5 + 1j*np.random.rand(*shp)-0.5j
vmax = np.max(np.abs(a))
assert_allclose(pypocketfft.ifftn(pypocketfft.fftn(a)), a, atol=2e-15*vmax, rtol=0)
assert_allclose(pypocketfft.ifftn(pypocketfft.fftn(a,norm="ortho"), norm="ortho"), a, atol=2e-15*vmax, rtol=0)
assert_allclose(pypocketfft.ifftn(pypocketfft.fftn(a),fct=1./a.size), a, atol=2e-15*vmax, rtol=0)
a=a.astype(np.complex64)
assert_allclose(pypocketfft.ifftn(pypocketfft.fftn(a)), a, atol=6e-7*vmax, rtol=0)
assert_allclose(pypocketfft.ifftn(pypocketfft.fftn(a),fct=1./a.size), a, atol=6e-7*vmax, rtol=0)
@pmp("shp", shapes)
def test_identity_r(shp):
......@@ -63,9 +60,8 @@ def test_identity_r(shp):
vmax = np.max(np.abs(a))
for ax in range(a.ndim):
n = a.shape[ax]
assert_allclose(pypocketfft.irfftn(pypocketfft.rfftn(a,(ax,)),(ax,),n), a, atol=2e-15*vmax, rtol=0)
assert_allclose(pypocketfft.irfftn(pypocketfft.rfftn(a,(ax,),norm="ortho"),(ax,),n,norm="ortho"), a, atol=2e-15*vmax, rtol=0)
assert_allclose(pypocketfft.irfftn(pypocketfft.rfftn(b,(ax,)),(ax,),n), b, atol=6e-7*vmax, rtol=0)
assert_allclose(pypocketfft.irfftn(pypocketfft.rfftn(a,(ax,)),(ax,),lastsize=n,fct=1./n), a, atol=2e-15*vmax, rtol=0)
assert_allclose(pypocketfft.irfftn(pypocketfft.rfftn(b,(ax,)),(ax,),lastsize=n,fct=1./n), b, atol=6e-7*vmax, rtol=0)
@pmp("shp", shapes3D)
def xtest_hartley(shp):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment