Commit 84bcc100 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

add support for inplace transforms

parent 8e275de8
...@@ -5,6 +5,8 @@ import pypocketfft ...@@ -5,6 +5,8 @@ import pypocketfft
from time import time from time import time
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def _l2error(a,b):
return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2))
def bench_nd_fftn(ndim, nmax, ntry, tp, nrepeat, filename=""): def bench_nd_fftn(ndim, nmax, ntry, tp, nrepeat, filename=""):
res=[] res=[]
...@@ -23,6 +25,8 @@ def bench_nd_fftn(ndim, nmax, ntry, tp, nrepeat, filename=""): ...@@ -23,6 +25,8 @@ def bench_nd_fftn(ndim, nmax, ntry, tp, nrepeat, filename=""):
b=pypocketfft.fftn(a) b=pypocketfft.fftn(a)
t1=time() t1=time()
tmin_pp = min(tmin_pp,t1-t0) tmin_pp = min(tmin_pp,t1-t0)
a2=pypocketfft.ifftn(b,fct=1./a.size)
assert(_l2error(a,a2)<(2e-15 if tp=='c16' else 6e-7))
res.append(tmin_pp/tmin_np) res.append(tmin_pp/tmin_np)
plt.title("t(pypocketfft / numpy 1.17), {}D, {}, max_extent={}".format(ndim, str(tp), nmax)) plt.title("t(pypocketfft / numpy 1.17), {}D, {}, max_extent={}".format(ndim, str(tp), nmax))
plt.xlabel("time ratio") plt.xlabel("time ratio")
......
...@@ -2060,27 +2060,43 @@ template<> struct VTYPE<float> ...@@ -2060,27 +2060,43 @@ template<> struct VTYPE<float>
#endif #endif
template<typename T> arr<char> alloc_tmp(const shape_t &shape, template<typename T> arr<char> alloc_tmp(const shape_t &shape,
size_t tmpsize, size_t elemsize) size_t axsize, size_t elemsize)
{ {
#ifdef HAVE_VECSUPPORT #ifdef HAVE_VECSUPPORT
using vtype = typename VTYPE<T>::type; using vtype = typename VTYPE<T>::type;
constexpr int vlen = sizeof(vtype)/sizeof(T); constexpr int vlen = sizeof(vtype)/sizeof(T);
return arr<char>(tmpsize*elemsize*vlen); #else
constexpr int vlen = 1;
#endif #endif
size_t fullsize=1;
size_t ndim = shape.size();
for (size_t i=0; i<ndim; ++i)
fullsize*=shape[i];
auto othersize = fullsize/axsize;
auto tmpsize = axsize*((othersize>vlen) ? vlen : 1);
return arr<char>(tmpsize*elemsize); return arr<char>(tmpsize*elemsize);
} }
template<typename T> arr<char> alloc_tmp(const shape_t &shape, template<typename T> arr<char> alloc_tmp(const shape_t &shape,
const shape_t &axes, size_t elemsize) const shape_t &axes, size_t elemsize)
{ {
#ifdef HAVE_VECSUPPORT
using vtype = typename VTYPE<T>::type;
constexpr int vlen = sizeof(vtype)/sizeof(T);
#else
constexpr int vlen = 1;
#endif
size_t fullsize=1; size_t fullsize=1;
size_t ndim = shape.size(); size_t ndim = shape.size();
for (size_t i=0; i<ndim; ++i) for (size_t i=0; i<ndim; ++i)
fullsize*=shape[i]; fullsize*=shape[i];
size_t tmpsize=0; size_t tmpsize=0;
for (size_t i=0; i<axes.size(); ++i) for (size_t i=0; i<axes.size(); ++i)
tmpsize = max(tmpsize, shape[axes[i]]); {
auto axsize = shape[axes[i]];
return alloc_tmp<T>(shape, tmpsize, elemsize); auto othersize = fullsize/axsize;
tmpsize = max(tmpsize, axsize*((othersize>vlen) ? vlen : 1));
}
return arr<char>(tmpsize*elemsize);
} }
template<typename T> void pocketfft_general_c(const shape_t &shape, template<typename T> void pocketfft_general_c(const shape_t &shape,
...@@ -2413,26 +2429,27 @@ shape_t makeaxes(const py::array &in, py::object axes) ...@@ -2413,26 +2429,27 @@ shape_t makeaxes(const py::array &in, py::object axes)
} }
template<typename T> py::array xfftn_internal(const py::array &in, template<typename T> py::array xfftn_internal(const py::array &in,
const shape_t &axes, double fct, bool fwd) const shape_t &axes, double fct, bool inplace, bool fwd)
{ {
auto dims(copy_shape(in)); auto dims(copy_shape(in));
py::array res = py::array_t<complex<T>>(dims); py::array res = inplace ? in : py::array_t<complex<T>>(dims);
auto s_i(copy_strides(in)), s_o(copy_strides(res)); auto s_i(copy_strides(in)), s_o(copy_strides(res));
pocketfft_general_c<T>(dims, s_i, s_o, axes, fwd, (const cmplx<T> *)in.data(), pocketfft_general_c<T>(dims, s_i, s_o, axes, fwd, (const cmplx<T> *)in.data(),
(cmplx<T> *)res.mutable_data(), fct); (cmplx<T> *)res.mutable_data(), fct);
return res; return res;
} }
py::array xfftn(const py::array &a, py::object axes, double fct, bool fwd) py::array xfftn(const py::array &a, py::object axes, double fct, bool inplace,
bool fwd)
{ {
return tcheck(a, c128, c64) ? return tcheck(a, c128, c64) ?
xfftn_internal<double>(a, makeaxes(a, axes), fct, fwd) : xfftn_internal<double>(a, makeaxes(a, axes), fct, inplace, fwd) :
xfftn_internal<float> (a, makeaxes(a, axes), fct, fwd); xfftn_internal<float> (a, makeaxes(a, axes), fct, inplace, fwd);
} }
py::array fftn(const py::array &a, py::object axes, double fct) py::array fftn(const py::array &a, py::object axes, double fct, bool inplace)
{ return xfftn(a, axes, fct, true); } { return xfftn(a, axes, fct, inplace, true); }
py::array ifftn(const py::array &a, py::object axes, double fct) py::array ifftn(const py::array &a, py::object axes, double fct, bool inplace)
{ return xfftn(a, axes, fct, false); } { return xfftn(a, axes, fct, inplace, false); }
template<typename T> py::array rfftn_internal(const py::array &in, template<typename T> py::array rfftn_internal(const py::array &in,
py::object axes_, T fct) py::object axes_, T fct)
...@@ -2467,7 +2484,7 @@ template<typename T> py::array irfftn_internal(const py::array &in, ...@@ -2467,7 +2484,7 @@ template<typename T> py::array irfftn_internal(const py::array &in,
shape_t axes2(axes.size()-1); shape_t axes2(axes.size()-1);
for (size_t i=0; i<axes2.size(); ++i) for (size_t i=0; i<axes2.size(); ++i)
axes2[i] = axes[i]; axes2[i] = axes[i];
inter = xfftn_internal<T>(in, axes2, 1., false); inter = xfftn_internal<T>(in, axes2, 1., false, false);
} }
else else
inter = in; inter = in;
...@@ -2493,20 +2510,22 @@ py::array irfftn(const py::array &in, py::object axes_, size_t lastsize, ...@@ -2493,20 +2510,22 @@ py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
} }
template<typename T> py::array hartley_internal(const py::array &in, template<typename T> py::array hartley_internal(const py::array &in,
py::object axes_, double fct) py::object axes_, double fct, bool inplace)
{ {
auto axes(makeaxes(in, axes_)); auto axes(makeaxes(in, axes_));
auto dims(copy_shape(in)); auto dims(copy_shape(in));
py::array res = py::array_t<T>(dims); py::array res = inplace ? in : py::array_t<T>(dims);
auto s_i(copy_strides(in)), s_o(copy_strides(res)); auto s_i(copy_strides(in)), s_o(copy_strides(res));
pocketfft_general_hartley<T>(dims, s_i, s_o, axes, (const T *)in.data(), pocketfft_general_hartley<T>(dims, s_i, s_o, axes, (const T *)in.data(),
(T *)res.mutable_data(), 1.); (T *)res.mutable_data(), 1.);
return res; return res;
} }
py::array hartley(const py::array &in, py::object axes_, double fct) py::array hartley(const py::array &in, py::object axes_, double fct,
bool inplace)
{ {
return tcheck(in, f64, f32) ? hartley_internal<double>(in, axes_, fct) return tcheck(in, f64, f32) ?
: hartley_internal<float> (in, axes_, fct); hartley_internal<double>(in, axes_, fct, inplace) :
hartley_internal<float> (in, axes_, fct, inplace);
} }
template<typename T>py::array complex2hartley(const py::array &in, template<typename T>py::array complex2hartley(const py::array &in,
...@@ -2589,6 +2608,9 @@ axes : list of integers ...@@ -2589,6 +2608,9 @@ axes : list of integers
If not set, all axes will be transformed. If not set, all axes will be transformed.
fct : float fct : float
Normalization factor Normalization factor
inplace : bool
if False, returns the esult in a new array and leaves the input unchanged.
if True, stores the result in the input array and returns a handle to it.
Returns Returns
------- -------
...@@ -2607,6 +2629,9 @@ axes : list of integers ...@@ -2607,6 +2629,9 @@ axes : list of integers
If not set, all axes will be transformed. If not set, all axes will be transformed.
fct : float fct : float
Normalization factor Normalization factor
inplace : bool
if False, returns the esult in a new array and leaves the input unchanged.
if True, stores the result in the input array and returns a handle to it.
Returns Returns
------- -------
...@@ -2670,6 +2695,9 @@ axes : list of integers ...@@ -2670,6 +2695,9 @@ axes : list of integers
If not set, all axes will be transformed. If not set, all axes will be transformed.
fct : float fct : float
Normalization factor Normalization factor
inplace : bool
if False, returns the esult in a new array and leaves the input unchanged.
if True, stores the result in the input array and returns a handle to it.
Returns Returns
------- -------
...@@ -2684,11 +2712,15 @@ PYBIND11_MODULE(pypocketfft, m) ...@@ -2684,11 +2712,15 @@ PYBIND11_MODULE(pypocketfft, m)
using namespace pybind11::literals; using namespace pybind11::literals;
m.doc() = pypocketfft_DS; m.doc() = pypocketfft_DS;
m.def("fftn",&fftn, fftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.); m.def("fftn",&fftn, fftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.,
m.def("ifftn",&ifftn, ifftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.); "inplace"_a=false);
m.def("ifftn",&ifftn, ifftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.,
"inplace"_a=false);
m.def("rfftn",&rfftn, rfftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.); m.def("rfftn",&rfftn, rfftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.);
m.def("irfftn",&irfftn, irfftn_DS, "a"_a, "axes"_a=py::none(), "lastsize"_a=0, "fct"_a=1.); m.def("irfftn",&irfftn, irfftn_DS, "a"_a, "axes"_a=py::none(), "lastsize"_a=0,
m.def("hartley",&hartley, hartley_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.); "fct"_a=1.);
m.def("hartley",&hartley, hartley_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.,
"inplace"_a=false);
m.def("hartley2",&hartley2, "a"_a, "axes"_a=py::none(), "fct"_a=1.); m.def("hartley2",&hartley2, "a"_a, "axes"_a=py::none(), "fct"_a=1.);
m.def("complex2hartley",&mycomplex2hartley, "in"_a, "tmp"_a, "axes"_a); m.def("complex2hartley",&mycomplex2hartley, "in"_a, "tmp"_a, "axes"_a);
} }
...@@ -20,8 +20,14 @@ def test1D(len): ...@@ -20,8 +20,14 @@ def test1D(len):
b=a.astype(np.complex64) b=a.astype(np.complex64)
assert(_l2error(a, pypocketfft.ifftn(pypocketfft.fftn(a),fct=1./len))<1.5e-15) assert(_l2error(a, pypocketfft.ifftn(pypocketfft.fftn(a),fct=1./len))<1.5e-15)
assert(_l2error(a.real, pypocketfft.irfftn(pypocketfft.rfftn(a.real),fct=1./len,lastsize=len))<1e-15) assert(_l2error(a.real, pypocketfft.irfftn(pypocketfft.rfftn(a.real),fct=1./len,lastsize=len))<1e-15)
tmp=a.copy()
assert (pypocketfft.ifftn(pypocketfft.fftn(tmp, inplace=True), fct=1./len, inplace=True) is tmp)
assert(_l2error(tmp, a)<1.5e-15)
assert(_l2error(b, pypocketfft.ifftn(pypocketfft.fftn(b),fct=1./len))<5e-7) assert(_l2error(b, pypocketfft.ifftn(pypocketfft.fftn(b),fct=1./len))<5e-7)
assert(_l2error(b.real, pypocketfft.irfftn(pypocketfft.rfftn(b.real),fct=1./len,lastsize=len))<5e-7) assert(_l2error(b.real, pypocketfft.irfftn(pypocketfft.rfftn(b.real),fct=1./len,lastsize=len))<5e-7)
tmp=b.copy()
assert (pypocketfft.ifftn(pypocketfft.fftn(tmp, inplace=True), fct=1./len, inplace=True) is tmp)
assert(_l2error(tmp, b)<5e-7)
@pmp("shp", shapes) @pmp("shp", shapes)
def test_fftn(shp): def test_fftn(shp):
...@@ -62,6 +68,9 @@ def test_identity(shp): ...@@ -62,6 +68,9 @@ def test_identity(shp):
a=np.random.rand(*shp)-0.5 + 1j*np.random.rand(*shp)-0.5j a=np.random.rand(*shp)-0.5 + 1j*np.random.rand(*shp)-0.5j
vmax = np.max(np.abs(a)) vmax = np.max(np.abs(a))
assert(_l2error(pypocketfft.ifftn(pypocketfft.fftn(a),fct=1./a.size), a)<1e-15*vmax) assert(_l2error(pypocketfft.ifftn(pypocketfft.fftn(a),fct=1./a.size), a)<1e-15*vmax)
tmp=a.copy()
assert (pypocketfft.ifftn(pypocketfft.fftn(tmp, inplace=True), fct=1./a.size, inplace=True) is tmp)
assert(_l2error(tmp, a)<1.5e-15*vmax)
a=a.astype(np.complex64) a=a.astype(np.complex64)
assert(_l2error(pypocketfft.ifftn(pypocketfft.fftn(a),fct=1./a.size), a)<5e-7*vmax) assert(_l2error(pypocketfft.ifftn(pypocketfft.fftn(a),fct=1./a.size), a)<5e-7*vmax)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment