Commit cef9199e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'develop' into 'master'

Develop

See merge request !3
parents 1307970b 39de2274
This diff is collapsed.
/*
* This file is part of pocketfft.
* Licensed under a 3-clause BSD style license - see LICENSE.md
*/
/*
* Python interface.
*
* Copyright (C) 2019 Max-Planck-Society
* \author Martin Reinecke
*/
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#pragma GCC visibility push(hidden)
#include "pocketfft_hdronly.h"
//
// Python interface
//
namespace {
using namespace std;
using namespace pocketfft;
using namespace pocketfft::detail;
namespace py = pybind11;
auto c64 = py::dtype("complex64");
auto c128 = py::dtype("complex128");
auto f32 = py::dtype("float32");
auto f64 = py::dtype("float64");
bool tcheck(const py::array &arr, const py::object &t1, const py::object &t2)
{
if (arr.dtype().is(t1))
return true;
if (arr.dtype().is(t2))
return false;
throw runtime_error("unsupported data type");
}
shape_t copy_shape(const py::array &arr)
{
shape_t res(arr.ndim());
for (size_t i=0; i<res.size(); ++i)
res[i] = arr.shape(i);
return res;
}
stride_t copy_strides(const py::array &arr)
{
stride_t res(arr.ndim());
for (size_t i=0; i<res.size(); ++i)
res[i] = arr.strides(i);
return res;
}
shape_t makeaxes(const py::array &in, py::object axes)
{
if (axes.is(py::none()))
{
shape_t res(in.ndim());
for (size_t i=0; i<res.size(); ++i)
res[i]=i;
return res;
}
auto tmp=axes.cast<shape_t>();
if ((tmp.size()>size_t(in.ndim())) || (tmp.size()==0))
throw runtime_error("bad axes argument");
for (auto sz: tmp)
if (sz>=size_t(in.ndim()))
throw runtime_error("invalid axis number");
return tmp;
}
template<typename T> py::array xfftn_internal(const py::array &in,
const shape_t &axes, double fct, bool inplace, bool fwd)
{
auto dims(copy_shape(in));
py::array res = inplace ? in : py::array_t<complex<T>>(dims);
ndarr<cmplx<T>> ain(in.data(), dims, copy_strides(in));
ndarr<cmplx<T>> aout(res.mutable_data(), dims, copy_strides(res));
general_c<T>(ain, aout, axes, fwd, fct);
return res;
}
py::array xfftn(const py::array &a, py::object axes, double fct, bool inplace,
bool fwd)
{
return tcheck(a, c128, c64) ?
xfftn_internal<double>(a, makeaxes(a, axes), fct, inplace, fwd) :
xfftn_internal<float> (a, makeaxes(a, axes), fct, inplace, fwd);
}
py::array fftn(const py::array &a, py::object axes, double fct, bool inplace)
{ return xfftn(a, axes, fct, inplace, true); }
py::array ifftn(const py::array &a, py::object axes, double fct, bool inplace)
{ return xfftn(a, axes, fct, inplace, false); }
template<typename T> py::array rfftn_internal(const py::array &in,
py::object axes_, T fct)
{
auto axes = makeaxes(in, axes_);
auto dims_in(copy_shape(in)), dims_out(dims_in);
dims_out[axes.back()] = (dims_out[axes.back()]>>1)+1;
py::array res = py::array_t<complex<T>>(dims_out);
ndarr<T> ain(in.data(), dims_in, copy_strides(in));
ndarr<cmplx<T>> aout(res.mutable_data(), dims_out, copy_strides(res));
general_r2c<T>(ain, aout, axes.back(), fct);
if (axes.size()==1) return res;
shape_t axes2(axes.begin(), --axes.end());
general_c<T>(aout, aout, axes2, true, 1.);
return res;
}
py::array rfftn(const py::array &in, py::object axes_, double fct)
{
return tcheck(in, f64, f32) ? rfftn_internal<double>(in, axes_, fct)
: rfftn_internal<float> (in, axes_, fct);
}
template<typename T> py::array xrfft_scipy(const py::array &in,
size_t axis, double fct, bool inplace, bool fwd)
{
auto dims(copy_shape(in));
py::array res = inplace ? in : py::array_t<T>(dims);
ndarr<T> ain(in.data(), dims, copy_strides(in));
ndarr<T> aout(res.mutable_data(), dims, copy_strides(res));
general_r<T>(ain, aout, axis, fwd, fct);
return res;
}
py::array rfft_scipy(const py::array &in, size_t axis, double fct, bool inplace)
{
return tcheck(in, f64, f32) ?
xrfft_scipy<double>(in, axis, fct, inplace, true) :
xrfft_scipy<float> (in, axis, fct, inplace, true);
}
py::array irfft_scipy(const py::array &in, size_t axis, double fct,
bool inplace)
{
return tcheck(in, f64, f32) ?
xrfft_scipy<double>(in, axis, fct, inplace, false) :
xrfft_scipy<float> (in, axis, fct, inplace, false);
}
template<typename T> py::array irfftn_internal(const py::array &in,
py::object axes_, size_t lastsize, T fct)
{
auto axes = makeaxes(in, axes_);
py::array inter = (axes.size()==1) ? in :
xfftn_internal<T>(in, shape_t(axes.begin(), --axes.end()), 1., false, false);
size_t axis = axes.back();
if (lastsize==0) lastsize=2*inter.shape(axis)-1;
if (ptrdiff_t(lastsize/2) + 1 != inter.shape(axis))
throw runtime_error("bad lastsize");
auto dims_out(copy_shape(inter));
dims_out[axis] = lastsize;
py::array res = py::array_t<T>(dims_out);
ndarr<cmplx<T>> ain(inter.data(), copy_shape(inter), copy_strides(inter));
ndarr<T> aout(res.mutable_data(), dims_out, copy_strides(res));
general_c2r<T>(ain, aout, axis, fct);
return res;
}
py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
double fct)
{
return tcheck(in, c128, c64) ?
irfftn_internal<double>(in, axes_, lastsize, fct) :
irfftn_internal<float> (in, axes_, lastsize, fct);
}
template<typename T> py::array hartley_internal(const py::array &in,
py::object axes_, double fct, bool inplace)
{
auto dims(copy_shape(in));
py::array res = inplace ? in : py::array_t<T>(dims);
ndarr<T> ain(in.data(), copy_shape(in), copy_strides(in));
ndarr<T> aout(res.mutable_data(), copy_shape(res), copy_strides(res));
general_hartley<T>(ain, aout, makeaxes(in, axes_), fct);
return res;
}
py::array hartley(const py::array &in, py::object axes_, double fct,
bool inplace)
{
return tcheck(in, f64, f32) ?
hartley_internal<double>(in, axes_, fct, inplace) :
hartley_internal<float> (in, axes_, fct, inplace);
}
template<typename T>py::array complex2hartley(const py::array &in,
const py::array &tmp, py::object axes_, bool inplace)
{
int ndim = in.ndim();
auto dims_out(copy_shape(in));
py::array out = inplace ? in : py::array_t<T>(dims_out);
ndarr<cmplx<T>> atmp(tmp.data(), copy_shape(tmp), copy_strides(tmp));
ndarr<T> aout(out.mutable_data(), copy_shape(out), copy_strides(out));
auto axes = makeaxes(in, axes_);
size_t axis = axes.back();
multi_iter<1,cmplx<T>,T> it(atmp, aout, axis);
vector<bool> swp(ndim,false);
for (auto i: axes)
if (i!=axis)
swp[i] = true;
while(it.remaining()>0)
{
ptrdiff_t rofs = 0;
for (size_t i=0; i<it.pos.size(); ++i)
{
if (i==axis) continue;
if (!swp[i])
rofs += it.pos[i]*it.oarr.stride(i);
else
{
auto x = (it.pos[i]==0) ? 0 : it.iarr.shape(i)-it.pos[i];
rofs += x*it.oarr.stride(i);
}
}
it.advance(1);
for (size_t i=0; i<it.length_in(); ++i)
{
auto re = it.in(i).r;
auto im = it.in(i).i;
auto rev_i = (i==0) ? 0 : it.length_out()-i;
it.out(i) = re+im;
aout[rofs + rev_i*it.stride_out()] = re-im;
}
}
return out;
}
py::array mycomplex2hartley(const py::array &in,
const py::array &tmp, py::object axes_, bool inplace)
{
return tcheck(in, f64, f32) ? complex2hartley<double>(in, tmp, axes_, inplace)
: complex2hartley<float> (in, tmp, axes_, inplace);
}
py::array hartley2(const py::array &in, py::object axes_, double fct,
bool inplace)
{ return mycomplex2hartley(in, rfftn(in, axes_, fct), axes_, inplace); }
const char *pypocketfft_DS = R"DELIM(Fast Fourier and Hartley transforms.
This module supports
- single and double precision
- complex and real-valued transforms
- multi-dimensional transforms
For two- and higher-dimensional transforms the code will use SSE2 and AVX
vector instructions for faster execution if these are supported by the CPU and
were enabled during compilation.
)DELIM";
const char *fftn_DS = R"DELIM(
Performs a forward complex FFT.
Parameters
----------
a : numpy.ndarray (np.complex64 or np.complex128)
The input data
axes : list of integers
The axes along which the FFT is carried out.
If not set, all axes will be transformed.
fct : float
Normalization factor
inplace : bool
if False, returns the result in a new array and leaves the input unchanged.
if True, stores the result in the input array and returns a handle to it.
Returns
-------
np.ndarray (same shape and data type as a)
The transformed data.
)DELIM";
const char *ifftn_DS = R"DELIM(Performs a backward complex FFT.
Parameters
----------
a : numpy.ndarray (np.complex64 or np.complex128)
The input data
axes : list of integers
The axes along which the FFT is carried out.
If not set, all axes will be transformed.
fct : float
Normalization factor
inplace : bool
if False, returns the result in a new array and leaves the input unchanged.
if True, stores the result in the input array and returns a handle to it.
Returns
-------
np.ndarray (same shape and data type as a)
The transformed data
)DELIM";
const char *rfftn_DS = R"DELIM(Performs a forward real-valued FFT.
Parameters
----------
a : numpy.ndarray (np.float32 or np.float64)
The input data
axes : list of integers
The axes along which the FFT is carried out.
If not set, all axes will be transformed in ascending order.
fct : float
Normalization factor
Returns
-------
np.ndarray (np.complex64 or np.complex128)
The transformed data. The shape is identical to that of the input array,
except for the axis that was transformed last. If the length of that axis
was n on input, it is n//2+1 on output.
)DELIM";
const char *rfft_scipy_DS = R"DELIM(Performs a forward real-valued FFT.
Parameters
----------
a : numpy.ndarray (np.float32 or np.float64)
The input data
axis : int
The axis along which the FFT is carried out.
fct : float
Normalization factor
inplace : bool
if False, returns the result in a new array and leaves the input unchanged.
if True, stores the result in the input array and returns a handle to it.
Returns
-------
np.ndarray (np.float32 or np.float64)
The transformed data. The shape is identical to that of the input array.
Along the transformed axis, values are arranged in
FFTPACK half-complex order, i.e. `a[0].re, a[1].re, a[1].im, a[2].re ...`.
)DELIM";
const char *irfftn_DS = R"DELIM(Performs a backward real-valued FFT.
Parameters
----------
a : numpy.ndarray (np.complex64 or np.complex128)
The input data
axes : list of integers
The axes along which the FFT is carried out.
If not set, all axes will be transformed in ascending order.
lastsize : the output size of the last axis to be transformed.
If the corresponding input axis has size n, this can be 2*n-2 or 2*n-1.
fct : float
Normalization factor
Returns
-------
np.ndarray (np.float32 or np.float64)
The transformed data. The shape is identical to that of the input array,
except for the axis that was transformed last, which has now `lastsize`
entries.
)DELIM";
const char *irfft_scipy_DS = R"DELIM(Performs a backward real-valued FFT.
Parameters
----------
a : numpy.ndarray (np.float32 or np.float64)
The input data. Along the transformed axis, values are expected in
FFTPACK half-complex order, i.e. `a[0].re, a[1].re, a[1].im, a[2].re ...`.
axis : int
The axis along which the FFT is carried out.
fct : float
Normalization factor
inplace : bool
if False, returns the result in a new array and leaves the input unchanged.
if True, stores the result in the input array and returns a handle to it.
Returns
-------
np.ndarray (np.float32 or np.float64)
The transformed data. The shape is identical to that of the input array.
)DELIM";
const char *hartley_DS = R"DELIM(Performs a Hartley transform.
For every requested axis, a 1D forward Fourier transform is carried out,
and the sum of real and imaginary parts of the result is stored in the output
array.
Parameters
----------
a : numpy.ndarray (np.float32 or np.float64)
The input data
axes : list of integers
The axes along which the transform is carried out.
If not set, all axes will be transformed.
fct : float
Normalization factor
inplace : bool
if False, returns the result in a new array and leaves the input unchanged.
if True, stores the result in the input array and returns a handle to it.
Returns
-------
np.ndarray (same shape and data type as a)
The transformed data
)DELIM";
} // unnamed namespace
#pragma GCC visibility pop
PYBIND11_MODULE(pypocketfft, m)
{
using namespace pybind11::literals;
m.doc() = pypocketfft_DS;
m.def("fftn",&fftn, fftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.,
"inplace"_a=false);
m.def("ifftn",&ifftn, ifftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.,
"inplace"_a=false);
m.def("rfftn",&rfftn, rfftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.);
m.def("rfft_scipy",&rfft_scipy, rfft_scipy_DS, "a"_a, "axis"_a, "fct"_a=1.,
"inplace"_a=false);
m.def("irfftn",&irfftn, irfftn_DS, "a"_a, "axes"_a=py::none(), "lastsize"_a=0,
"fct"_a=1.);
m.def("irfft_scipy",&irfft_scipy, irfft_scipy_DS, "a"_a, "axis"_a, "fct"_a=1.,
"inplace"_a=false);
m.def("hartley",&hartley, hartley_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.,
"inplace"_a=false);
m.def("hartley2",&hartley2, "a"_a, "axes"_a=py::none(), "fct"_a=1.,
"inplace"_a=false);
m.def("complex2hartley",&mycomplex2hartley, "in"_a, "tmp"_a, "axes"_a,
"inplace"_a=false);
}
......@@ -32,14 +32,6 @@ class _deferred_pybind11_include(object):
return pybind11.get_include(self.user)
def _get_cc_files(directory):
path = directory
iterable_sources = (iglob(os.path.join(root, '*.cc'))
for root, dirs, files in os.walk(path))
source_files = itertools.chain.from_iterable(iterable_sources)
return list(source_files)
def _remove_strict_prototype_option_from_distutils_config():
strict_prototypes = '-Wstrict-prototypes'
config = distutils.sysconfig.get_config_vars()
......@@ -61,7 +53,7 @@ python_module_link_args = []
base_library_link_args = []
if sys.platform == 'darwin':
extra_cc_compile_args.append('--std=c++14')
extra_cc_compile_args.append('--std=c++11')
extra_cc_compile_args.append('--stdlib=libc++')
extra_compile_args.append('-mmacosx-version-min=10.9')
......@@ -73,7 +65,7 @@ if sys.platform == 'darwin':
else:
extra_compile_args += ['-march=native', '-O3', '-Wfatal-errors', '-Wno-ignored-attributes']
python_module_link_args += ['-march=native']
extra_cc_compile_args.append('--std=c++14')
extra_cc_compile_args.append('--std=c++11')
python_module_link_args.append("-Wl,-rpath,$ORIGIN")
extra_cc_compile_args = extra_compile_args + extra_cc_compile_args
......@@ -82,7 +74,7 @@ extra_cc_compile_args = extra_compile_args + extra_cc_compile_args
def get_extension_modules():
extension_modules = []
pocketfft_sources = ['pocketfft.cc']
pocketfft_sources = ['pypocketfft.cc']
pocketfft_library = Extension('pypocketfft',
sources=pocketfft_sources,
include_dirs=include_dirs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment