Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Martin Reinecke
pypocketfft
Commits
fcfd81ff
Commit
fcfd81ff
authored
May 10, 2019
by
Martin Reinecke
Browse files
Merge branch 'develop' into 'master'
Develop See merge request
!4
parents
cef9199e
8a3d6ff7
Changes
2
Hide whitespace changes
Inline
Side-by-side
pocketfft_hdronly.h
View file @
fcfd81ff
...
...
@@ -19,6 +19,7 @@
#include
<stdexcept>
#include
<memory>
#include
<vector>
#include
<complex>
#if defined(_WIN32)
#include
<malloc.h>
#endif
...
...
@@ -36,6 +37,8 @@ namespace pocketfft {
using
shape_t
=
std
::
vector
<
size_t
>
;
using
stride_t
=
std
::
vector
<
ptrdiff_t
>
;
constexpr
bool
FORWARD
=
true
,
BACKWARD
=
false
;
namespace
detail
{
...
...
@@ -156,13 +159,29 @@ template<typename T> void ROTM90(cmplx<T> &a)
template
<
typename
T
>
class
sincos_2pibyn
{
private:
template
<
typename
Ta
,
typename
Tb
,
bool
bigger
>
struct
TypeSelector
{};
template
<
typename
Ta
,
typename
Tb
>
struct
TypeSelector
<
Ta
,
Tb
,
true
>
{
using
type
=
Ta
;
};
template
<
typename
Ta
,
typename
Tb
>
struct
TypeSelector
<
Ta
,
Tb
,
false
>
{
using
type
=
Tb
;
};
using
Thigh
=
typename
TypeSelector
<
T
,
double
,
(
sizeof
(
T
)
>
sizeof
(
double
))
>::
type
;
arr
<
T
>
data
;
// adapted from https://stackoverflow.com/questions/42792939/
// CAUTION: this function only works for arguments in the range
// [-0.25; 0.25]!
void
my_sincosm1pi
(
double
a
,
double
*
restrict
res
)
void
my_sincosm1pi
(
Thigh
a
,
Thigh
*
restrict
res
)
{
if
(
sizeof
(
Thigh
)
>
sizeof
(
double
))
// don't have the code for long double
{
Thigh
pi
=
Thigh
(
3.141592653589793238462643383279502884197
L
);
res
[
1
]
=
sin
(
pi
*
a
);
auto
s
=
res
[
1
];
res
[
0
]
=
(
s
*
s
)
/
(
-
sqrt
((
1
-
s
)
*
(
1
+
s
))
-
1
);
return
;
}
double
s
=
a
*
a
;
/* Approximate cos(pi*x)-1 for x in [-0.25,0.25] */
double
r
=
-
1.0369917389758117e-4
;
...
...
@@ -194,25 +213,25 @@ template<typename T> class sincos_2pibyn
res
[
0
]
=
1.
;
res
[
1
]
=
0.
;
if
(
n
==
1
)
return
;
size_t
l1
=
size_t
(
sqrt
(
n
));
arr
<
double
>
tmp
(
2
*
l1
);
arr
<
Thigh
>
tmp
(
2
*
l1
);
for
(
size_t
i
=
1
;
i
<
l1
;
++
i
)
{
my_sincosm1pi
((
2.
*
i
)
/
den
,
&
tmp
[
2
*
i
]);
my_sincosm1pi
((
Thigh
(
2
)
*
i
)
/
den
,
&
tmp
[
2
*
i
]);
res
[
2
*
i
]
=
tmp
[
2
*
i
]
+
1.
;
res
[
2
*
i
+
1
]
=
tmp
[
2
*
i
+
1
];
}
size_t
start
=
l1
;
while
(
start
<
n
)
{
double
cs
[
2
];
my_sincosm1pi
((
2.
*
start
)
/
den
,
cs
);
Thigh
cs
[
2
];
my_sincosm1pi
((
Thigh
(
2
)
*
start
)
/
den
,
cs
);
res
[
2
*
start
]
=
cs
[
0
]
+
1.
;
res
[
2
*
start
+
1
]
=
cs
[
1
];
size_t
end
=
l1
;
if
(
start
+
end
>
n
)
end
=
n
-
start
;
for
(
size_t
i
=
1
;
i
<
end
;
++
i
)
{
double
csx
[
2
]
=
{
tmp
[
2
*
i
],
tmp
[
2
*
i
+
1
]};
Thigh
csx
[
2
]
=
{
tmp
[
2
*
i
],
tmp
[
2
*
i
+
1
]};
res
[
2
*
(
start
+
i
)]
=
((
cs
[
0
]
*
csx
[
0
]
-
cs
[
1
]
*
csx
[
1
]
+
cs
[
0
])
+
csx
[
0
])
+
1.
;
res
[
2
*
(
start
+
i
)
+
1
]
=
(
cs
[
0
]
*
csx
[
1
]
+
cs
[
1
]
*
csx
[
0
])
+
cs
[
1
]
+
csx
[
1
];
}
...
...
@@ -253,7 +272,7 @@ template<typename T> class sincos_2pibyn
void
fill_first_quadrant
(
size_t
n
,
T
*
restrict
res
)
{
const
double
hsqt2
=
0.707106781186547524400844362104849
;
const
expr
Thigh
hsqt2
=
Thigh
(
0.707106781186547524400844362104849
L
)
;
size_t
quart
=
n
>>
2
;
if
((
n
&
7
)
==
0
)
res
[
quart
]
=
res
[
quart
+
1
]
=
hsqt2
;
...
...
@@ -336,7 +355,7 @@ struct util // hack to avoid duplicate symbols
static
NOINLINE
double
cost_guess
(
size_t
n
)
{
const
double
lfp
=
1.1
;
// penalty for non-hardcoded larger factors
const
expr
double
lfp
=
1.1
;
// penalty for non-hardcoded larger factors
size_t
ni
=
n
;
double
result
=
0.
;
while
((
n
&
1
)
==
0
)
...
...
@@ -377,6 +396,41 @@ struct util // hack to avoid duplicate symbols
res
*=
sz
;
return
res
;
}
static
NOINLINE
void
sanity_check
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
bool
inplace
)
{
auto
ndim
=
shape
.
size
();
if
(
ndim
<
1
)
throw
runtime_error
(
"ndim must be >= 1"
);
if
((
stride_in
.
size
()
!=
ndim
)
||
(
stride_out
.
size
()
!=
ndim
))
throw
runtime_error
(
"stride dimension mismatch"
);
for
(
auto
shp
:
shape
)
if
(
shp
<
1
)
throw
runtime_error
(
"zero extent detected"
);
if
(
inplace
&&
(
stride_in
!=
stride_out
))
throw
runtime_error
(
"stride mismatch"
);
}
static
NOINLINE
void
sanity_check
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
bool
inplace
,
const
shape_t
&
axes
)
{
sanity_check
(
shape
,
stride_in
,
stride_out
,
inplace
);
auto
ndim
=
shape
.
size
();
shape_t
tmp
(
ndim
,
0
);
for
(
auto
ax
:
axes
)
{
if
(
ax
>=
ndim
)
throw
runtime_error
(
"bad axis number"
);
if
(
++
tmp
[
ax
]
>
1
)
throw
runtime_error
(
"axis specified repeatedly"
);
}
}
static
NOINLINE
void
sanity_check
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
bool
inplace
,
size_t
axis
)
{
sanity_check
(
shape
,
stride_in
,
stride_out
,
inplace
);
if
(
axis
>=
shape
.
size
())
throw
runtime_error
(
"bad axis number"
);
}
};
#define CH(a,b,c) ch[(a)+ido*((b)+l1*(c))]
...
...
@@ -390,7 +444,6 @@ struct util // hack to avoid duplicate symbols
template
<
typename
T0
>
class
cfftp
{
private:
struct
fctdata
{
size_t
fct
;
...
...
@@ -452,7 +505,8 @@ template<bool bwd, typename T> void pass3 (size_t ido, size_t l1,
const
T
*
restrict
cc
,
T
*
restrict
ch
,
const
cmplx
<
T0
>
*
restrict
wa
)
{
constexpr
size_t
cdim
=
3
;
constexpr
T0
tw1r
=-
0.5
,
tw1i
=
(
bwd
?
1.
:
-
1.
)
*
0.86602540378443864676
;
constexpr
T0
tw1r
=-
0.5
,
tw1i
=
(
bwd
?
1
:
-
1
)
*
T0
(
0.8660254037844386467637231707529362
L
);
if
(
ido
==
1
)
for
(
size_t
k
=
0
;
k
<
l1
;
++
k
)
...
...
@@ -554,10 +608,10 @@ template<bool bwd, typename T> void pass5 (size_t ido, size_t l1,
const
T
*
restrict
cc
,
T
*
restrict
ch
,
const
cmplx
<
T0
>
*
restrict
wa
)
{
constexpr
size_t
cdim
=
5
;
constexpr
T0
tw1r
=
0.3090169943749474241
,
tw1i
=
(
bwd
?
1
.
:
-
1
.
)
*
0.9510565162951535721
2
,
tw2r
=
-
0.8090169943749474241
,
tw2i
=
(
bwd
?
1
.
:
-
1
.
)
*
0.5877852522924731291
7
;
constexpr
T0
tw1r
=
T0
(
0.3090169943749474241
022934171828191
L
)
,
tw1i
=
(
bwd
?
1
:
-
1
)
*
T0
(
0.9510565162951535721
164393333793821
L
)
,
tw2r
=
T0
(
-
0.8090169943749474241
022934171828191
L
)
,
tw2i
=
(
bwd
?
1
:
-
1
)
*
T0
(
0.5877852522924731291
687059546390728
L
)
;
if
(
ido
==
1
)
for
(
size_t
k
=
0
;
k
<
l1
;
++
k
)
...
...
@@ -618,12 +672,12 @@ template<bool bwd, typename T> void pass7(size_t ido, size_t l1,
const
T
*
restrict
cc
,
T
*
restrict
ch
,
const
cmplx
<
T0
>
*
restrict
wa
)
{
constexpr
size_t
cdim
=
7
;
constexpr
T0
tw1r
=
0.623489801858733530525
,
tw1i
=
(
bwd
?
1
.
:
-
1
.
)
*
0.7818314824680298087084
,
tw2r
=
-
0.22252093395631440428
9
,
tw2i
=
(
bwd
?
1
.
:
-
1
.
)
*
0.9749279121818236070181
,
tw3r
=
-
0.9009688679024191262361
,
tw3i
=
(
bwd
?
1
.
:
-
1
.
)
*
0.433883739117558120475
8
;
constexpr
T0
tw1r
=
T0
(
0.623489801858733530525
0048840042398
L
)
,
tw1i
=
(
bwd
?
1
:
-
1
)
*
T0
(
0.7818314824680298087084
445266740578
L
)
,
tw2r
=
T0
(
-
0.22252093395631440428
89025644967948
L
)
,
tw2i
=
(
bwd
?
1
:
-
1
)
*
T0
(
0.9749279121818236070181
316829939312
L
)
,
tw3r
=
T0
(
-
0.9009688679024191262361
023195074451
L
)
,
tw3i
=
(
bwd
?
1
:
-
1
)
*
T0
(
0.433883739117558120475
768332848359
L
)
;
if
(
ido
==
1
)
for
(
size_t
k
=
0
;
k
<
l1
;
++
k
)
...
...
@@ -688,17 +742,17 @@ template<bool bwd, typename T> void pass7(size_t ido, size_t l1,
template
<
bool
bwd
,
typename
T
>
void
pass11
(
size_t
ido
,
size_t
l1
,
const
T
*
restrict
cc
,
T
*
restrict
ch
,
const
cmplx
<
T0
>
*
restrict
wa
)
{
const
size_t
cdim
=
11
;
const
T0
tw1r
=
0.8412535328311811688618
,
tw1i
=
(
bwd
?
1
.
:
-
1
.
)
*
0.5406408174555975821076
,
tw2r
=
0.415415013001886425529
3
,
tw2i
=
(
bwd
?
1
.
:
-
1
.
)
*
0.9096319953545183714117
,
tw3r
=
-
0.142314838273285140443
8
,
tw3i
=
(
bwd
?
1
.
:
-
1
.
)
*
0.989821441880932732376
1
,
tw4r
=
-
0.6548607339452850640569
,
tw4i
=
(
bwd
?
1
.
:
-
1
.
)
*
0.755749574354258283774
,
tw5r
=
-
0.959492973614497389890
4
,
tw5i
=
(
bwd
?
1
.
:
-
1
.
)
*
0.2817325568414296977114
;
const
expr
size_t
cdim
=
11
;
const
expr
T0
tw1r
=
T0
(
0.8412535328311811688618
116489193677
L
)
,
tw1i
=
(
bwd
?
1
:
-
1
)
*
T0
(
0.5406408174555975821076
359543186917
L
)
,
tw2r
=
T0
(
0.415415013001886425529
2741492296232
L
)
,
tw2i
=
(
bwd
?
1
:
-
1
)
*
T0
(
0.9096319953545183714117
153830790285
L
)
,
tw3r
=
T0
(
-
0.142314838273285140443
7926686163697
L
)
,
tw3i
=
(
bwd
?
1
:
-
1
)
*
T0
(
0.989821441880932732376
0920377767188
L
)
,
tw4r
=
T0
(
-
0.6548607339452850640569
250724662936
L
)
,
tw4i
=
(
bwd
?
1
:
-
1
)
*
T0
(
0.755749574354258283774
0358439723444
L
)
,
tw5r
=
T0
(
-
0.959492973614497389890
3680570663277
L
)
,
tw5i
=
(
bwd
?
1
:
-
1
)
*
T0
(
0.2817325568414296977114
179153466169
L
)
;
if
(
ido
==
1
)
for
(
size_t
k
=
0
;
k
<
l1
;
++
k
)
...
...
@@ -1035,7 +1089,7 @@ template<typename T> void radf3(size_t ido, size_t l1,
const
T
*
restrict
cc
,
T
*
restrict
ch
,
const
T0
*
restrict
wa
)
{
constexpr
size_t
cdim
=
3
;
constexpr
T0
taur
=-
0.5
,
taui
=
0.86602540378443864676
;
constexpr
T0
taur
=-
0.5
,
taui
=
T0
(
0.86602540378443864676
37231707529362
L
)
;
for
(
size_t
k
=
0
;
k
<
l1
;
k
++
)
{
...
...
@@ -1069,7 +1123,7 @@ template<typename T> void radf4(size_t ido, size_t l1,
const
T
*
restrict
cc
,
T
*
restrict
ch
,
const
T0
*
restrict
wa
)
{
constexpr
size_t
cdim
=
4
;
constexpr
T0
hsqt2
=
0.70710678118654752440
;
constexpr
T0
hsqt2
=
T0
(
0.70710678118654752440
0844362104849
L
)
;
for
(
size_t
k
=
0
;
k
<
l1
;
k
++
)
{
...
...
@@ -1110,8 +1164,10 @@ template<typename T> void radf5(size_t ido, size_t l1,
const
T
*
restrict
cc
,
T
*
restrict
ch
,
const
T0
*
restrict
wa
)
{
constexpr
size_t
cdim
=
5
;
constexpr
T0
tr11
=
0.3090169943749474241
,
ti11
=
0.95105651629515357212
,
tr12
=-
0.8090169943749474241
,
ti12
=
0.58778525229247312917
;
constexpr
T0
tr11
=
T0
(
0.3090169943749474241022934171828191
L
),
ti11
=
T0
(
0.9510565162951535721164393333793821
L
),
tr12
=
T0
(
-
0.8090169943749474241022934171828191
L
),
ti12
=
T0
(
0.5877852522924731291687059546390728
L
);
for
(
size_t
k
=
0
;
k
<
l1
;
k
++
)
{
...
...
@@ -1335,7 +1391,7 @@ template<typename T> void radb3(size_t ido, size_t l1,
const
T
*
restrict
cc
,
T
*
restrict
ch
,
const
T0
*
restrict
wa
)
{
constexpr
size_t
cdim
=
3
;
constexpr
T0
taur
=-
0.5
,
taui
=
0.86602540378443864676
;
constexpr
T0
taur
=-
0.5
,
taui
=
T0
(
0.86602540378443864676
37231707529362
L
)
;
for
(
size_t
k
=
0
;
k
<
l1
;
k
++
)
{
...
...
@@ -1370,7 +1426,7 @@ template<typename T> void radb4(size_t ido, size_t l1,
const
T
*
restrict
cc
,
T
*
restrict
ch
,
const
T0
*
restrict
wa
)
{
constexpr
size_t
cdim
=
4
;
constexpr
T0
sqrt2
=
1.41421356237309504880
;
constexpr
T0
sqrt2
=
T0
(
1.41421356237309504880
1688724209698
L
)
;
for
(
size_t
k
=
0
;
k
<
l1
;
k
++
)
{
...
...
@@ -1416,8 +1472,10 @@ template<typename T> void radb5(size_t ido, size_t l1,
const
T
*
restrict
cc
,
T
*
restrict
ch
,
const
T0
*
restrict
wa
)
{
constexpr
size_t
cdim
=
5
;
constexpr
T0
tr11
=
0.3090169943749474241
,
ti11
=
0.95105651629515357212
,
tr12
=-
0.8090169943749474241
,
ti12
=
0.58778525229247312917
;
constexpr
T0
tr11
=
T0
(
0.3090169943749474241022934171828191
L
),
ti11
=
T0
(
0.9510565162951535721164393333793821
L
),
tr12
=
T0
(
-
0.8090169943749474241022934171828191
L
),
ti12
=
T0
(
0.5877852522924731291687059546390728
L
);
for
(
size_t
k
=
0
;
k
<
l1
;
k
++
)
{
...
...
@@ -1825,7 +1883,7 @@ template<typename T0> class fftblue
}
/* initialize the zero-padded, Fourier transformed b_k. Add normalisation. */
T0
xn2
=
1.
/
n2
;
T0
xn2
=
T0
(
1
)
/
n2
;
bkf
[
0
]
=
bk
[
0
]
*
xn2
;
for
(
size_t
m
=
1
;
m
<
n
;
++
m
)
bkf
[
m
]
=
bkf
[
n2
-
m
]
=
bk
[
m
]
*
xn2
;
...
...
@@ -1996,10 +2054,9 @@ template<size_t N, typename Ti, typename To> class multi_iter
for
(
int
i
=
pos
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
if
(
i
==
int
(
idim
))
continue
;
++
pos
[
i
];
p_ii
+=
iarr
.
stride
(
i
);
p_oi
+=
oarr
.
stride
(
i
);
if
(
pos
[
i
]
<
iarr
.
shape
(
i
))
if
(
++
pos
[
i
]
<
iarr
.
shape
(
i
))
return
;
pos
[
i
]
=
0
;
p_ii
-=
iarr
.
shape
(
i
)
*
iarr
.
stride
(
i
);
...
...
@@ -2039,7 +2096,12 @@ template<size_t N, typename Ti, typename To> class multi_iter
};
#if defined(HAVE_VECSUPPORT)
template
<
typename
T
>
struct
VTYPE
{};
template
<
typename
T
>
struct
VTYPE
{};
template
<
>
struct
VTYPE
<
long
double
>
{
using
type
=
long
double
__attribute__
((
vector_size
(
sizeof
(
long
double
))));
static
constexpr
int
vlen
=
1
;
};
template
<
>
struct
VTYPE
<
double
>
{
using
type
=
double
__attribute__
((
vector_size
(
VBYTELEN
)));
...
...
@@ -2051,10 +2113,7 @@ template<> struct VTYPE<float>
static
constexpr
int
vlen
=
VBYTELEN
/
sizeof
(
float
);
};
#else
template
<
typename
T
>
struct
VTYPE
{};
template
<
>
struct
VTYPE
<
double
>
{
static
constexpr
int
vlen
=
1
;
};
template
<
>
struct
VTYPE
<
float
>
template
<
typename
T
>
struct
VTYPE
{
static
constexpr
int
vlen
=
1
;
};
#endif
...
...
@@ -2345,52 +2404,99 @@ template<typename T> NOINLINE void general_r(
}
// namespace detail
template
<
typename
T
>
void
c2c
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
bool
forward
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
template
<
typename
T
>
void
c2c
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
bool
forward
,
const
std
::
complex
<
T
>
*
data_in
,
std
::
complex
<
T
>
*
data_out
,
T
fct
)
{
using
namespace
detail
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
ndarr
<
cmplx
<
T
>>
ain
(
data_in
,
shape
,
stride_in
),
aout
(
data_out
,
shape
,
stride_out
);
general_c
(
ain
,
aout
,
axes
,
forward
,
fct
);
}
template
<
typename
T
>
void
r2c
(
const
shape_t
&
shape
,
template
<
typename
T
>
void
r2c
(
const
shape_t
&
shape
_in
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
size_t
axis
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
const
T
*
data_in
,
std
::
complex
<
T
>
*
data_out
,
T
fct
)
{
using
namespace
detail
;
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
);
ndarr
<
cmplx
<
T
>>
aout
(
data_out
,
shape
,
stride_out
);
util
::
sanity_check
(
shape_in
,
stride_in
,
stride_out
,
false
,
axis
);
ndarr
<
T
>
ain
(
data_in
,
shape_in
,
stride_in
);
ndarr
<
cmplx
<
T
>>
aout
(
data_out
,
shape_in
,
stride_out
);
// FIXME
general_r2c
(
ain
,
aout
,
axis
,
fct
);
}
template
<
typename
T
>
void
c2r
(
const
shape_t
&
shape
,
size_t
new_size
,
template
<
typename
T
>
void
r2c
(
const
shape_t
&
shape_in
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
const
T
*
data_in
,
std
::
complex
<
T
>
*
data_out
,
T
fct
)
{
using
namespace
detail
;
util
::
sanity_check
(
shape_in
,
stride_in
,
stride_out
,
false
,
axes
);
r2c
(
shape_in
,
stride_in
,
stride_out
,
axes
.
back
(),
data_in
,
data_out
,
fct
);
if
(
axes
.
size
()
==
1
)
return
;
shape_t
shape_out
(
shape_in
);
shape_out
[
axes
.
back
()]
=
shape_in
[
axes
.
back
()]
/
2
+
1
;
auto
newaxes
=
shape_t
{
axes
.
begin
(),
--
axes
.
end
()};
c2c
(
shape_out
,
stride_out
,
stride_out
,
newaxes
,
true
,
data_out
,
data_out
,
T
(
1
));
}
template
<
typename
T
>
void
c2r
(
const
shape_t
&
shape_out
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
size_t
axis
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
const
std
::
complex
<
T
>
*
data_in
,
T
*
data_out
,
T
fct
)
{
using
namespace
detail
;
shape_t
shape_out
(
shape
);
shape_out
[
axis
]
=
new_size
;
ndarr
<
cmplx
<
T
>>
ain
(
data_in
,
shape
,
stride_in
);
util
::
sanity_check
(
shape_out
,
stride_in
,
stride_out
,
false
,
axis
);
shape_t
shape_in
(
shape_out
);
shape_in
[
axis
]
=
shape_out
[
axis
]
/
2
+
1
;
ndarr
<
cmplx
<
T
>>
ain
(
data_in
,
shape_in
,
stride_in
);
ndarr
<
T
>
aout
(
data_out
,
shape_out
,
stride_out
);
general_c2r
(
ain
,
aout
,
axis
,
fct
);
}
template
<
typename
T
>
void
c2r
(
const
shape_t
&
shape_out
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
const
std
::
complex
<
T
>
*
data_in
,
T
*
data_out
,
T
fct
)
{
using
namespace
detail
;
if
(
axes
.
size
()
==
1
)
{
c2r
(
shape_out
,
stride_in
,
stride_out
,
axes
[
0
],
data_in
,
data_out
,
fct
);
return
;
}
util
::
sanity_check
(
shape_out
,
stride_in
,
stride_out
,
false
,
axes
);
auto
shape_in
=
shape_out
;
shape_in
[
axes
.
back
()]
=
shape_out
[
axes
.
back
()]
/
2
+
1
;
auto
nval
=
util
::
prod
(
shape_in
);
stride_t
stride_inter
(
shape_in
.
size
());
stride_inter
.
back
()
=
sizeof
(
cmplx
<
T
>
);
for
(
int
i
=
shape_in
.
size
()
-
2
;
i
>=
0
;
--
i
)
stride_inter
[
i
]
=
stride_inter
[
i
+
1
]
*
shape_in
[
i
+
1
];
arr
<
complex
<
T
>>
tmp
(
nval
);
auto
newaxes
=
shape_t
({
axes
.
begin
(),
--
axes
.
end
()});
c2c
(
shape_in
,
stride_in
,
stride_inter
,
newaxes
,
false
,
data_in
,
tmp
.
data
(),
T
(
1
));
c2r
(
shape_out
,
stride_inter
,
stride_out
,
axes
.
back
(),
tmp
.
data
(),
data_out
,
fct
);
}
template
<
typename
T
>
void
r2r_fftpack
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
size_t
axis
,
bool
forward
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
bool
forward
,
const
T
*
data_in
,
T
*
data_out
,
T
fct
)
{
using
namespace
detail
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axis
);
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
),
aout
(
data_out
,
shape
,
stride_out
);
general_r
(
ain
,
aout
,
axis
,
forward
,
fct
);
}
template
<
typename
T
>
void
r2r_hartley
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
const
T
*
data_in
,
T
*
data_out
,
T
fct
)
{
using
namespace
detail
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
),
aout
(
data_out
,
shape
,
stride_out
);
general_hartley
(
ain
,
aout
,
axes
,
fct
);
}
...
...
pypocketfft.cc
View file @
fcfd81ff
...
...
@@ -26,7 +26,6 @@ namespace {
using
namespace
std
;
using
namespace
pocketfft
;
using
namespace
pocketfft
::
detail
;
namespace
py
=
pybind11
;
...
...
@@ -83,9 +82,9 @@ template<typename T> py::array xfftn_internal(const py::array &in,
{
auto
dims
(
copy_shape
(
in
));
py
::
array
res
=
inplace
?
in
:
py
::
array_t
<
complex
<
T
>>
(
dims
);
ndarr
<
cmplx
<
T
>>
ain
(
in
.
data
(),
dims
,
copy_strides
(
in
));
ndarr
<
c
mplx
<
T
>
>
aout
(
res
.
mutable_data
(),
dims
,
copy_strides
(
res
));
general_c
<
T
>
(
ain
,
aout
,
axes
,
fwd
,
fct
);
c2c
(
dims
,
copy_strides
(
in
),
copy_strides
(
res
),
axes
,
fwd
,
reinterpret_cast
<
const
co
mpl
e
x
<
T
>
*>
(
in
.
data
()),
reinterpret_cast
<
complex
<
T
>
*>
(
res
.
mutable_data
()),
T
(
fct
)
)
;
return
res
;
}
...
...
@@ -108,12 +107,9 @@ template<typename T> py::array rfftn_internal(const py::array &in,
auto
dims_in
(
copy_shape
(
in
)),
dims_out
(
dims_in
);
dims_out
[
axes
.
back
()]
=
(
dims_out
[
axes
.
back
()]
>>
1
)
+
1
;
py
::
array
res
=
py
::
array_t
<
complex
<
T
>>
(
dims_out
);
ndarr
<
T
>
ain
(
in
.
data
(),
dims_in
,
copy_strides
(
in
));
ndarr
<
cmplx
<
T
>>
aout
(
res
.
mutable_data
(),
dims_out
,
copy_strides
(
res
));
general_r2c
<
T
>
(
ain
,
aout
,
axes
.
back
(),
fct
);
if
(
axes
.
size
()
==
1
)
return
res
;
shape_t
axes2
(
axes
.
begin
(),
--
axes
.
end
());
general_c
<
T
>
(
aout
,
aout
,
axes2
,
true
,
1.
);
r2c
(
dims_in
,
copy_strides
(
in
),
copy_strides
(
res
),
axes
,
reinterpret_cast
<
const
T
*>
(
in
.
data
()),
reinterpret_cast
<
complex
<
T
>
*>
(
res
.
mutable_data
()),
T
(
fct
));
return
res
;
}
py
::
array
rfftn
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
)
...
...
@@ -126,9 +122,9 @@ template<typename T> py::array xrfft_scipy(const py::array &in,
{
auto
dims
(
copy_shape
(
in
));
py
::
array
res
=
inplace
?
in
:
py
::
array_t
<
T
>
(
dims
);
ndarr
<
T
>
ain
(
in
.
data
(),
dims
,
copy_strides
(
in
));
ndarr
<
T
>
aout
(
res
.
mutable_data
(),
dims
,
copy_strides
(
res
))
;
general_r
<
T
>
(
ain
,
aout
,
axis
,
fwd
,
fct
);
r2r_fftpack
(
dims
,
copy_strides
(
in
),
copy_strides
(
res
),
axis
,
fwd
,
reinterpret_cast
<
const
T
*>
(
in
.
data
(
))
,
reinterpret_cast
<
T
*>
(
res
.
mutable_data
()),
T
(
fct
)
)
;
return
res
;
}
py
::
array
rfft_scipy
(
const
py
::
array
&
in
,
size_t
axis
,
double
fct
,
bool
inplace
)
...
...
@@ -148,19 +144,16 @@ template<typename T> py::array irfftn_internal(const py::array &in,
py
::
object
axes_
,
size_t
lastsize
,
T
fct
)
{
auto
axes
=
makeaxes
(
in
,
axes_
);
py
::
array
inter
=
(
axes
.
size
()
==
1
)
?
in
:
xfftn_internal
<
T
>
(
in
,
shape_t
(
axes
.
begin
(),
--
axes
.
end
()),
1.
,
false
,
false
);
size_t
axis
=
axes
.
back
();
if
(
lastsize
==
0
)
lastsize
=
2
*
inter
.
shape
(
axis
)
-
1
;
if
(
ptrdiff_t
(
lastsize
/
2
)
+
1
!=
inter
.
shape
(
axis
))
shape_t
dims_in
(
copy_shape
(
in
)),
dims_out
=
dims_in
;
if
(
lastsize
==
0
)
lastsize
=
2
*
dims_in
[
axis
]
-
1
;
if
((
lastsize
/
2
)
+
1
!=
dims_in
[
axis
])
throw
runtime_error
(
"bad lastsize"
);
auto
dims_out
(
copy_shape
(
inter
));
dims_out
[
axis
]
=
lastsize
;
py
::
array
res
=
py
::
array_t
<
T
>
(
dims_out
);
ndarr
<
cmplx
<
T
>>
ain
(
inter
.
data
(),
copy_shape
(
inter
),
copy_strides
(
inter
));
ndarr
<
T
>
aout
(
res
.
mutable_data
(),
dims_out
,
copy_strides
(
res
))
;
general_c2r
<
T
>
(
ain
,
aout
,
axis
,
fct
);
c2r
(
dims_in
,
lastsize
,
copy_strides
(
in
),
copy_strides
(
res
),
axes
,
reinterpret_cast
<
const
complex
<
T
>
*>
(
in
.
data
(
))
,
reinterpret_cast
<
T
*>
(
res
.
mutable_data
()),
T
(
fct
)
)
;
return
res
;
}
py
::
array
irfftn
(
const
py
::
array
&
in
,
py
::
object
axes_
,
size_t
lastsize
,
...
...
@@ -176,9 +169,9 @@ template<typename T> py::array hartley_internal(const py::array &in,
{
auto
dims
(
copy_shape
(
in
));
py
::
array
res
=
inplace
?
in
:
py
::
array_t
<
T
>
(
dims
);
ndarr
<
T
>
ain
(
in
.
data
()
,
copy_s
hape
(
in
),
copy_strides
(
in
));
ndarr
<
T
>
aout
(
res
.
mutable_data
(),
copy_shape
(
res
),
copy_strides
(
res
))
;
general_hartley
<
T
>
(
ain
,
aout
,
makeaxes
(
in
,
axes_
),
fct
);
r2r_hartley
(
dims
,
copy_s
trides
(
in
),
copy_strides
(
res
),
makeaxes
(
in
,
axes_
),
reinterpret_cast
<
const
T
*>
(
in
.
data
(
))
,
reinterpret_cast
<
T
*>
(
res
.
mutable_data
()
),
T
(
fct
)
)
;
return
res
;
}
py
::
array
hartley
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
,
...
...
@@ -192,6 +185,7 @@ py::array hartley(const py::array &in, py::object axes_, double fct,
template
<
typename
T
>
py
::
array
complex2hartley
(
const
py
::
array
&
in
,
const
py
::
array
&
tmp
,
py
::
object
axes_
,
bool
inplace
)
{
using
namespace
pocketfft
::
detail
;
int
ndim
=
in
.
ndim
();
auto
dims_out
(
copy_shape
(
in
));
py
::
array
out
=
inplace
?
in
:
py
::
array_t
<
T
>
(
dims_out
);
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment