Commit 06d75f2c authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cosmetics

parent df6b820b
...@@ -358,8 +358,8 @@ template<typename T0> class cfftp ...@@ -358,8 +358,8 @@ template<typename T0> class cfftp
fct[nfct++].fct = factor; fct[nfct++].fct = factor;
} }
template<typename T, bool bwd> NOINLINE void pass2 (size_t ido, size_t l1, const T * restrict cc, template<bool bwd, typename T> NOINLINE void pass2 (size_t ido, size_t l1,
T * restrict ch, const cmplx<T0> * restrict wa) const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
{ {
constexpr size_t cdim=2; constexpr size_t cdim=2;
...@@ -402,7 +402,7 @@ template<typename T, bool bwd> NOINLINE void pass2 (size_t ido, size_t l1, const ...@@ -402,7 +402,7 @@ template<typename T, bool bwd> NOINLINE void pass2 (size_t ido, size_t l1, const
CH(i,k,u1) = da.template special_mul<bwd>(WA(u1-1,i)); \ CH(i,k,u1) = da.template special_mul<bwd>(WA(u1-1,i)); \
CH(i,k,u2) = db.template special_mul<bwd>(WA(u2-1,i)); \ CH(i,k,u2) = db.template special_mul<bwd>(WA(u2-1,i)); \
} }
template<typename T, bool bwd> NOINLINE void pass3 (size_t ido, size_t l1, template<bool bwd, typename T> NOINLINE void pass3 (size_t ido, size_t l1,
const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa) const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
{ {
constexpr size_t cdim=3; constexpr size_t cdim=3;
...@@ -429,7 +429,7 @@ template<typename T, bool bwd> NOINLINE void pass3 (size_t ido, size_t l1, ...@@ -429,7 +429,7 @@ template<typename T, bool bwd> NOINLINE void pass3 (size_t ido, size_t l1,
} }
} }
template<typename T, bool bwd> NOINLINE void pass4 (size_t ido, size_t l1, template<bool bwd, typename T> NOINLINE void pass4 (size_t ido, size_t l1,
const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa) const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
{ {
constexpr size_t cdim=4; constexpr size_t cdim=4;
...@@ -500,7 +500,7 @@ template<typename T, bool bwd> NOINLINE void pass4 (size_t ido, size_t l1, ...@@ -500,7 +500,7 @@ template<typename T, bool bwd> NOINLINE void pass4 (size_t ido, size_t l1,
CH(i,k,u1) = da.template special_mul<bwd>(WA(u1-1,i)); \ CH(i,k,u1) = da.template special_mul<bwd>(WA(u1-1,i)); \
CH(i,k,u2) = db.template special_mul<bwd>(WA(u2-1,i)); \ CH(i,k,u2) = db.template special_mul<bwd>(WA(u2-1,i)); \
} }
template<typename T, bool bwd> NOINLINE void pass5 (size_t ido, size_t l1, template<bool bwd, typename T> NOINLINE void pass5 (size_t ido, size_t l1,
const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa) const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
{ {
constexpr size_t cdim=5; constexpr size_t cdim=5;
...@@ -560,7 +560,7 @@ template<typename T, bool bwd> NOINLINE void pass5 (size_t ido, size_t l1, ...@@ -560,7 +560,7 @@ template<typename T, bool bwd> NOINLINE void pass5 (size_t ido, size_t l1,
CH(i,k,u2) = db.template special_mul<bwd>(WA(u2-1,i)); \ CH(i,k,u2) = db.template special_mul<bwd>(WA(u2-1,i)); \
} }
template<typename T, bool bwd> NOINLINE void pass7(size_t ido, size_t l1, template<bool bwd, typename T> NOINLINE void pass7(size_t ido, size_t l1,
const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa) const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
{ {
constexpr size_t cdim=7; constexpr size_t cdim=7;
...@@ -626,7 +626,7 @@ template<typename T, bool bwd> NOINLINE void pass7(size_t ido, size_t l1, ...@@ -626,7 +626,7 @@ template<typename T, bool bwd> NOINLINE void pass7(size_t ido, size_t l1,
CH(i,k,u2) = db.template special_mul<bwd>(WA(u2-1,i)); \ CH(i,k,u2) = db.template special_mul<bwd>(WA(u2-1,i)); \
} }
template<typename T, bool bwd> NOINLINE void pass11 (size_t ido, size_t l1, template<bool bwd, typename T> NOINLINE void pass11 (size_t ido, size_t l1,
const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa) const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
{ {
const size_t cdim=11; const size_t cdim=11;
...@@ -678,7 +678,7 @@ template<typename T, bool bwd> NOINLINE void pass11 (size_t ido, size_t l1, ...@@ -678,7 +678,7 @@ template<typename T, bool bwd> NOINLINE void pass11 (size_t ido, size_t l1,
#define CX2(a,b) cc[(a)+idl1*(b)] #define CX2(a,b) cc[(a)+idl1*(b)]
#define CH2(a,b) ch[(a)+idl1*(b)] #define CH2(a,b) ch[(a)+idl1*(b)]
template<typename T, bool bwd> NOINLINE void passg (size_t ido, size_t ip, template<bool bwd, typename T> NOINLINE void passg (size_t ido, size_t ip,
size_t l1, T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa, size_t l1, T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa,
const cmplx<T0> * restrict csarr) const cmplx<T0> * restrict csarr)
{ {
...@@ -779,8 +779,7 @@ template<typename T, bool bwd> NOINLINE void passg (size_t ido, size_t ip, ...@@ -779,8 +779,7 @@ template<typename T, bool bwd> NOINLINE void passg (size_t ido, size_t ip,
#undef CX2 #undef CX2
#undef CX #undef CX
template<typename T> NOINLINE void pass_all(T c[], T0 fact, template<bool bwd, typename T> NOINLINE void pass_all(T c[], T0 fact)
const int sign)
{ {
if (length==1) return; if (length==1) return;
size_t l1=1; size_t l1=1;
...@@ -793,27 +792,20 @@ template<typename T> NOINLINE void pass_all(T c[], T0 fact, ...@@ -793,27 +792,20 @@ template<typename T> NOINLINE void pass_all(T c[], T0 fact,
size_t l2=ip*l1; size_t l2=ip*l1;
size_t ido = length/l2; size_t ido = length/l2;
if (ip==4) if (ip==4)
sign>0 ? pass4<T, true> (ido, l1, p1, p2, fct[k1].tw) pass4<bwd> (ido, l1, p1, p2, fct[k1].tw);
: pass4<T, false> (ido, l1, p1, p2, fct[k1].tw);
else if(ip==2) else if(ip==2)
sign>0 ? pass2<T, true>(ido, l1, p1, p2, fct[k1].tw) pass2<bwd>(ido, l1, p1, p2, fct[k1].tw);
: pass2<T, false>(ido, l1, p1, p2, fct[k1].tw);
else if(ip==3) else if(ip==3)
sign>0 ? pass3<T, true> (ido, l1, p1, p2, fct[k1].tw) pass3<bwd> (ido, l1, p1, p2, fct[k1].tw);
: pass3<T, false> (ido, l1, p1, p2, fct[k1].tw);
else if(ip==5) else if(ip==5)
sign>0 ? pass5<T, true> (ido, l1, p1, p2, fct[k1].tw) pass5<bwd> (ido, l1, p1, p2, fct[k1].tw);
: pass5<T, false> (ido, l1, p1, p2, fct[k1].tw);
else if(ip==7) else if(ip==7)
sign>0 ? pass7<T, true> (ido, l1, p1, p2, fct[k1].tw) pass7<bwd> (ido, l1, p1, p2, fct[k1].tw);
: pass7<T, false> (ido, l1, p1, p2, fct[k1].tw);
else if(ip==11) else if(ip==11)
sign>0 ? pass11<T, true> (ido, l1, p1, p2, fct[k1].tw) pass11<bwd> (ido, l1, p1, p2, fct[k1].tw);
: pass11<T, false> (ido, l1, p1, p2, fct[k1].tw);
else else
{ {
sign>0 ? passg<T, true>(ido, ip, l1, p1, p2, fct[k1].tw, fct[k1].tws) passg<bwd>(ido, ip, l1, p1, p2, fct[k1].tw, fct[k1].tws);
: passg<T, false>(ido, ip, l1, p1, p2, fct[k1].tw, fct[k1].tws);
swap(p1,p2); swap(p1,p2);
} }
swap(p1,p2); swap(p1,p2);
...@@ -840,10 +832,10 @@ template<typename T> NOINLINE void pass_all(T c[], T0 fact, ...@@ -840,10 +832,10 @@ template<typename T> NOINLINE void pass_all(T c[], T0 fact,
public: public:
template<typename T> NOINLINE void forward(T c[], T0 fct) template<typename T> NOINLINE void forward(T c[], T0 fct)
{ pass_all(c, fct, -1); } { pass_all<false>(c, fct); }
template<typename T> NOINLINE void backward(T c[], T0 fct) template<typename T> NOINLINE void backward(T c[], T0 fct)
{ pass_all(c, fct, 1); } { pass_all<true>(c, fct); }
private: private:
NOINLINE void factorize() NOINLINE void factorize()
...@@ -1054,7 +1046,8 @@ template<typename T0> class pocketfft_c ...@@ -1054,7 +1046,8 @@ template<typename T0> class pocketfft_c
size_t length() const { return len; } size_t length() const { return len; }
}; };
struct diminfo { size_t n, s; }; struct diminfo
{ size_t n, s; };
class multiarr class multiarr
{ {
private: private:
...@@ -1130,9 +1123,7 @@ template<typename T> void pocketfft_general_c(int ndim, const size_t *shape, ...@@ -1130,9 +1123,7 @@ template<typename T> void pocketfft_general_c(int ndim, const size_t *shape,
if (shape[axes[iax]] > tmpsize) tmpsize = shape[axes[iax]]; if (shape[axes[iax]] > tmpsize) tmpsize = shape[axes[iax]];
} }
arr<cmplx<T>> tdata(tmpsize); arr<cmplx<T>> tdata(tmpsize);
//cout << "tmpsize = " << tmpsize << endl;
multiarr a_in(ndim, shape, stride_in), a_out(ndim, shape, stride_out); multiarr a_in(ndim, shape, stride_in), a_out(ndim, shape, stride_out);
//cout << ndim << " " << a_in.ndim() << " " << a_in.size(0) << endl;
unique_ptr<pocketfft_c<T>> plan; unique_ptr<pocketfft_c<T>> plan;
for (int iax=0; iax<nax; ++iax) for (int iax=0; iax<nax; ++iax)
...@@ -1140,10 +1131,8 @@ template<typename T> void pocketfft_general_c(int ndim, const size_t *shape, ...@@ -1140,10 +1131,8 @@ template<typename T> void pocketfft_general_c(int ndim, const size_t *shape,
int axis = axes[iax]; int axis = axes[iax];
multi_iter it_in(a_in, axis), it_out(a_out, axis); multi_iter it_in(a_in, axis), it_out(a_out, axis);
bool inplace = (data_in==data_out) && (it_in.stride()==1); bool inplace = (data_in==data_out) && (it_in.stride()==1);
//cout << "inplace = " << inplace << endl;
if ((!plan) || (it_in.length()!=plan->length())) if ((!plan) || (it_in.length()!=plan->length()))
plan.reset(new pocketfft_c<T>(it_in.length())); plan.reset(new pocketfft_c<T>(it_in.length()));
//cout << "plan length = " << plan->length() << endl;
while (!it_in.done()) while (!it_in.done())
{ {
if (inplace) if (inplace)
...@@ -1182,9 +1171,6 @@ namespace py = pybind11; ...@@ -1182,9 +1171,6 @@ namespace py = pybind11;
namespace { namespace {
using a_cd = py::array_t<complex<double>>;
using a_cs = py::array_t<complex<float>>;
py::array execute(const py::array &in, bool fwd) py::array execute(const py::array &in, bool fwd)
{ {
py::array res; py::array res;
...@@ -1192,9 +1178,9 @@ py::array execute(const py::array &in, bool fwd) ...@@ -1192,9 +1178,9 @@ py::array execute(const py::array &in, bool fwd)
for (int i=0; i<in.ndim(); ++i) for (int i=0; i<in.ndim(); ++i)
dims[i] = in.shape(i); dims[i] = in.shape(i);
if (in.dtype().is(py::dtype("complex128"))) if (in.dtype().is(py::dtype("complex128")))
res = a_cd(dims); res = py::array_t<complex<double>>(dims);
else if (in.dtype().is(py::dtype("complex64"))) else if (in.dtype().is(py::dtype("complex64")))
res = a_cs(dims); res = py::array_t<complex<float>>(dims);
else throw runtime_error("unsupported data type"); else throw runtime_error("unsupported data type");
for (int i=0; i<in.ndim(); ++i) for (int i=0; i<in.ndim(); ++i)
{ {
...@@ -1222,8 +1208,6 @@ py::array ifftn(const py::array &in) ...@@ -1222,8 +1208,6 @@ py::array ifftn(const py::array &in)
PYBIND11_MODULE(pypocketfft, m) PYBIND11_MODULE(pypocketfft, m)
{ {
using namespace pybind11::literals;
m.def("fftn",&fftn); m.def("fftn",&fftn);
m.def("ifftn",&ifftn); m.def("ifftn",&ifftn);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment