Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Martin Reinecke
pypocketfft
Commits
e045a9eb
Commit
e045a9eb
authored
Apr 20, 2019
by
Martin Reinecke
Browse files
fixes
parent
1e3079c4
Changes
2
Hide whitespace changes
Inline
Side-by-side
pocketfft.cc
View file @
e045a9eb
...
...
@@ -1953,7 +1953,7 @@ class multiarr
class
multi_iter
{
p
rivate
:
p
ublic
:
vector
<
diminfo
>
dim
;
vector
<
size_t
>
pos
;
size_t
ofs_
,
len
;
...
...
@@ -2538,6 +2538,85 @@ py::array hartley(const py::array &in, py::object axes_, double fct)
return
res
;
}
template
<
typename
T
>
py
::
array
complex2hartley
(
const
py
::
array
&
in
,
const
py
::
array
&
tmp
,
py
::
object
axes_
)
{
int
ndim
=
in
.
ndim
();
vector
<
size_t
>
dims_out
(
in
.
ndim
());
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
dims_out
[
i
]
=
in
.
shape
(
i
);
py
::
array
out
;
if
(
in
.
dtype
().
is
(
f64
))
out
=
py
::
array_t
<
double
>
(
dims_out
);
else
if
(
in
.
dtype
().
is
(
f32
))
out
=
py
::
array_t
<
float
>
(
dims_out
);
else
throw
runtime_error
(
"unsupported data type"
);
vector
<
size_t
>
dims_tmp
(
ndim
);
vector
<
int64_t
>
stride_tmp
(
ndim
),
stride_out
(
ndim
);
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
dims_tmp
[
i
]
=
tmp
.
shape
(
i
);
stride_tmp
[
i
]
=
tmp
.
strides
(
i
)
/
tmp
.
itemsize
();
if
(
stride_tmp
[
i
]
*
tmp
.
itemsize
()
!=
tmp
.
strides
(
i
))
throw
runtime_error
(
"weird strides"
);
stride_out
[
i
]
=
out
.
strides
(
i
)
/
out
.
itemsize
();
if
(
stride_out
[
i
]
*
out
.
itemsize
()
!=
out
.
strides
(
i
))
throw
runtime_error
(
"weird strides"
);
}
auto
axes
=
makeaxes
(
in
,
axes_
);
int
axis
=
axes
.
back
();
multiarr
a_tmp
(
ndim
,
dims_tmp
.
data
(),
stride_tmp
.
data
()),
a_out
(
ndim
,
dims_out
.
data
(),
stride_out
.
data
());
multi_iter
it_tmp
(
a_tmp
,
axis
),
it_out
(
a_out
,
axis
);
const
cmplx
<
T
>
*
tdata
=
(
const
cmplx
<
T
>
*
)
tmp
.
data
();
T
*
odata
=
(
T
*
)
out
.
mutable_data
();
vector
<
bool
>
swp
(
ndim
-
1
,
false
);
for
(
auto
i
:
axes
)
{
if
(
i
!=
axis
)
{
auto
i2
=
i
<
axis
?
i
:
i
-
1
;
swp
[
i2
]
=
true
;
}
}
while
(
!
it_tmp
.
done
())
{
auto
tofs
=
it_tmp
.
offset
();
auto
ofs
=
it_out
.
offset
();
size_t
rofs
=
0
;
for
(
size_t
i
=
0
;
i
<
it_out
.
pos
.
size
();
++
i
)
{
if
(
!
swp
[
i
])
rofs
+=
it_out
.
pos
[
i
]
*
it_out
.
dim
[
i
].
s
;
else
{
auto
x
=
(
it_out
.
pos
[
i
]
==
0
)
?
0
:
it_out
.
dim
[
i
].
n
-
it_out
.
pos
[
i
];
rofs
+=
x
*
it_out
.
dim
[
i
].
s
;
}
}
for
(
size_t
i
=
0
;
i
<
it_tmp
.
length
();
++
i
)
{
auto
re
=
tdata
[
tofs
+
i
*
it_tmp
.
stride
()].
r
;
auto
im
=
tdata
[
tofs
+
i
*
it_tmp
.
stride
()].
i
;
auto
rev_i
=
(
i
==
0
)
?
0
:
it_out
.
length
()
-
i
;
odata
[
ofs
+
i
*
it_out
.
stride
()]
=
re
+
im
;
odata
[
rofs
+
rev_i
*
it_out
.
stride
()]
=
re
-
im
;
}
it_tmp
.
advance
();
it_out
.
advance
();
}
return
out
;
}
py
::
array
hartley2
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
)
{
auto
tmp
=
rfftn
(
in
,
axes_
,
fct
);
if
(
in
.
dtype
().
is
(
f64
))
return
complex2hartley
<
double
>
(
in
,
tmp
,
axes_
);
else
if
(
in
.
dtype
().
is
(
f32
))
return
complex2hartley
<
float
>
(
in
,
tmp
,
axes_
);
else
throw
runtime_error
(
"unsupported data type"
);
}
const
char
*
pypocketfft_DS
=
R"DELIM(
Fast Fourier and Hartley transforms.
...
...
@@ -2667,4 +2746,5 @@ PYBIND11_MODULE(pypocketfft, m)
m
.
def
(
"rfftn"
,
&
rfftn
,
rfftn_DS
,
"a"
_a
,
"axes"
_a
=
py
::
none
(),
"fct"
_a
=
1.
);
m
.
def
(
"irfftn"
,
&
irfftn
,
irfftn_DS
,
"a"
_a
,
"axes"
_a
=
py
::
none
(),
"lastsize"
_a
=
0
,
"fct"
_a
=
1.
);
m
.
def
(
"hartley"
,
&
hartley
,
hartley_DS
,
"a"
_a
,
"axes"
_a
=
py
::
none
(),
"fct"
_a
=
1.
);
m
.
def
(
"hartley2"
,
&
hartley2
,
"a"
_a
,
"axes"
_a
=
py
::
none
(),
"fct"
_a
=
1.
);
}
test.py
View file @
e045a9eb
...
...
@@ -71,18 +71,33 @@ def test_identity_r2(shp):
vmax
=
np
.
max
(
np
.
abs
(
a
))
assert_allclose
(
pypocketfft
.
rfftn
(
pypocketfft
.
irfftn
(
a
),
fct
=
fct
),
a
,
atol
=
2e-15
*
vmax
,
rtol
=
0
)
@
pmp
(
"shp"
,
shapes3D
)
def
x
test_hartley
(
shp
):
@
pmp
(
"shp"
,
shapes2D
+
shapes3D
)
def
test_hartley
2
(
shp
):
a
=
np
.
random
.
rand
(
*
shp
)
-
0.5
v1
=
pypocketfft
.
hartley
(
a
)
v1
=
pypocketfft
.
hartley
2
(
a
)
v2
=
pypocketfft
.
fftn
(
a
.
astype
(
np
.
complex128
))
vmax
=
np
.
max
(
np
.
abs
(
v1
))
v2
=
v2
.
real
+
v2
.
imag
assert_allclose
(
v1
,
v2
,
atol
=
2e-15
*
vmax
,
rtol
=
0
)
@
pmp
(
"shp"
,
shapes3D
)
@
pmp
(
"shp"
,
shapes
)
def
test_hartley_identity
(
shp
):
a
=
np
.
random
.
rand
(
*
shp
)
-
0.5
v1
=
pypocketfft
.
hartley
(
pypocketfft
.
hartley
(
a
))
/
a
.
size
vmax
=
np
.
max
(
np
.
abs
(
a
))
assert_allclose
(
a
,
v1
,
atol
=
2e-15
*
vmax
,
rtol
=
0
)
@
pmp
(
"shp"
,
shapes
)
def
test_hartley2_identity
(
shp
):
a
=
np
.
random
.
rand
(
*
shp
)
-
0.5
v1
=
pypocketfft
.
hartley2
(
pypocketfft
.
hartley2
(
a
))
/
a
.
size
vmax
=
np
.
max
(
np
.
abs
(
a
))
assert_allclose
(
a
,
v1
,
atol
=
2e-15
*
vmax
,
rtol
=
0
)
@
pmp
(
"shp"
,
shapes2D
)
@
pmp
(
"axes"
,
((
0
,),(
1
,),(
0
,
1
),(
1
,
0
)))
def
test_hartley2_2D
(
shp
,
axes
):
a
=
np
.
random
.
rand
(
*
shp
)
-
0.5
fct
=
1.
/
np
.
prod
(
np
.
take
(
shp
,
axes
))
vmax
=
np
.
max
(
np
.
abs
(
a
))
assert_allclose
(
pypocketfft
.
hartley2
(
pypocketfft
.
hartley2
(
a
,
axes
=
axes
),
axes
=
axes
,
fct
=
fct
),
a
,
atol
=
2e-15
*
vmax
,
rtol
=
0
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment