Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Martin Reinecke
pypocketfft
Commits
2970b68c
Commit
2970b68c
authored
May 14, 2019
by
Martin Reinecke
Browse files
OpenMP support, take 1
parent
f7e2dfc1
Changes
3
Hide whitespace changes
Inline
Side-by-side
pocketfft_hdronly.h
View file @
2970b68c
...
...
@@ -51,6 +51,10 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#if defined(_WIN32)
#include
<malloc.h>
#endif
#ifdef POCKETFFT_OPENMP
#include
<omp.h>
#endif
#ifdef __GNUC__
#define NOINLINE __attribute__((noinline))
...
...
@@ -457,6 +461,18 @@ struct util // hack to avoid duplicate symbols
sanity_check
(
shape
,
stride_in
,
stride_out
,
inplace
);
if
(
axis
>=
shape
.
size
())
throw
runtime_error
(
"bad axis number"
);
}
#ifdef POCKETFFT_OPENMP
static
int
nthreads
()
{
return
omp_get_num_threads
();
}
static
int
thread_num
()
{
return
omp_get_thread_num
();
}
static
bool
run_parallel
(
const
shape_t
&
shape
,
size_t
axis
)
{
return
prod
(
shape
)
/
shape
[
axis
]
>
20
;
}
// FIXME, needs improvement
#else
static
int
nthreads
()
{
return
1
;
}
static
int
thread_num
()
{
return
0
;
}
static
bool
run_parallel
(
const
shape_t
&
,
size_t
)
{
return
false
;
}
#endif
};
#define CH(a,b,c) ch[(a)+ido*((b)+l1*(c))]
...
...
@@ -2091,11 +2107,34 @@ template<size_t N, typename Ti, typename To> class multi_iter
}
public:
multi_iter
(
const
ndarr
<
Ti
>
&
iarr_
,
ndarr
<
To
>
&
oarr_
,
size_t
idim_
)
multi_iter
(
const
ndarr
<
Ti
>
&
iarr_
,
ndarr
<
To
>
&
oarr_
,
size_t
idim_
,
size_t
nshares
=
1
,
size_t
myshare
=
0
)
:
pos
(
iarr_
.
ndim
(),
0
),
iarr
(
iarr_
),
oarr
(
oarr_
),
p_ii
(
0
),
str_i
(
iarr
.
stride
(
idim_
)),
p_oi
(
0
),
str_o
(
oarr
.
stride
(
idim_
)),
idim
(
idim_
),
rem
(
iarr
.
size
()
/
iarr
.
shape
(
idim
))
{}
{
if
(
nshares
==
1
)
return
;
if
(
nshares
==
0
)
throw
runtime_error
(
"can't run with zero threads"
);
if
(
myshare
>=
nshares
)
throw
runtime_error
(
"impossible share requested"
);
size_t
nbase
=
rem
/
nshares
;
size_t
additional
=
rem
%
nshares
;
size_t
lo
=
myshare
*
nbase
+
((
myshare
<
additional
)
?
myshare
:
additional
);
size_t
hi
=
lo
+
nbase
+
(
myshare
<
additional
);
size_t
todo
=
hi
-
lo
;
size_t
chunk
=
rem
;
for
(
size_t
i
=
0
;
i
<
pos
.
size
();
++
i
)
{
if
(
i
==
idim
)
continue
;
chunk
/=
iarr
.
shape
(
i
);
size_t
n_advance
=
lo
/
chunk
;
pos
[
i
]
+=
n_advance
;
p_ii
+=
n_advance
*
iarr
.
stride
(
i
);
p_oi
+=
n_advance
*
oarr
.
stride
(
i
);
lo
-=
n_advance
*
chunk
;
}
rem
=
todo
;
}
void
advance
(
size_t
n
)
{
if
(
rem
<
n
)
throw
runtime_error
(
"underrun"
);
...
...
@@ -2167,18 +2206,24 @@ template<typename T> arr<char> alloc_tmp(const shape_t &shape,
template
<
typename
T
>
NOINLINE
void
general_c
(
const
ndarr
<
cmplx
<
T
>>
&
in
,
ndarr
<
cmplx
<
T
>>
&
out
,
const
shape_t
&
axes
,
bool
forward
,
T
fct
)
const
shape_t
&
axes
,
bool
forward
,
T
fct
,
size_t
nthreads
=
1
)
{
auto
storage
=
alloc_tmp
<
T
>
(
in
.
shape
(),
axes
,
sizeof
(
cmplx
<
T
>
));
unique_ptr
<
pocketfft_c
<
T
>>
plan
;
for
(
size_t
iax
=
0
;
iax
<
axes
.
size
();
++
iax
)
{
constexpr
int
vlen
=
VTYPE
<
T
>::
vlen
;
multi_iter
<
vlen
,
cmplx
<
T
>
,
cmplx
<
T
>>
it
(
iax
==
0
?
in
:
out
,
out
,
axes
[
iax
]);
size_t
len
=
it
.
length_in
();
size_t
len
=
in
.
shape
(
axes
[
iax
]);
if
((
!
plan
)
||
(
len
!=
plan
->
length
()))
plan
.
reset
(
new
pocketfft_c
<
T
>
(
len
));
#ifdef POCKETFFT_OPENMP
#pragma omp parallel if(util::run_parallel(in.shape(), axes[iax])) num_threads(nthreads)
#endif
{
auto
storage
=
alloc_tmp
<
T
>
(
in
.
shape
(),
len
,
sizeof
(
cmplx
<
T
>
));
multi_iter
<
vlen
,
cmplx
<
T
>
,
cmplx
<
T
>>
it
(
iax
==
0
?
in
:
out
,
out
,
axes
[
iax
],
util
::
nthreads
(),
util
::
thread_num
());
#if defined(HAVE_VECSUPPORT)
if
(
vlen
>
1
)
while
(
it
.
remaining
()
>=
vlen
)
...
...
@@ -2218,23 +2263,31 @@ template<typename T> NOINLINE void general_c(
it
.
out
(
i
)
=
tdata
[
i
];
}
}
}
// end of parallel region
fct
=
T
(
1
);
// factor has been applied, use 1 for remaining axes
}
}
template
<
typename
T
>
NOINLINE
void
general_hartley
(
const
ndarr
<
T
>
&
in
,
ndarr
<
T
>
&
out
,
const
shape_t
&
axes
,
T
fct
)
const
ndarr
<
T
>
&
in
,
ndarr
<
T
>
&
out
,
const
shape_t
&
axes
,
T
fct
,
size_t
nthreads
=
1
)
{
auto
storage
=
alloc_tmp
<
T
>
(
in
.
shape
(),
axes
,
sizeof
(
T
));
unique_ptr
<
pocketfft_r
<
T
>>
plan
;
for
(
size_t
iax
=
0
;
iax
<
axes
.
size
();
++
iax
)
{
constexpr
int
vlen
=
VTYPE
<
T
>::
vlen
;
multi_iter
<
vlen
,
T
,
T
>
it
(
iax
==
0
?
in
:
out
,
out
,
axes
[
iax
]);
size_t
len
=
it
.
length_in
();
size_t
len
=
in
.
shape
(
axes
[
iax
]);
if
((
!
plan
)
||
(
len
!=
plan
->
length
()))
plan
.
reset
(
new
pocketfft_r
<
T
>
(
len
));
#ifdef POCKETFFT_OPENMP
#pragma omp parallel if(util::run_parallel(in.shape(), axes[iax])) num_threads(nthreads)
#endif
{
auto
storage
=
alloc_tmp
<
T
>
(
in
.
shape
(),
len
,
sizeof
(
T
));
multi_iter
<
vlen
,
T
,
T
>
it
(
iax
==
0
?
in
:
out
,
out
,
axes
[
iax
],
util
::
nthreads
(),
util
::
thread_num
());
#if defined(HAVE_VECSUPPORT)
if
(
vlen
>
1
)
while
(
it
.
remaining
()
>=
vlen
)
...
...
@@ -2275,19 +2328,25 @@ template<typename T> NOINLINE void general_hartley(
if
(
i
<
len
)
it
.
out
(
i1
)
=
tdata
[
i
];
}
}
// end of parallel region
fct
=
T
(
1
);
// factor has been applied, use 1 for remaining axes
}
}
template
<
typename
T
>
NOINLINE
void
general_r2c
(
const
ndarr
<
T
>
&
in
,
ndarr
<
cmplx
<
T
>>
&
out
,
size_t
axis
,
T
fct
)
const
ndarr
<
T
>
&
in
,
ndarr
<
cmplx
<
T
>>
&
out
,
size_t
axis
,
T
fct
,
size_t
nthreads
=
1
)
{
auto
storage
=
alloc_tmp
<
T
>
(
in
.
shape
(),
in
.
shape
(
axis
),
sizeof
(
T
));
pocketfft_r
<
T
>
plan
(
in
.
shape
(
axis
));
constexpr
int
vlen
=
VTYPE
<
T
>::
vlen
;
multi_iter
<
vlen
,
T
,
cmplx
<
T
>>
it
(
in
,
out
,
axis
);
size_t
len
=
in
.
shape
(
axis
);
#ifdef POCKETFFT_OPENMP
#pragma omp parallel if(util::run_parallel(in.shape(), axis)) num_threads(nthreads)
#endif
{
auto
storage
=
alloc_tmp
<
T
>
(
in
.
shape
(),
len
,
sizeof
(
T
));
multi_iter
<
vlen
,
T
,
cmplx
<
T
>>
it
(
in
,
out
,
axis
,
util
::
nthreads
(),
util
::
thread_num
());
#if defined(HAVE_VECSUPPORT)
if
(
vlen
>
1
)
while
(
it
.
remaining
()
>=
vlen
)
...
...
@@ -2324,16 +2383,22 @@ template<typename T> NOINLINE void general_r2c(
if
(
i
<
len
)
it
.
out
(
ii
).
Set
(
tdata
[
i
]);
}
}
// end of parallel region
}
template
<
typename
T
>
NOINLINE
void
general_c2r
(
const
ndarr
<
cmplx
<
T
>>
&
in
,
ndarr
<
T
>
&
out
,
size_t
axis
,
T
fct
)
const
ndarr
<
cmplx
<
T
>>
&
in
,
ndarr
<
T
>
&
out
,
size_t
axis
,
T
fct
,
size_t
nthreads
=
1
)
{
auto
storage
=
alloc_tmp
<
T
>
(
out
.
shape
(),
out
.
shape
(
axis
),
sizeof
(
T
));
pocketfft_r
<
T
>
plan
(
out
.
shape
(
axis
));
constexpr
int
vlen
=
VTYPE
<
T
>::
vlen
;
multi_iter
<
vlen
,
cmplx
<
T
>
,
T
>
it
(
in
,
out
,
axis
);
size_t
len
=
out
.
shape
(
axis
);
#ifdef POCKETFFT_OPENMP
#pragma omp parallel if(util::run_parallel(in.shape(), axis)) num_threads(nthreads)
#endif
{
auto
storage
=
alloc_tmp
<
T
>
(
out
.
shape
(),
len
,
sizeof
(
T
));
multi_iter
<
vlen
,
cmplx
<
T
>
,
T
>
it
(
in
,
out
,
axis
,
util
::
nthreads
(),
util
::
thread_num
());
#if defined(HAVE_VECSUPPORT)
if
(
vlen
>
1
)
while
(
it
.
remaining
()
>=
vlen
)
...
...
@@ -2374,17 +2439,23 @@ template<typename T> NOINLINE void general_c2r(
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
it
.
out
(
i
)
=
tdata
[
i
];
}
}
// end of parallel region
}
template
<
typename
T
>
NOINLINE
void
general_r
(
const
ndarr
<
T
>
&
in
,
ndarr
<
T
>
&
out
,
size_t
axis
,
bool
forward
,
T
fct
)
const
ndarr
<
T
>
&
in
,
ndarr
<
T
>
&
out
,
size_t
axis
,
bool
forward
,
T
fct
,
size_t
nthreads
=
1
)
{
auto
storage
=
alloc_tmp
<
T
>
(
in
.
shape
(),
in
.
shape
(
axis
),
sizeof
(
T
));
constexpr
int
vlen
=
VTYPE
<
T
>::
vlen
;
multi_iter
<
vlen
,
T
,
T
>
it
(
in
,
out
,
axis
);
size_t
len
=
it
.
length_in
();
size_t
len
=
in
.
shape
(
axis
);
pocketfft_r
<
T
>
plan
(
len
);
#ifdef POCKETFFT_OPENMP
#pragma omp parallel if(util::run_parallel(in.shape(), axis)) num_threads(nthreads)
#endif
{
auto
storage
=
alloc_tmp
<
T
>
(
in
.
shape
(),
len
,
sizeof
(
T
));
multi_iter
<
vlen
,
T
,
T
>
it
(
in
,
out
,
axis
,
util
::
nthreads
(),
util
::
thread_num
());
#if defined(HAVE_VECSUPPORT)
if
(
vlen
>
1
)
while
(
it
.
remaining
()
>=
vlen
)
...
...
@@ -2424,6 +2495,7 @@ template<typename T> NOINLINE void general_r(
it
.
out
(
i
)
=
tdata
[
i
];
}
}
}
// end of parallel region
}
#undef HAVE_VECSUPPORT
...
...
@@ -2432,19 +2504,20 @@ template<typename T> NOINLINE void general_r(
template
<
typename
T
>
void
c2c
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
bool
forward
,
const
std
::
complex
<
T
>
*
data_in
,
std
::
complex
<
T
>
*
data_out
,
T
fct
)
const
std
::
complex
<
T
>
*
data_in
,
std
::
complex
<
T
>
*
data_out
,
T
fct
,
size_t
nthreads
=
1
)
{
using
namespace
detail
;
if
(
util
::
prod
(
shape
)
==
0
)
return
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
ndarr
<
cmplx
<
T
>>
ain
(
data_in
,
shape
,
stride_in
),
aout
(
data_out
,
shape
,
stride_out
);
general_c
(
ain
,
aout
,
axes
,
forward
,
fct
);
general_c
(
ain
,
aout
,
axes
,
forward
,
fct
,
nthreads
);
}
template
<
typename
T
>
void
r2c
(
const
shape_t
&
shape_in
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
size_t
axis
,
const
T
*
data_in
,
std
::
complex
<
T
>
*
data_out
,
T
fct
)
const
T
*
data_in
,
std
::
complex
<
T
>
*
data_out
,
T
fct
,
size_t
nthreads
=
1
)
{
using
namespace
detail
;
if
(
util
::
prod
(
shape_in
)
==
0
)
return
;
...
...
@@ -2453,29 +2526,30 @@ template<typename T> void r2c(const shape_t &shape_in,
shape_t
shape_out
(
shape_in
);
shape_out
[
axis
]
=
shape_in
[
axis
]
/
2
+
1
;
ndarr
<
cmplx
<
T
>>
aout
(
data_out
,
shape_out
,
stride_out
);
general_r2c
(
ain
,
aout
,
axis
,
fct
);
general_r2c
(
ain
,
aout
,
axis
,
fct
,
nthreads
);
}
template
<
typename
T
>
void
r2c
(
const
shape_t
&
shape_in
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
const
T
*
data_in
,
std
::
complex
<
T
>
*
data_out
,
T
fct
)
const
T
*
data_in
,
std
::
complex
<
T
>
*
data_out
,
T
fct
,
size_t
nthreads
=
1
)
{
using
namespace
detail
;
if
(
util
::
prod
(
shape_in
)
==
0
)
return
;
util
::
sanity_check
(
shape_in
,
stride_in
,
stride_out
,
false
,
axes
);
r2c
(
shape_in
,
stride_in
,
stride_out
,
axes
.
back
(),
data_in
,
data_out
,
fct
);
r2c
(
shape_in
,
stride_in
,
stride_out
,
axes
.
back
(),
data_in
,
data_out
,
fct
,
nthreads
);
if
(
axes
.
size
()
==
1
)
return
;
shape_t
shape_out
(
shape_in
);
shape_out
[
axes
.
back
()]
=
shape_in
[
axes
.
back
()]
/
2
+
1
;
auto
newaxes
=
shape_t
{
axes
.
begin
(),
--
axes
.
end
()};
c2c
(
shape_out
,
stride_out
,
stride_out
,
newaxes
,
true
,
data_out
,
data_out
,
T
(
1
));
T
(
1
)
,
nthreads
);
}
template
<
typename
T
>
void
c2r
(
const
shape_t
&
shape_out
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
size_t
axis
,
const
std
::
complex
<
T
>
*
data_in
,
T
*
data_out
,
T
fct
)
const
std
::
complex
<
T
>
*
data_in
,
T
*
data_out
,
T
fct
,
size_t
nthreads
=
1
)
{
using
namespace
detail
;
if
(
util
::
prod
(
shape_out
)
==
0
)
return
;
...
...
@@ -2484,20 +2558,18 @@ template<typename T> void c2r(const shape_t &shape_out,
shape_in
[
axis
]
=
shape_out
[
axis
]
/
2
+
1
;
ndarr
<
cmplx
<
T
>>
ain
(
data_in
,
shape_in
,
stride_in
);
ndarr
<
T
>
aout
(
data_out
,
shape_out
,
stride_out
);
general_c2r
(
ain
,
aout
,
axis
,
fct
);
general_c2r
(
ain
,
aout
,
axis
,
fct
,
nthreads
);
}
template
<
typename
T
>
void
c2r
(
const
shape_t
&
shape_out
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
const
std
::
complex
<
T
>
*
data_in
,
T
*
data_out
,
T
fct
)
const
std
::
complex
<
T
>
*
data_in
,
T
*
data_out
,
T
fct
,
size_t
nthreads
=
1
)
{
using
namespace
detail
;
if
(
util
::
prod
(
shape_out
)
==
0
)
return
;
if
(
axes
.
size
()
==
1
)
{
c2r
(
shape_out
,
stride_in
,
stride_out
,
axes
[
0
],
data_in
,
data_out
,
fct
);
return
;
}
return
c2r
(
shape_out
,
stride_in
,
stride_out
,
axes
[
0
],
data_in
,
data_out
,
fct
,
nthreads
);
util
::
sanity_check
(
shape_out
,
stride_in
,
stride_out
,
false
,
axes
);
auto
shape_in
=
shape_out
;
shape_in
[
axes
.
back
()]
=
shape_out
[
axes
.
back
()]
/
2
+
1
;
...
...
@@ -2509,31 +2581,31 @@ template<typename T> void c2r(const shape_t &shape_out,
arr
<
complex
<
T
>>
tmp
(
nval
);
auto
newaxes
=
shape_t
({
axes
.
begin
(),
--
axes
.
end
()});
c2c
(
shape_in
,
stride_in
,
stride_inter
,
newaxes
,
false
,
data_in
,
tmp
.
data
(),
T
(
1
));
T
(
1
)
,
nthreads
);
c2r
(
shape_out
,
stride_inter
,
stride_out
,
axes
.
back
(),
tmp
.
data
(),
data_out
,
fct
);
fct
,
nthreads
);
}
template
<
typename
T
>
void
r2r_fftpack
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
size_t
axis
,
bool
forward
,
const
T
*
data_in
,
T
*
data_out
,
T
fct
)
bool
forward
,
const
T
*
data_in
,
T
*
data_out
,
T
fct
,
size_t
nthreads
=
1
)
{
using
namespace
detail
;
if
(
util
::
prod
(
shape
)
==
0
)
return
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axis
);
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
),
aout
(
data_out
,
shape
,
stride_out
);
general_r
(
ain
,
aout
,
axis
,
forward
,
fct
);
general_r
(
ain
,
aout
,
axis
,
forward
,
fct
,
nthreads
);
}
template
<
typename
T
>
void
r2r_hartley
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
const
T
*
data_in
,
T
*
data_out
,
T
fct
)
const
T
*
data_in
,
T
*
data_out
,
T
fct
,
size_t
nthreads
=
1
)
{
using
namespace
detail
;
if
(
util
::
prod
(
shape
)
==
0
)
return
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
),
aout
(
data_out
,
shape
,
stride_out
);
general_hartley
(
ain
,
aout
,
axes
,
fct
);
general_hartley
(
ain
,
aout
,
axes
,
fct
,
nthreads
);
}
}
// namespace pocketfft
...
...
pypocketfft.cc
View file @
2970b68c
...
...
@@ -14,8 +14,6 @@
#include
<pybind11/numpy.h>
#include
<pybind11/stl.h>
#pragma GCC visibility push(hidden)
#include
"pocketfft_hdronly.h"
//
...
...
@@ -78,31 +76,31 @@ shape_t makeaxes(const py::array &in, py::object axes)
throw runtime_error("unsupported data type");
template
<
typename
T
>
py
::
array
xfftn_internal
(
const
py
::
array
&
in
,
const
shape_t
&
axes
,
double
fct
,
bool
inplace
,
bool
fwd
)
const
shape_t
&
axes
,
double
fct
,
bool
inplace
,
bool
fwd
,
size_t
nthreads
)
{
auto
dims
(
copy_shape
(
in
));
py
::
array
res
=
inplace
?
in
:
py
::
array_t
<
complex
<
T
>>
(
dims
);
c2c
(
dims
,
copy_strides
(
in
),
copy_strides
(
res
),
axes
,
fwd
,
reinterpret_cast
<
const
complex
<
T
>
*>
(
in
.
data
()),
reinterpret_cast
<
complex
<
T
>
*>
(
res
.
mutable_data
()),
T
(
fct
));
reinterpret_cast
<
complex
<
T
>
*>
(
res
.
mutable_data
()),
T
(
fct
)
,
nthreads
);
return
res
;
}
py
::
array
xfftn
(
const
py
::
array
&
a
,
py
::
object
axes
,
double
fct
,
bool
inplace
,
bool
fwd
)
bool
fwd
,
size_t
nthreads
)
{
DISPATCH
(
a
,
c128
,
c64
,
c256
,
xfftn_internal
,
(
a
,
makeaxes
(
a
,
axes
),
fct
,
inplace
,
fwd
))
inplace
,
fwd
,
nthreads
))
}
py
::
array
fftn
(
const
py
::
array
&
a
,
py
::
object
axes
,
double
fct
,
bool
inplace
)
{
return
xfftn
(
a
,
axes
,
fct
,
inplace
,
true
);
}
py
::
array
fftn
(
const
py
::
array
&
a
,
py
::
object
axes
,
double
fct
,
bool
inplace
,
size_t
nthreads
)
{
return
xfftn
(
a
,
axes
,
fct
,
inplace
,
true
,
nthreads
);
}
py
::
array
ifftn
(
const
py
::
array
&
a
,
py
::
object
axes
,
double
fct
,
bool
inplace
)
{
return
xfftn
(
a
,
axes
,
fct
,
inplace
,
false
);
}
py
::
array
ifftn
(
const
py
::
array
&
a
,
py
::
object
axes
,
double
fct
,
bool
inplace
,
size_t
nthreads
)
{
return
xfftn
(
a
,
axes
,
fct
,
inplace
,
false
,
nthreads
);
}
template
<
typename
T
>
py
::
array
rfftn_internal
(
const
py
::
array
&
in
,
py
::
object
axes_
,
T
fct
)
py
::
object
axes_
,
T
fct
,
size_t
nthreads
)
{
auto
axes
=
makeaxes
(
in
,
axes_
);
auto
dims_in
(
copy_shape
(
in
)),
dims_out
(
dims_in
);
...
...
@@ -110,38 +108,38 @@ template<typename T> py::array rfftn_internal(const py::array &in,
py
::
array
res
=
py
::
array_t
<
complex
<
T
>>
(
dims_out
);
r2c
(
dims_in
,
copy_strides
(
in
),
copy_strides
(
res
),
axes
,
reinterpret_cast
<
const
T
*>
(
in
.
data
()),
reinterpret_cast
<
complex
<
T
>
*>
(
res
.
mutable_data
()),
T
(
fct
));
reinterpret_cast
<
complex
<
T
>
*>
(
res
.
mutable_data
()),
T
(
fct
)
,
nthreads
);
return
res
;
}
py
::
array
rfftn
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
)
py
::
array
rfftn
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
,
size_t
nthreads
)
{
DISPATCH
(
in
,
f64
,
f32
,
f128
,
rfftn_internal
,
(
in
,
axes_
,
fct
))
DISPATCH
(
in
,
f64
,
f32
,
f128
,
rfftn_internal
,
(
in
,
axes_
,
fct
,
nthreads
))
}
template
<
typename
T
>
py
::
array
xrfft_scipy
(
const
py
::
array
&
in
,
size_t
axis
,
double
fct
,
bool
inplace
,
bool
fwd
)
size_t
axis
,
double
fct
,
bool
inplace
,
bool
fwd
,
size_t
nthreads
)
{
auto
dims
(
copy_shape
(
in
));
py
::
array
res
=
inplace
?
in
:
py
::
array_t
<
T
>
(
dims
);
r2r_fftpack
(
dims
,
copy_strides
(
in
),
copy_strides
(
res
),
axis
,
fwd
,
reinterpret_cast
<
const
T
*>
(
in
.
data
()),
reinterpret_cast
<
T
*>
(
res
.
mutable_data
()),
T
(
fct
));
reinterpret_cast
<
T
*>
(
res
.
mutable_data
()),
T
(
fct
)
,
nthreads
);
return
res
;
}
py
::
array
rfft_scipy
(
const
py
::
array
&
in
,
size_t
axis
,
double
fct
,
bool
inplace
)
py
::
array
rfft_scipy
(
const
py
::
array
&
in
,
size_t
axis
,
double
fct
,
bool
inplace
,
size_t
nthreads
)
{
DISPATCH
(
in
,
f64
,
f32
,
f128
,
xrfft_scipy
,
(
in
,
axis
,
fct
,
inplace
,
true
))
DISPATCH
(
in
,
f64
,
f32
,
f128
,
xrfft_scipy
,
(
in
,
axis
,
fct
,
inplace
,
true
,
nthreads
))
}
py
::
array
irfft_scipy
(
const
py
::
array
&
in
,
size_t
axis
,
double
fct
,
bool
inplace
)
bool
inplace
,
size_t
nthreads
)
{
DISPATCH
(
in
,
f64
,
f32
,
f128
,
xrfft_scipy
,
(
in
,
axis
,
fct
,
inplace
,
false
))
DISPATCH
(
in
,
f64
,
f32
,
f128
,
xrfft_scipy
,
(
in
,
axis
,
fct
,
inplace
,
false
,
nthreads
))
}
template
<
typename
T
>
py
::
array
irfftn_internal
(
const
py
::
array
&
in
,
py
::
object
axes_
,
size_t
lastsize
,
T
fct
)
py
::
object
axes_
,
size_t
lastsize
,
T
fct
,
size_t
nthreads
)
{
auto
axes
=
makeaxes
(
in
,
axes_
);
size_t
axis
=
axes
.
back
();
...
...
@@ -153,31 +151,31 @@ template<typename T> py::array irfftn_internal(const py::array &in,
py
::
array
res
=
py
::
array_t
<
T
>
(
dims_out
);
c2r
(
dims_out
,
copy_strides
(
in
),
copy_strides
(
res
),
axes
,
reinterpret_cast
<
const
complex
<
T
>
*>
(
in
.
data
()),
reinterpret_cast
<
T
*>
(
res
.
mutable_data
()),
T
(
fct
));
reinterpret_cast
<
T
*>
(
res
.
mutable_data
()),
T
(
fct
)
,
nthreads
);
return
res
;
}
py
::
array
irfftn
(
const
py
::
array
&
in
,
py
::
object
axes_
,
size_t
lastsize
,
double
fct
)
double
fct
,
size_t
nthreads
)
{
DISPATCH
(
in
,
c128
,
c64
,
c256
,
irfftn_internal
,
(
in
,
axes_
,
lastsize
,
fct
))
DISPATCH
(
in
,
c128
,
c64
,
c256
,
irfftn_internal
,
(
in
,
axes_
,
lastsize
,
fct
,
nthreads
))
}
template
<
typename
T
>
py
::
array
hartley_internal
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
,
bool
inplace
)
py
::
object
axes_
,
double
fct
,
bool
inplace
,
size_t
nthreads
)
{
auto
dims
(
copy_shape
(
in
));
py
::
array
res
=
inplace
?
in
:
py
::
array_t
<
T
>
(
dims
);
r2r_hartley
(
dims
,
copy_strides
(
in
),
copy_strides
(
res
),
makeaxes
(
in
,
axes_
),
reinterpret_cast
<
const
T
*>
(
in
.
data
()),
reinterpret_cast
<
T
*>
(
res
.
mutable_data
()),
T
(
fct
));
reinterpret_cast
<
T
*>
(
res
.
mutable_data
()),
T
(
fct
)
,
nthreads
);
return
res
;
}
py
::
array
hartley
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
,
bool
inplace
)
bool
inplace
,
size_t
nthreads
)
{
DISPATCH
(
in
,
f64
,
f32
,
f128
,
hartley_internal
,
(
in
,
axes_
,
fct
,
inplace
))
DISPATCH
(
in
,
f64
,
f32
,
f128
,
hartley_internal
,
(
in
,
axes_
,
fct
,
inplace
,
nthreads
))
}
template
<
typename
T
>
py
::
array
complex2hartley
(
const
py
::
array
&
in
,
...
...
@@ -230,8 +228,8 @@ py::array mycomplex2hartley(const py::array &in,
}
py
::
array
hartley2
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
,
bool
inplace
)
{
return
mycomplex2hartley
(
in
,
rfftn
(
in
,
axes_
,
fct
),
axes_
,
inplace
);
}
bool
inplace
,
size_t
nthreads
)
{
return
mycomplex2hartley
(
in
,
rfftn
(
in
,
axes_
,
fct
,
nthreads
),
axes_
,
inplace
);
}
const
char
*
pypocketfft_DS
=
R"DELIM(Fast Fourier and Hartley transforms.
...
...
@@ -399,28 +397,26 @@ np.ndarray (same shape and data type as a)
}
// unnamed namespace
#pragma GCC visibility pop
PYBIND11_MODULE
(
pypocketfft
,
m
)
{
using
namespace
pybind11
::
literals
;
m
.
doc
()
=
pypocketfft_DS
;
m
.
def
(
"fftn"
,
&
fftn
,
fftn_DS
,
"a"
_a
,
"axes"
_a
=
py
::
none
(),
"fct"
_a
=
1.
,
"inplace"
_a
=
false
);
"inplace"
_a
=
false
,
"nthreads"
_a
=
1
);
m
.
def
(
"ifftn"
,
&
ifftn
,
ifftn_DS
,
"a"
_a
,
"axes"
_a
=
py
::
none
(),
"fct"
_a
=
1.
,
"inplace"
_a
=
false
);
m
.
def
(
"rfftn"
,
&
rfftn
,
rfftn_DS
,
"a"
_a
,
"axes"
_a
=
py
::
none
(),
"fct"
_a
=
1.
);
"inplace"
_a
=
false
,
"nthreads"
_a
=
1
);
m
.
def
(
"rfftn"
,
&
rfftn
,
rfftn_DS
,
"a"
_a
,
"axes"
_a
=
py
::
none
(),
"fct"
_a
=
1.
,
"nthreads"
_a
=
1
);
m
.
def
(
"rfft_scipy"
,
&
rfft_scipy
,
rfft_scipy_DS
,
"a"
_a
,
"axis"
_a
,
"fct"
_a
=
1.
,
"inplace"
_a
=
false
);
"inplace"
_a
=
false
,
"nthreads"
_a
=
1
);
m
.
def
(
"irfftn"
,
&
irfftn
,
irfftn_DS
,
"a"
_a
,
"axes"
_a
=
py
::
none
(),
"lastsize"
_a
=
0
,
"fct"
_a
=
1.
);
"fct"
_a
=
1.
,
"nthreads"
_a
=
1
);
m
.
def
(
"irfft_scipy"
,
&
irfft_scipy
,
irfft_scipy_DS
,
"a"
_a
,
"axis"
_a
,
"fct"
_a
=
1.
,
"inplace"
_a
=
false
);
"inplace"
_a
=
false
,
"nthreads"
_a
=
1
);
m
.
def
(
"hartley"
,
&
hartley
,
hartley_DS
,
"a"
_a
,
"axes"
_a
=
py
::
none
(),
"fct"
_a
=
1.
,
"inplace"
_a
=
false
);
"inplace"
_a
=
false
,
"nthreads"
_a
=
1
);
m
.
def
(
"hartley2"
,
&
hartley2
,
"a"
_a
,
"axes"
_a
=
py
::
none
(),
"fct"
_a
=
1.
,
"inplace"
_a
=
false
);
"inplace"
_a
=
false
,
"nthreads"
_a
=
1
);
m
.
def
(
"complex2hartley"
,
&
mycomplex2hartley
,
"in"
_a
,
"tmp"
_a
,
"axes"
_a
,
"inplace"
_a
=
false
);
}
setup.py
View file @
2970b68c
...
...
@@ -63,10 +63,11 @@ if sys.platform == 'darwin':
builder
=
setuptools
.
command
.
build_ext
.
build_ext
(
Distribution
())
base_library_link_args
.
append
(
'-dynamiclib'
)
else
:
extra_compile_args
+=
[
'-march=native'
,
'-O3'
,
'-Wfatal-errors'
,
'-Wno-ignored-attributes'
]
extra_compile_args
+=
[
'-march=native'
,
'-O3'
,
'-Wfatal-errors'
,
'-Wno-ignored-attributes'
,
'-DPOCKETFFT_OPENMP'
,
'-fopenmp'
]
python_module_link_args
+=
[
'-march=native'
]
extra_cc_compile_args
.
append
(
'--std=c++11'
)
python_module_link_args
.
append
(
"-Wl,-rpath,$ORIGIN"
)
python_module_link_args
.
append
(
'-fopenmp'
)
extra_cc_compile_args
=
extra_compile_args
+
extra_cc_compile_args
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment