Commit 06d75f2c authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cosmetics

parent df6b820b
......@@ -358,8 +358,8 @@ template<typename T0> class cfftp
fct[nfct++].fct = factor;
}
template<typename T, bool bwd> NOINLINE void pass2 (size_t ido, size_t l1, const T * restrict cc,
T * restrict ch, const cmplx<T0> * restrict wa)
template<bool bwd, typename T> NOINLINE void pass2 (size_t ido, size_t l1,
const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
{
constexpr size_t cdim=2;
......@@ -402,7 +402,7 @@ template<typename T, bool bwd> NOINLINE void pass2 (size_t ido, size_t l1, const
CH(i,k,u1) = da.template special_mul<bwd>(WA(u1-1,i)); \
CH(i,k,u2) = db.template special_mul<bwd>(WA(u2-1,i)); \
}
template<typename T, bool bwd> NOINLINE void pass3 (size_t ido, size_t l1,
template<bool bwd, typename T> NOINLINE void pass3 (size_t ido, size_t l1,
const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
{
constexpr size_t cdim=3;
......@@ -429,7 +429,7 @@ template<typename T, bool bwd> NOINLINE void pass3 (size_t ido, size_t l1,
}
}
template<typename T, bool bwd> NOINLINE void pass4 (size_t ido, size_t l1,
template<bool bwd, typename T> NOINLINE void pass4 (size_t ido, size_t l1,
const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
{
constexpr size_t cdim=4;
......@@ -500,7 +500,7 @@ template<typename T, bool bwd> NOINLINE void pass4 (size_t ido, size_t l1,
CH(i,k,u1) = da.template special_mul<bwd>(WA(u1-1,i)); \
CH(i,k,u2) = db.template special_mul<bwd>(WA(u2-1,i)); \
}
template<typename T, bool bwd> NOINLINE void pass5 (size_t ido, size_t l1,
template<bool bwd, typename T> NOINLINE void pass5 (size_t ido, size_t l1,
const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
{
constexpr size_t cdim=5;
......@@ -560,7 +560,7 @@ template<typename T, bool bwd> NOINLINE void pass5 (size_t ido, size_t l1,
CH(i,k,u2) = db.template special_mul<bwd>(WA(u2-1,i)); \
}
template<typename T, bool bwd> NOINLINE void pass7(size_t ido, size_t l1,
template<bool bwd, typename T> NOINLINE void pass7(size_t ido, size_t l1,
const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
{
constexpr size_t cdim=7;
......@@ -626,7 +626,7 @@ template<typename T, bool bwd> NOINLINE void pass7(size_t ido, size_t l1,
CH(i,k,u2) = db.template special_mul<bwd>(WA(u2-1,i)); \
}
template<typename T, bool bwd> NOINLINE void pass11 (size_t ido, size_t l1,
template<bool bwd, typename T> NOINLINE void pass11 (size_t ido, size_t l1,
const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
{
const size_t cdim=11;
......@@ -678,7 +678,7 @@ template<typename T, bool bwd> NOINLINE void pass11 (size_t ido, size_t l1,
#define CX2(a,b) cc[(a)+idl1*(b)]
#define CH2(a,b) ch[(a)+idl1*(b)]
template<typename T, bool bwd> NOINLINE void passg (size_t ido, size_t ip,
template<bool bwd, typename T> NOINLINE void passg (size_t ido, size_t ip,
size_t l1, T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa,
const cmplx<T0> * restrict csarr)
{
......@@ -779,8 +779,7 @@ template<typename T, bool bwd> NOINLINE void passg (size_t ido, size_t ip,
#undef CX2
#undef CX
template<typename T> NOINLINE void pass_all(T c[], T0 fact,
const int sign)
template<bool bwd, typename T> NOINLINE void pass_all(T c[], T0 fact)
{
if (length==1) return;
size_t l1=1;
......@@ -793,27 +792,20 @@ template<typename T> NOINLINE void pass_all(T c[], T0 fact,
size_t l2=ip*l1;
size_t ido = length/l2;
if (ip==4)
sign>0 ? pass4<T, true> (ido, l1, p1, p2, fct[k1].tw)
: pass4<T, false> (ido, l1, p1, p2, fct[k1].tw);
pass4<bwd> (ido, l1, p1, p2, fct[k1].tw);
else if(ip==2)
sign>0 ? pass2<T, true>(ido, l1, p1, p2, fct[k1].tw)
: pass2<T, false>(ido, l1, p1, p2, fct[k1].tw);
pass2<bwd>(ido, l1, p1, p2, fct[k1].tw);
else if(ip==3)
sign>0 ? pass3<T, true> (ido, l1, p1, p2, fct[k1].tw)
: pass3<T, false> (ido, l1, p1, p2, fct[k1].tw);
pass3<bwd> (ido, l1, p1, p2, fct[k1].tw);
else if(ip==5)
sign>0 ? pass5<T, true> (ido, l1, p1, p2, fct[k1].tw)
: pass5<T, false> (ido, l1, p1, p2, fct[k1].tw);
pass5<bwd> (ido, l1, p1, p2, fct[k1].tw);
else if(ip==7)
sign>0 ? pass7<T, true> (ido, l1, p1, p2, fct[k1].tw)
: pass7<T, false> (ido, l1, p1, p2, fct[k1].tw);
pass7<bwd> (ido, l1, p1, p2, fct[k1].tw);
else if(ip==11)
sign>0 ? pass11<T, true> (ido, l1, p1, p2, fct[k1].tw)
: pass11<T, false> (ido, l1, p1, p2, fct[k1].tw);
pass11<bwd> (ido, l1, p1, p2, fct[k1].tw);
else
{
sign>0 ? passg<T, true>(ido, ip, l1, p1, p2, fct[k1].tw, fct[k1].tws)
: passg<T, false>(ido, ip, l1, p1, p2, fct[k1].tw, fct[k1].tws);
passg<bwd>(ido, ip, l1, p1, p2, fct[k1].tw, fct[k1].tws);
swap(p1,p2);
}
swap(p1,p2);
......@@ -840,10 +832,10 @@ template<typename T> NOINLINE void pass_all(T c[], T0 fact,
public:
template<typename T> NOINLINE void forward(T c[], T0 fct)
{ pass_all(c, fct, -1); }
{ pass_all<false>(c, fct); }
template<typename T> NOINLINE void backward(T c[], T0 fct)
{ pass_all(c, fct, 1); }
{ pass_all<true>(c, fct); }
private:
NOINLINE void factorize()
......@@ -1054,7 +1046,8 @@ template<typename T0> class pocketfft_c
size_t length() const { return len; }
};
struct diminfo { size_t n, s; };
struct diminfo
{ size_t n, s; };
class multiarr
{
private:
......@@ -1130,9 +1123,7 @@ template<typename T> void pocketfft_general_c(int ndim, const size_t *shape,
if (shape[axes[iax]] > tmpsize) tmpsize = shape[axes[iax]];
}
arr<cmplx<T>> tdata(tmpsize);
//cout << "tmpsize = " << tmpsize << endl;
multiarr a_in(ndim, shape, stride_in), a_out(ndim, shape, stride_out);
//cout << ndim << " " << a_in.ndim() << " " << a_in.size(0) << endl;
unique_ptr<pocketfft_c<T>> plan;
for (int iax=0; iax<nax; ++iax)
......@@ -1140,10 +1131,8 @@ template<typename T> void pocketfft_general_c(int ndim, const size_t *shape,
int axis = axes[iax];
multi_iter it_in(a_in, axis), it_out(a_out, axis);
bool inplace = (data_in==data_out) && (it_in.stride()==1);
//cout << "inplace = " << inplace << endl;
if ((!plan) || (it_in.length()!=plan->length()))
plan.reset(new pocketfft_c<T>(it_in.length()));
//cout << "plan length = " << plan->length() << endl;
while (!it_in.done())
{
if (inplace)
......@@ -1182,9 +1171,6 @@ namespace py = pybind11;
namespace {
using a_cd = py::array_t<complex<double>>;
using a_cs = py::array_t<complex<float>>;
py::array execute(const py::array &in, bool fwd)
{
py::array res;
......@@ -1192,9 +1178,9 @@ py::array execute(const py::array &in, bool fwd)
for (int i=0; i<in.ndim(); ++i)
dims[i] = in.shape(i);
if (in.dtype().is(py::dtype("complex128")))
res = a_cd(dims);
res = py::array_t<complex<double>>(dims);
else if (in.dtype().is(py::dtype("complex64")))
res = a_cs(dims);
res = py::array_t<complex<float>>(dims);
else throw runtime_error("unsupported data type");
for (int i=0; i<in.ndim(); ++i)
{
......@@ -1222,8 +1208,6 @@ py::array ifftn(const py::array &in)
PYBIND11_MODULE(pypocketfft, m)
{
using namespace pybind11::literals;
m.def("fftn",&fftn);
m.def("ifftn",&ifftn);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment