Commit 89d506fd authored by Martin Reinecke's avatar Martin Reinecke

allow real-valued arguments to ifftn()

parent d84ec6c3
...@@ -73,11 +73,13 @@ shape_t makeaxes(const py::array &in, py::object axes) ...@@ -73,11 +73,13 @@ shape_t makeaxes(const py::array &in, py::object axes)
} }
#define DISPATCH(arr, T1, T2, T3, func, args) \ #define DISPATCH(arr, T1, T2, T3, func, args) \
{ \
auto dtype = arr.dtype(); \ auto dtype = arr.dtype(); \
if (dtype.is(T1)) return func<double> args; \ if (dtype.is(T1)) return func<double> args; \
if (dtype.is(T2)) return func<float> args; \ if (dtype.is(T2)) return func<float> args; \
if (dtype.is(T3)) return func<ldbl_t> args; \ if (dtype.is(T3)) return func<ldbl_t> args; \
throw runtime_error("unsupported data type"); throw runtime_error("unsupported data type"); \
}
template<typename T> T norm_fct(int inorm, size_t N) template<typename T> T norm_fct(int inorm, size_t N)
{ {
...@@ -114,17 +116,9 @@ template<typename T> py::array xfftn_internal(const py::array &in, ...@@ -114,17 +116,9 @@ template<typename T> py::array xfftn_internal(const py::array &in,
return res; return res;
} }
py::array xfftn(const py::array &a, py::object axes, int inorm,
bool inplace, bool fwd, size_t nthreads)
{
DISPATCH(a, c128, c64, clong, xfftn_internal, (a, makeaxes(a, axes), inorm,
inplace, fwd, nthreads))
}
template<typename T> py::array sym_rfftn_internal(const py::array &in, template<typename T> py::array sym_rfftn_internal(const py::array &in,
py::object axes_, int inorm, size_t nthreads) const shape_t &axes, int inorm, bool fwd, size_t nthreads)
{ {
auto axes = makeaxes(in, axes_);
auto dims(copy_shape(in)); auto dims(copy_shape(in));
py::array res = py::array_t<complex<T>>(dims); py::array res = py::array_t<complex<T>>(dims);
auto s_in=copy_strides(in); auto s_in=copy_strides(in);
...@@ -145,20 +139,36 @@ template<typename T> py::array sym_rfftn_internal(const py::array &in, ...@@ -145,20 +139,36 @@ template<typename T> py::array sym_rfftn_internal(const py::array &in,
ares[iter.rev_ofs()] = conj(v); ares[iter.rev_ofs()] = conj(v);
iter.advance(); iter.advance();
} }
if (!fwd)
{
simple_iter iter(ares);
while(iter.remaining()>0)
{
auto v = ares[iter.ofs()];
ares[iter.ofs()] = conj(v);
iter.advance();
}
}
} }
return res; return res;
} }
py::array fftn(const py::array &a, py::object axes, int inorm, bool inplace, py::array xfftn(const py::array &a, py::object axes, int inorm,
size_t nthreads) bool inplace, bool fwd, size_t nthreads)
{ {
if (a.dtype().kind() == 'c') if (a.dtype().kind() == 'c')
return xfftn(a, axes, inorm, inplace, true, nthreads); DISPATCH(a, c128, c64, clong, xfftn_internal, (a, makeaxes(a, axes), inorm,
inplace, fwd, nthreads))
if (inplace) throw runtime_error("cannot do this operation in-place"); if (inplace) throw runtime_error("cannot do this operation in-place");
DISPATCH(a, f64, f32, flong, sym_rfftn_internal, (a, axes, inorm, nthreads)) DISPATCH(a, f64, f32, flong, sym_rfftn_internal, (a, makeaxes(a, axes), inorm,
fwd, nthreads))
} }
py::array fftn(const py::array &a, py::object axes, int inorm, bool inplace,
size_t nthreads)
{ return xfftn(a, axes, inorm, inplace, true, nthreads); }
py::array ifftn(const py::array &a, py::object axes, int inorm, py::array ifftn(const py::array &a, py::object axes, int inorm,
bool inplace, size_t nthreads) bool inplace, size_t nthreads)
{ return xfftn(a, axes, inorm, inplace, false, nthreads); } { return xfftn(a, axes, inorm, inplace, false, nthreads); }
...@@ -359,8 +369,9 @@ const char *ifftn_DS = R"""(Performs a backward complex FFT. ...@@ -359,8 +369,9 @@ const char *ifftn_DS = R"""(Performs a backward complex FFT.
Parameters Parameters
---------- ----------
a : numpy.ndarray (any complex type) a : numpy.ndarray (any complex or real type)
The input data The input data. If its type is real, a more efficient real-to-complex
transform will be used.
axes : list of integers axes : list of integers
The axes along which the FFT is carried out. The axes along which the FFT is carried out.
If not set, all axes will be transformed. If not set, all axes will be transformed.
...@@ -373,13 +384,15 @@ inorm : int ...@@ -373,13 +384,15 @@ inorm : int
inplace : bool inplace : bool
if False, returns the result in a new array and leaves the input unchanged. if False, returns the result in a new array and leaves the input unchanged.
if True, stores the result in the input array and returns a handle to it. if True, stores the result in the input array and returns a handle to it.
NOTE: if `a` is real-valued and `inplace` is `True`, an exception will be
raised!
nthreads : int nthreads : int
Number of threads to use. If 0, use the system default (typically governed Number of threads to use. If 0, use the system default (typically governed
by the `OMP_NUM_THREADS` environment variable). by the `OMP_NUM_THREADS` environment variable).
Returns Returns
------- -------
np.ndarray (same shape and data type as a) np.ndarray (same shape as `a`, complex type with same accuracy as `a`)
The transformed data The transformed data
)"""; )""";
......
...@@ -29,6 +29,11 @@ def test(): ...@@ -29,6 +29,11 @@ def test():
print("cmaxerr:", cmaxerr, shape, axes) print("cmaxerr:", cmaxerr, shape, axes)
b=pypocketfft.ifftn(pypocketfft.fftn(a.real,axes=axes,nthreads=nthreads),axes=axes,inorm=2,nthreads=nthreads) b=pypocketfft.ifftn(pypocketfft.fftn(a.real,axes=axes,nthreads=nthreads),axes=axes,inorm=2,nthreads=nthreads)
err = _l2error(a.real,b) err = _l2error(a.real,b)
if err > cmaxerr:
cmaxerr = err
print("cmaxerr:", cmaxerr, shape, axes)
b=pypocketfft.fftn(pypocketfft.ifftn(a.real,axes=axes,nthreads=nthreads),axes=axes,inorm=2,nthreads=nthreads)
err = _l2error(a.real,b)
if err > cmaxerr: if err > cmaxerr:
cmaxerr = err cmaxerr = err
print("cmaxerr:", cmaxerr, shape, axes) print("cmaxerr:", cmaxerr, shape, axes)
......
...@@ -24,12 +24,14 @@ def test1D(len, inorm): ...@@ -24,12 +24,14 @@ def test1D(len, inorm):
assert_(_l2error(a, pypocketfft.ifftn(pypocketfft.fftn(c,inorm=inorm), inorm=2-inorm))<1e-18) assert_(_l2error(a, pypocketfft.ifftn(pypocketfft.fftn(c,inorm=inorm), inorm=2-inorm))<1e-18)
assert_(_l2error(a, pypocketfft.ifftn(pypocketfft.fftn(a,inorm=inorm), inorm=2-inorm))<1.5e-15) assert_(_l2error(a, pypocketfft.ifftn(pypocketfft.fftn(a,inorm=inorm), inorm=2-inorm))<1.5e-15)
assert_(_l2error(a.real, pypocketfft.ifftn(pypocketfft.fftn(a.real,inorm=inorm), inorm=2-inorm))<1.5e-15) assert_(_l2error(a.real, pypocketfft.ifftn(pypocketfft.fftn(a.real,inorm=inorm), inorm=2-inorm))<1.5e-15)
assert_(_l2error(a.real, pypocketfft.fftn(pypocketfft.ifftn(a.real,inorm=inorm), inorm=2-inorm))<1.5e-15)
assert_(_l2error(a.real, pypocketfft.irfftn(pypocketfft.rfftn(a.real,inorm=inorm), inorm=2-inorm,lastsize=len))<1.5e-15) assert_(_l2error(a.real, pypocketfft.irfftn(pypocketfft.rfftn(a.real,inorm=inorm), inorm=2-inorm,lastsize=len))<1.5e-15)
tmp=a.copy() tmp=a.copy()
assert_ (pypocketfft.ifftn(pypocketfft.fftn(tmp, inplace=True, inorm=inorm), inplace=True, inorm=2-inorm) is tmp) assert_ (pypocketfft.ifftn(pypocketfft.fftn(tmp, inplace=True, inorm=inorm), inplace=True, inorm=2-inorm) is tmp)
assert_(_l2error(tmp, a)<1.5e-15) assert_(_l2error(tmp, a)<1.5e-15)
assert_(_l2error(b, pypocketfft.ifftn(pypocketfft.fftn(b, inorm=inorm), inorm=2-inorm))<6e-7) assert_(_l2error(b, pypocketfft.ifftn(pypocketfft.fftn(b, inorm=inorm), inorm=2-inorm))<6e-7)
assert_(_l2error(b.real, pypocketfft.ifftn(pypocketfft.fftn(b.real,inorm=inorm), inorm=2-inorm))<6e-7) assert_(_l2error(b.real, pypocketfft.ifftn(pypocketfft.fftn(b.real,inorm=inorm), inorm=2-inorm))<6e-7)
assert_(_l2error(b.real, pypocketfft.fftn(pypocketfft.ifftn(b.real,inorm=inorm), inorm=2-inorm))<6e-7)
assert_(_l2error(b.real, pypocketfft.irfftn(pypocketfft.rfftn(b.real, inorm=inorm), lastsize=len, inorm=2-inorm))<6e-7) assert_(_l2error(b.real, pypocketfft.irfftn(pypocketfft.rfftn(b.real, inorm=inorm), lastsize=len, inorm=2-inorm))<6e-7)
tmp=b.copy() tmp=b.copy()
assert_ (pypocketfft.ifftn(pypocketfft.fftn(tmp, inplace=True, inorm=inorm), inplace=True, inorm=2-inorm) is tmp) assert_ (pypocketfft.ifftn(pypocketfft.fftn(tmp, inplace=True, inorm=inorm), inplace=True, inorm=2-inorm) is tmp)
...@@ -77,6 +79,8 @@ def test_rfftn2D(shp, axes): ...@@ -77,6 +79,8 @@ def test_rfftn2D(shp, axes):
def test_identity(shp): def test_identity(shp):
a=np.random.rand(*shp)-0.5 + 1j*np.random.rand(*shp)-0.5j a=np.random.rand(*shp)-0.5 + 1j*np.random.rand(*shp)-0.5j
assert_(_l2error(pypocketfft.ifftn(pypocketfft.fftn(a),inorm=2), a)<1.5e-15) assert_(_l2error(pypocketfft.ifftn(pypocketfft.fftn(a),inorm=2), a)<1.5e-15)
assert_(_l2error(pypocketfft.ifftn(pypocketfft.fftn(a.real),inorm=2), a.real)<1.5e-15)
assert_(_l2error(pypocketfft.fftn(pypocketfft.ifftn(a.real),inorm=2), a.real)<1.5e-15)
tmp=a.copy() tmp=a.copy()
assert_ (pypocketfft.ifftn(pypocketfft.fftn(tmp, inplace=True), inorm=2, inplace=True) is tmp) assert_ (pypocketfft.ifftn(pypocketfft.fftn(tmp, inplace=True), inorm=2, inplace=True) is tmp)
assert_(_l2error(tmp, a)<1.5e-15) assert_(_l2error(tmp, a)<1.5e-15)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment