Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Martin Reinecke
pypocketfft
Commits
89a7dd70
Commit
89a7dd70
authored
May 09, 2019
by
Martin Reinecke
Browse files
synchronize
parent
39de2274
Changes
2
Hide whitespace changes
Inline
Side-by-side
pocketfft_hdronly.h
View file @
89a7dd70
...
...
@@ -377,6 +377,42 @@ struct util // hack to avoid duplicate symbols
res
*=
sz
;
return
res
;
}
static
NOINLINE
void
sanity_check
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
bool
inplace
)
{
auto
ndim
=
shape
.
size
();
if
(
ndim
<
1
)
throw
runtime_error
(
"ndim must be >= 1"
);
if
((
stride_in
.
size
()
!=
ndim
)
||
(
stride_out
.
size
()
!=
ndim
))
throw
runtime_error
(
"stride dimension mismatch"
);
for
(
auto
shp
:
shape
)
if
(
shp
<
1
)
throw
runtime_error
(
"zero extent detected"
);
if
(
inplace
)
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
if
(
stride_in
[
i
]
!=
stride_out
[
i
])
throw
runtime_error
(
"stride mismatch"
);
}
static
NOINLINE
void
sanity_check
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
bool
inplace
,
const
shape_t
&
axes
)
{
sanity_check
(
shape
,
stride_in
,
stride_out
,
inplace
);
auto
ndim
=
shape
.
size
();
shape_t
tmp
(
ndim
,
0
);
for
(
auto
ax
:
axes
)
{
if
(
ax
>=
ndim
)
throw
runtime_error
(
"bad axis number"
);
if
(
++
tmp
[
ax
]
>
1
)
throw
runtime_error
(
"axis specified repeatedly"
);
}
}
static
NOINLINE
void
sanity_check
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
bool
inplace
,
size_t
axis
)
{
sanity_check
(
shape
,
stride_in
,
stride_out
,
inplace
);
if
(
axis
>=
shape
.
size
())
throw
runtime_error
(
"bad axis number"
);
}
};
#define CH(a,b,c) ch[(a)+ido*((b)+l1*(c))]
...
...
@@ -2350,6 +2386,7 @@ template<typename T> void c2c(const shape_t &shape,
bool
forward
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
{
using
namespace
detail
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
ndarr
<
cmplx
<
T
>>
ain
(
data_in
,
shape
,
stride_in
),
aout
(
data_out
,
shape
,
stride_out
);
general_c
(
ain
,
aout
,
axes
,
forward
,
fct
);
...
...
@@ -2360,16 +2397,34 @@ template<typename T> void r2c(const shape_t &shape,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
{
using
namespace
detail
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axis
);
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
);
ndarr
<
cmplx
<
T
>>
aout
(
data_out
,
shape
,
stride_out
);
general_r2c
(
ain
,
aout
,
axis
,
fct
);
}
template
<
typename
T
>
void
r2c
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
{
using
namespace
detail
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
r2c
(
shape
,
stride_in
,
stride_out
,
axes
.
back
(),
data_in
,
data_out
,
fct
);
if
(
axes
.
size
()
==
1
)
return
;
shape_t
shape_out
(
shape
);
shape_out
[
axes
.
back
()]
=
shape
[
axes
.
back
()]
/
2
+
1
;
auto
newaxes
=
shape_t
{
axes
.
begin
(),
--
axes
.
end
()};
c2c
(
shape_out
,
stride_out
,
stride_out
,
newaxes
,
true
,
data_out
,
data_out
,
T
(
1
));
}
template
<
typename
T
>
void
c2r
(
const
shape_t
&
shape
,
size_t
new_size
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
size_t
axis
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
{
using
namespace
detail
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axis
);
shape_t
shape_out
(
shape
);
shape_out
[
axis
]
=
new_size
;
ndarr
<
cmplx
<
T
>>
ain
(
data_in
,
shape
,
stride_in
);
...
...
@@ -2377,11 +2432,36 @@ template<typename T> void c2r(const shape_t &shape, size_t new_size,
general_c2r
(
ain
,
aout
,
axis
,
fct
);
}
template
<
typename
T
>
void
c2r
(
const
shape_t
&
shape
,
size_t
new_size
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
{
using
namespace
detail
;
if
(
axes
.
size
()
==
1
)
{
c2r
(
shape
,
new_size
,
stride_in
,
stride_out
,
axes
[
0
],
data_in
,
data_out
,
fct
);
return
;
}
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
auto
nval
=
util
::
prod
(
shape
);
stride_t
stride_inter
(
shape
.
size
());
stride_inter
.
back
()
=
sizeof
(
cmplx
<
T
>
);
for
(
int
i
=
shape
.
size
()
-
2
;
i
>=
0
;
--
i
)
stride_inter
[
i
]
=
stride_inter
[
i
+
1
]
*
shape
[
i
+
1
];
arr
<
char
>
tmp
(
nval
*
sizeof
(
cmplx
<
T
>
));
auto
newaxes
=
shape_t
({
axes
.
begin
(),
--
axes
.
end
()});
c2c
(
shape
,
stride_in
,
stride_inter
,
newaxes
,
false
,
data_in
,
tmp
.
data
(),
T
(
1
));
c2r
(
shape
,
new_size
,
stride_inter
,
stride_out
,
axes
.
back
(),
tmp
.
data
(),
data_out
,
fct
);
}
template
<
typename
T
>
void
r2r_fftpack
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
size_t
axis
,
bool
forward
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
{
using
namespace
detail
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axis
);
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
),
aout
(
data_out
,
shape
,
stride_out
);
general_r
(
ain
,
aout
,
axis
,
forward
,
fct
);
}
...
...
@@ -2391,6 +2471,7 @@ template<typename T> void r2r_hartley(const shape_t &shape,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
{
using
namespace
detail
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
),
aout
(
data_out
,
shape
,
stride_out
);
general_hartley
(
ain
,
aout
,
axes
,
fct
);
}
...
...
pypocketfft.cc
View file @
89a7dd70
...
...
@@ -26,7 +26,6 @@ namespace {
using
namespace
std
;
using
namespace
pocketfft
;
using
namespace
pocketfft
::
detail
;
namespace
py
=
pybind11
;
...
...
@@ -83,9 +82,8 @@ template<typename T> py::array xfftn_internal(const py::array &in,
{
auto
dims
(
copy_shape
(
in
));
py
::
array
res
=
inplace
?
in
:
py
::
array_t
<
complex
<
T
>>
(
dims
);
ndarr
<
cmplx
<
T
>>
ain
(
in
.
data
(),
dims
,
copy_strides
(
in
));
ndarr
<
cmplx
<
T
>>
aout
(
res
.
mutable_data
(),
dims
,
copy_strides
(
res
));
general_c
<
T
>
(
ain
,
aout
,
axes
,
fwd
,
fct
);
c2c
(
dims
,
copy_strides
(
in
),
copy_strides
(
res
),
axes
,
fwd
,
in
.
data
(),
res
.
mutable_data
(),
T
(
fct
));
return
res
;
}
...
...
@@ -108,12 +106,8 @@ template<typename T> py::array rfftn_internal(const py::array &in,
auto
dims_in
(
copy_shape
(
in
)),
dims_out
(
dims_in
);
dims_out
[
axes
.
back
()]
=
(
dims_out
[
axes
.
back
()]
>>
1
)
+
1
;
py
::
array
res
=
py
::
array_t
<
complex
<
T
>>
(
dims_out
);
ndarr
<
T
>
ain
(
in
.
data
(),
dims_in
,
copy_strides
(
in
));
ndarr
<
cmplx
<
T
>>
aout
(
res
.
mutable_data
(),
dims_out
,
copy_strides
(
res
));
general_r2c
<
T
>
(
ain
,
aout
,
axes
.
back
(),
fct
);
if
(
axes
.
size
()
==
1
)
return
res
;
shape_t
axes2
(
axes
.
begin
(),
--
axes
.
end
());
general_c
<
T
>
(
aout
,
aout
,
axes2
,
true
,
1.
);
r2c
(
dims_in
,
copy_strides
(
in
),
copy_strides
(
res
),
axes
,
in
.
data
(),
res
.
mutable_data
(),
T
(
fct
));
return
res
;
}
py
::
array
rfftn
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
)
...
...
@@ -126,9 +120,8 @@ template<typename T> py::array xrfft_scipy(const py::array &in,
{
auto
dims
(
copy_shape
(
in
));
py
::
array
res
=
inplace
?
in
:
py
::
array_t
<
T
>
(
dims
);
ndarr
<
T
>
ain
(
in
.
data
(),
dims
,
copy_strides
(
in
));
ndarr
<
T
>
aout
(
res
.
mutable_data
(),
dims
,
copy_strides
(
res
));
general_r
<
T
>
(
ain
,
aout
,
axis
,
fwd
,
fct
);
r2r_fftpack
(
dims
,
copy_strides
(
in
),
copy_strides
(
res
),
axis
,
fwd
,
in
.
data
(),
res
.
mutable_data
(),
T
(
fct
));
return
res
;
}
py
::
array
rfft_scipy
(
const
py
::
array
&
in
,
size_t
axis
,
double
fct
,
bool
inplace
)
...
...
@@ -148,19 +141,15 @@ template<typename T> py::array irfftn_internal(const py::array &in,
py
::
object
axes_
,
size_t
lastsize
,
T
fct
)
{
auto
axes
=
makeaxes
(
in
,
axes_
);
py
::
array
inter
=
(
axes
.
size
()
==
1
)
?
in
:
xfftn_internal
<
T
>
(
in
,
shape_t
(
axes
.
begin
(),
--
axes
.
end
()),
1.
,
false
,
false
);
size_t
axis
=
axes
.
back
();
if
(
lastsize
==
0
)
lastsize
=
2
*
inter
.
shape
(
axis
)
-
1
;
if
(
ptrdiff_t
(
lastsize
/
2
)
+
1
!=
inter
.
shape
(
axis
))
shape_t
dims_in
(
copy_shape
(
in
)),
dims_out
=
dims_in
;
if
(
lastsize
==
0
)
lastsize
=
2
*
dims_in
[
axis
]
-
1
;
if
((
lastsize
/
2
)
+
1
!=
dims_in
[
axis
])
throw
runtime_error
(
"bad lastsize"
);
auto
dims_out
(
copy_shape
(
inter
));
dims_out
[
axis
]
=
lastsize
;
py
::
array
res
=
py
::
array_t
<
T
>
(
dims_out
);
ndarr
<
cmplx
<
T
>>
ain
(
inter
.
data
(),
copy_shape
(
inter
),
copy_strides
(
inter
));
ndarr
<
T
>
aout
(
res
.
mutable_data
(),
dims_out
,
copy_strides
(
res
));
general_c2r
<
T
>
(
ain
,
aout
,
axis
,
fct
);
c2r
(
dims_in
,
lastsize
,
copy_strides
(
in
),
copy_strides
(
res
),
axes
,
in
.
data
(),
res
.
mutable_data
(),
T
(
fct
));
return
res
;
}
py
::
array
irfftn
(
const
py
::
array
&
in
,
py
::
object
axes_
,
size_t
lastsize
,
...
...
@@ -176,9 +165,8 @@ template<typename T> py::array hartley_internal(const py::array &in,
{
auto
dims
(
copy_shape
(
in
));
py
::
array
res
=
inplace
?
in
:
py
::
array_t
<
T
>
(
dims
);
ndarr
<
T
>
ain
(
in
.
data
(),
copy_shape
(
in
),
copy_strides
(
in
));
ndarr
<
T
>
aout
(
res
.
mutable_data
(),
copy_shape
(
res
),
copy_strides
(
res
));
general_hartley
<
T
>
(
ain
,
aout
,
makeaxes
(
in
,
axes_
),
fct
);
r2r_hartley
(
dims
,
copy_strides
(
in
),
copy_strides
(
res
),
makeaxes
(
in
,
axes_
),
in
.
data
(),
res
.
mutable_data
(),
T
(
fct
));
return
res
;
}
py
::
array
hartley
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
,
...
...
@@ -192,6 +180,7 @@ py::array hartley(const py::array &in, py::object axes_, double fct,
template
<
typename
T
>
py
::
array
complex2hartley
(
const
py
::
array
&
in
,
const
py
::
array
&
tmp
,
py
::
object
axes_
,
bool
inplace
)
{
using
namespace
pocketfft
::
detail
;
int
ndim
=
in
.
ndim
();
auto
dims_out
(
copy_shape
(
in
));
py
::
array
out
=
inplace
?
in
:
py
::
array_t
<
T
>
(
dims_out
);
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment