Commit 89d506fd authored by Martin Reinecke's avatar Martin Reinecke
Browse files

allow real-valued arguments to ifftn()

parent d84ec6c3
......@@ -73,11 +73,13 @@ shape_t makeaxes(const py::array &in, py::object axes)
}
#define DISPATCH(arr, T1, T2, T3, func, args) \
{ \
auto dtype = arr.dtype(); \
if (dtype.is(T1)) return func<double> args; \
if (dtype.is(T2)) return func<float> args; \
if (dtype.is(T3)) return func<ldbl_t> args; \
throw runtime_error("unsupported data type");
throw runtime_error("unsupported data type"); \
}
template<typename T> T norm_fct(int inorm, size_t N)
{
......@@ -114,17 +116,9 @@ template<typename T> py::array xfftn_internal(const py::array &in,
return res;
}
py::array xfftn(const py::array &a, py::object axes, int inorm,
bool inplace, bool fwd, size_t nthreads)
{
DISPATCH(a, c128, c64, clong, xfftn_internal, (a, makeaxes(a, axes), inorm,
inplace, fwd, nthreads))
}
template<typename T> py::array sym_rfftn_internal(const py::array &in,
py::object axes_, int inorm, size_t nthreads)
const shape_t &axes, int inorm, bool fwd, size_t nthreads)
{
auto axes = makeaxes(in, axes_);
auto dims(copy_shape(in));
py::array res = py::array_t<complex<T>>(dims);
auto s_in=copy_strides(in);
......@@ -145,20 +139,36 @@ template<typename T> py::array sym_rfftn_internal(const py::array &in,
ares[iter.rev_ofs()] = conj(v);
iter.advance();
}
if (!fwd)
{
simple_iter iter(ares);
while(iter.remaining()>0)
{
auto v = ares[iter.ofs()];
ares[iter.ofs()] = conj(v);
iter.advance();
}
}
}
return res;
}
py::array fftn(const py::array &a, py::object axes, int inorm, bool inplace,
size_t nthreads)
py::array xfftn(const py::array &a, py::object axes, int inorm,
bool inplace, bool fwd, size_t nthreads)
{
if (a.dtype().kind() == 'c')
return xfftn(a, axes, inorm, inplace, true, nthreads);
if (a.dtype().kind() == 'c')
DISPATCH(a, c128, c64, clong, xfftn_internal, (a, makeaxes(a, axes), inorm,
inplace, fwd, nthreads))
if (inplace) throw runtime_error("cannot do this operation in-place");
DISPATCH(a, f64, f32, flong, sym_rfftn_internal, (a, axes, inorm, nthreads))
if (inplace) throw runtime_error("cannot do this operation in-place");
DISPATCH(a, f64, f32, flong, sym_rfftn_internal, (a, makeaxes(a, axes), inorm,
fwd, nthreads))
}
py::array fftn(const py::array &a, py::object axes, int inorm, bool inplace,
size_t nthreads)
{ return xfftn(a, axes, inorm, inplace, true, nthreads); }
py::array ifftn(const py::array &a, py::object axes, int inorm,
bool inplace, size_t nthreads)
{ return xfftn(a, axes, inorm, inplace, false, nthreads); }
......@@ -359,8 +369,9 @@ const char *ifftn_DS = R"""(Performs a backward complex FFT.
Parameters
----------
a : numpy.ndarray (any complex type)
The input data
a : numpy.ndarray (any complex or real type)
The input data. If its type is real, a more efficient real-to-complex
transform will be used.
axes : list of integers
The axes along which the FFT is carried out.
If not set, all axes will be transformed.
......@@ -373,13 +384,15 @@ inorm : int
inplace : bool
if False, returns the result in a new array and leaves the input unchanged.
if True, stores the result in the input array and returns a handle to it.
NOTE: if `a` is real-valued and `inplace` is `True`, an exception will be
raised!
nthreads : int
Number of threads to use. If 0, use the system default (typically governed
by the `OMP_NUM_THREADS` environment variable).
Returns
-------
np.ndarray (same shape and data type as a)
np.ndarray (same shape as `a`, complex type with same accuracy as `a`)
The transformed data
)""";
......
......@@ -29,6 +29,11 @@ def test():
print("cmaxerr:", cmaxerr, shape, axes)
b=pypocketfft.ifftn(pypocketfft.fftn(a.real,axes=axes,nthreads=nthreads),axes=axes,inorm=2,nthreads=nthreads)
err = _l2error(a.real,b)
if err > cmaxerr:
cmaxerr = err
print("cmaxerr:", cmaxerr, shape, axes)
b=pypocketfft.fftn(pypocketfft.ifftn(a.real,axes=axes,nthreads=nthreads),axes=axes,inorm=2,nthreads=nthreads)
err = _l2error(a.real,b)
if err > cmaxerr:
cmaxerr = err
print("cmaxerr:", cmaxerr, shape, axes)
......
......@@ -24,12 +24,14 @@ def test1D(len, inorm):
assert_(_l2error(a, pypocketfft.ifftn(pypocketfft.fftn(c,inorm=inorm), inorm=2-inorm))<1e-18)
assert_(_l2error(a, pypocketfft.ifftn(pypocketfft.fftn(a,inorm=inorm), inorm=2-inorm))<1.5e-15)
assert_(_l2error(a.real, pypocketfft.ifftn(pypocketfft.fftn(a.real,inorm=inorm), inorm=2-inorm))<1.5e-15)
assert_(_l2error(a.real, pypocketfft.fftn(pypocketfft.ifftn(a.real,inorm=inorm), inorm=2-inorm))<1.5e-15)
assert_(_l2error(a.real, pypocketfft.irfftn(pypocketfft.rfftn(a.real,inorm=inorm), inorm=2-inorm,lastsize=len))<1.5e-15)
tmp=a.copy()
assert_ (pypocketfft.ifftn(pypocketfft.fftn(tmp, inplace=True, inorm=inorm), inplace=True, inorm=2-inorm) is tmp)
assert_(_l2error(tmp, a)<1.5e-15)
assert_(_l2error(b, pypocketfft.ifftn(pypocketfft.fftn(b, inorm=inorm), inorm=2-inorm))<6e-7)
assert_(_l2error(b.real, pypocketfft.ifftn(pypocketfft.fftn(b.real,inorm=inorm), inorm=2-inorm))<6e-7)
assert_(_l2error(b.real, pypocketfft.fftn(pypocketfft.ifftn(b.real,inorm=inorm), inorm=2-inorm))<6e-7)
assert_(_l2error(b.real, pypocketfft.irfftn(pypocketfft.rfftn(b.real, inorm=inorm), lastsize=len, inorm=2-inorm))<6e-7)
tmp=b.copy()
assert_ (pypocketfft.ifftn(pypocketfft.fftn(tmp, inplace=True, inorm=inorm), inplace=True, inorm=2-inorm) is tmp)
......@@ -77,6 +79,8 @@ def test_rfftn2D(shp, axes):
def test_identity(shp):
a=np.random.rand(*shp)-0.5 + 1j*np.random.rand(*shp)-0.5j
assert_(_l2error(pypocketfft.ifftn(pypocketfft.fftn(a),inorm=2), a)<1.5e-15)
assert_(_l2error(pypocketfft.ifftn(pypocketfft.fftn(a.real),inorm=2), a.real)<1.5e-15)
assert_(_l2error(pypocketfft.fftn(pypocketfft.ifftn(a.real),inorm=2), a.real)<1.5e-15)
tmp=a.copy()
assert_ (pypocketfft.ifftn(pypocketfft.fftn(tmp, inplace=True), inorm=2, inplace=True) is tmp)
assert_(_l2error(tmp, a)<1.5e-15)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment