Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
P
pypocketfft
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
4
Issues
4
List
Boards
Labels
Service Desk
Milestones
Merge Requests
2
Merge Requests
2
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Martin Reinecke
pypocketfft
Commits
89a7dd70
Commit
89a7dd70
authored
May 09, 2019
by
Martin Reinecke
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
synchronize
parent
39de2274
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
95 additions
and
25 deletions
+95
-25
pocketfft_hdronly.h
pocketfft_hdronly.h
+81
-0
pypocketfft.cc
pypocketfft.cc
+14
-25
No files found.
pocketfft_hdronly.h
View file @
89a7dd70
...
@@ -377,6 +377,42 @@ struct util // hack to avoid duplicate symbols
...
@@ -377,6 +377,42 @@ struct util // hack to avoid duplicate symbols
res
*=
sz
;
res
*=
sz
;
return
res
;
return
res
;
}
}
static
NOINLINE
void
sanity_check
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
bool
inplace
)
{
auto
ndim
=
shape
.
size
();
if
(
ndim
<
1
)
throw
runtime_error
(
"ndim must be >= 1"
);
if
((
stride_in
.
size
()
!=
ndim
)
||
(
stride_out
.
size
()
!=
ndim
))
throw
runtime_error
(
"stride dimension mismatch"
);
for
(
auto
shp
:
shape
)
if
(
shp
<
1
)
throw
runtime_error
(
"zero extent detected"
);
if
(
inplace
)
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
if
(
stride_in
[
i
]
!=
stride_out
[
i
])
throw
runtime_error
(
"stride mismatch"
);
}
static
NOINLINE
void
sanity_check
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
bool
inplace
,
const
shape_t
&
axes
)
{
sanity_check
(
shape
,
stride_in
,
stride_out
,
inplace
);
auto
ndim
=
shape
.
size
();
shape_t
tmp
(
ndim
,
0
);
for
(
auto
ax
:
axes
)
{
if
(
ax
>=
ndim
)
throw
runtime_error
(
"bad axis number"
);
if
(
++
tmp
[
ax
]
>
1
)
throw
runtime_error
(
"axis specified repeatedly"
);
}
}
static
NOINLINE
void
sanity_check
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
bool
inplace
,
size_t
axis
)
{
sanity_check
(
shape
,
stride_in
,
stride_out
,
inplace
);
if
(
axis
>=
shape
.
size
())
throw
runtime_error
(
"bad axis number"
);
}
};
};
#define CH(a,b,c) ch[(a)+ido*((b)+l1*(c))]
#define CH(a,b,c) ch[(a)+ido*((b)+l1*(c))]
...
@@ -2350,6 +2386,7 @@ template<typename T> void c2c(const shape_t &shape,
...
@@ -2350,6 +2386,7 @@ template<typename T> void c2c(const shape_t &shape,
bool
forward
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
bool
forward
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
{
{
using
namespace
detail
;
using
namespace
detail
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
ndarr
<
cmplx
<
T
>>
ain
(
data_in
,
shape
,
stride_in
),
ndarr
<
cmplx
<
T
>>
ain
(
data_in
,
shape
,
stride_in
),
aout
(
data_out
,
shape
,
stride_out
);
aout
(
data_out
,
shape
,
stride_out
);
general_c
(
ain
,
aout
,
axes
,
forward
,
fct
);
general_c
(
ain
,
aout
,
axes
,
forward
,
fct
);
...
@@ -2360,16 +2397,34 @@ template<typename T> void r2c(const shape_t &shape,
...
@@ -2360,16 +2397,34 @@ template<typename T> void r2c(const shape_t &shape,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
{
{
using
namespace
detail
;
using
namespace
detail
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axis
);
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
);
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
);
ndarr
<
cmplx
<
T
>>
aout
(
data_out
,
shape
,
stride_out
);
ndarr
<
cmplx
<
T
>>
aout
(
data_out
,
shape
,
stride_out
);
general_r2c
(
ain
,
aout
,
axis
,
fct
);
general_r2c
(
ain
,
aout
,
axis
,
fct
);
}
}
template
<
typename
T
>
void
r2c
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
{
using
namespace
detail
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
r2c
(
shape
,
stride_in
,
stride_out
,
axes
.
back
(),
data_in
,
data_out
,
fct
);
if
(
axes
.
size
()
==
1
)
return
;
shape_t
shape_out
(
shape
);
shape_out
[
axes
.
back
()]
=
shape
[
axes
.
back
()]
/
2
+
1
;
auto
newaxes
=
shape_t
{
axes
.
begin
(),
--
axes
.
end
()};
c2c
(
shape_out
,
stride_out
,
stride_out
,
newaxes
,
true
,
data_out
,
data_out
,
T
(
1
));
}
template
<
typename
T
>
void
c2r
(
const
shape_t
&
shape
,
size_t
new_size
,
template
<
typename
T
>
void
c2r
(
const
shape_t
&
shape
,
size_t
new_size
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
size_t
axis
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
size_t
axis
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
{
{
using
namespace
detail
;
using
namespace
detail
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axis
);
shape_t
shape_out
(
shape
);
shape_t
shape_out
(
shape
);
shape_out
[
axis
]
=
new_size
;
shape_out
[
axis
]
=
new_size
;
ndarr
<
cmplx
<
T
>>
ain
(
data_in
,
shape
,
stride_in
);
ndarr
<
cmplx
<
T
>>
ain
(
data_in
,
shape
,
stride_in
);
...
@@ -2377,11 +2432,36 @@ template<typename T> void c2r(const shape_t &shape, size_t new_size,
...
@@ -2377,11 +2432,36 @@ template<typename T> void c2r(const shape_t &shape, size_t new_size,
general_c2r
(
ain
,
aout
,
axis
,
fct
);
general_c2r
(
ain
,
aout
,
axis
,
fct
);
}
}
template
<
typename
T
>
void
c2r
(
const
shape_t
&
shape
,
size_t
new_size
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
{
using
namespace
detail
;
if
(
axes
.
size
()
==
1
)
{
c2r
(
shape
,
new_size
,
stride_in
,
stride_out
,
axes
[
0
],
data_in
,
data_out
,
fct
);
return
;
}
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
auto
nval
=
util
::
prod
(
shape
);
stride_t
stride_inter
(
shape
.
size
());
stride_inter
.
back
()
=
sizeof
(
cmplx
<
T
>
);
for
(
int
i
=
shape
.
size
()
-
2
;
i
>=
0
;
--
i
)
stride_inter
[
i
]
=
stride_inter
[
i
+
1
]
*
shape
[
i
+
1
];
arr
<
char
>
tmp
(
nval
*
sizeof
(
cmplx
<
T
>
));
auto
newaxes
=
shape_t
({
axes
.
begin
(),
--
axes
.
end
()});
c2c
(
shape
,
stride_in
,
stride_inter
,
newaxes
,
false
,
data_in
,
tmp
.
data
(),
T
(
1
));
c2r
(
shape
,
new_size
,
stride_inter
,
stride_out
,
axes
.
back
(),
tmp
.
data
(),
data_out
,
fct
);
}
template
<
typename
T
>
void
r2r_fftpack
(
const
shape_t
&
shape
,
template
<
typename
T
>
void
r2r_fftpack
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
size_t
axis
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
size_t
axis
,
bool
forward
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
bool
forward
,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
{
{
using
namespace
detail
;
using
namespace
detail
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axis
);
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
),
aout
(
data_out
,
shape
,
stride_out
);
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
),
aout
(
data_out
,
shape
,
stride_out
);
general_r
(
ain
,
aout
,
axis
,
forward
,
fct
);
general_r
(
ain
,
aout
,
axis
,
forward
,
fct
);
}
}
...
@@ -2391,6 +2471,7 @@ template<typename T> void r2r_hartley(const shape_t &shape,
...
@@ -2391,6 +2471,7 @@ template<typename T> void r2r_hartley(const shape_t &shape,
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
const
void
*
data_in
,
void
*
data_out
,
T
fct
)
{
{
using
namespace
detail
;
using
namespace
detail
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
),
aout
(
data_out
,
shape
,
stride_out
);
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
),
aout
(
data_out
,
shape
,
stride_out
);
general_hartley
(
ain
,
aout
,
axes
,
fct
);
general_hartley
(
ain
,
aout
,
axes
,
fct
);
}
}
...
...
pypocketfft.cc
View file @
89a7dd70
...
@@ -26,7 +26,6 @@ namespace {
...
@@ -26,7 +26,6 @@ namespace {
using
namespace
std
;
using
namespace
std
;
using
namespace
pocketfft
;
using
namespace
pocketfft
;
using
namespace
pocketfft
::
detail
;
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
...
@@ -83,9 +82,8 @@ template<typename T> py::array xfftn_internal(const py::array &in,
...
@@ -83,9 +82,8 @@ template<typename T> py::array xfftn_internal(const py::array &in,
{
{
auto
dims
(
copy_shape
(
in
));
auto
dims
(
copy_shape
(
in
));
py
::
array
res
=
inplace
?
in
:
py
::
array_t
<
complex
<
T
>>
(
dims
);
py
::
array
res
=
inplace
?
in
:
py
::
array_t
<
complex
<
T
>>
(
dims
);
ndarr
<
cmplx
<
T
>>
ain
(
in
.
data
(),
dims
,
copy_strides
(
in
));
c2c
(
dims
,
copy_strides
(
in
),
copy_strides
(
res
),
axes
,
fwd
,
in
.
data
(),
ndarr
<
cmplx
<
T
>>
aout
(
res
.
mutable_data
(),
dims
,
copy_strides
(
res
));
res
.
mutable_data
(),
T
(
fct
));
general_c
<
T
>
(
ain
,
aout
,
axes
,
fwd
,
fct
);
return
res
;
return
res
;
}
}
...
@@ -108,12 +106,8 @@ template<typename T> py::array rfftn_internal(const py::array &in,
...
@@ -108,12 +106,8 @@ template<typename T> py::array rfftn_internal(const py::array &in,
auto
dims_in
(
copy_shape
(
in
)),
dims_out
(
dims_in
);
auto
dims_in
(
copy_shape
(
in
)),
dims_out
(
dims_in
);
dims_out
[
axes
.
back
()]
=
(
dims_out
[
axes
.
back
()]
>>
1
)
+
1
;
dims_out
[
axes
.
back
()]
=
(
dims_out
[
axes
.
back
()]
>>
1
)
+
1
;
py
::
array
res
=
py
::
array_t
<
complex
<
T
>>
(
dims_out
);
py
::
array
res
=
py
::
array_t
<
complex
<
T
>>
(
dims_out
);
ndarr
<
T
>
ain
(
in
.
data
(),
dims_in
,
copy_strides
(
in
));
r2c
(
dims_in
,
copy_strides
(
in
),
copy_strides
(
res
),
axes
,
in
.
data
(),
ndarr
<
cmplx
<
T
>>
aout
(
res
.
mutable_data
(),
dims_out
,
copy_strides
(
res
));
res
.
mutable_data
(),
T
(
fct
));
general_r2c
<
T
>
(
ain
,
aout
,
axes
.
back
(),
fct
);
if
(
axes
.
size
()
==
1
)
return
res
;
shape_t
axes2
(
axes
.
begin
(),
--
axes
.
end
());
general_c
<
T
>
(
aout
,
aout
,
axes2
,
true
,
1.
);
return
res
;
return
res
;
}
}
py
::
array
rfftn
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
)
py
::
array
rfftn
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
)
...
@@ -126,9 +120,8 @@ template<typename T> py::array xrfft_scipy(const py::array &in,
...
@@ -126,9 +120,8 @@ template<typename T> py::array xrfft_scipy(const py::array &in,
{
{
auto
dims
(
copy_shape
(
in
));
auto
dims
(
copy_shape
(
in
));
py
::
array
res
=
inplace
?
in
:
py
::
array_t
<
T
>
(
dims
);
py
::
array
res
=
inplace
?
in
:
py
::
array_t
<
T
>
(
dims
);
ndarr
<
T
>
ain
(
in
.
data
(),
dims
,
copy_strides
(
in
));
r2r_fftpack
(
dims
,
copy_strides
(
in
),
copy_strides
(
res
),
axis
,
fwd
,
ndarr
<
T
>
aout
(
res
.
mutable_data
(),
dims
,
copy_strides
(
res
));
in
.
data
(),
res
.
mutable_data
(),
T
(
fct
));
general_r
<
T
>
(
ain
,
aout
,
axis
,
fwd
,
fct
);
return
res
;
return
res
;
}
}
py
::
array
rfft_scipy
(
const
py
::
array
&
in
,
size_t
axis
,
double
fct
,
bool
inplace
)
py
::
array
rfft_scipy
(
const
py
::
array
&
in
,
size_t
axis
,
double
fct
,
bool
inplace
)
...
@@ -148,19 +141,15 @@ template<typename T> py::array irfftn_internal(const py::array &in,
...
@@ -148,19 +141,15 @@ template<typename T> py::array irfftn_internal(const py::array &in,
py
::
object
axes_
,
size_t
lastsize
,
T
fct
)
py
::
object
axes_
,
size_t
lastsize
,
T
fct
)
{
{
auto
axes
=
makeaxes
(
in
,
axes_
);
auto
axes
=
makeaxes
(
in
,
axes_
);
py
::
array
inter
=
(
axes
.
size
()
==
1
)
?
in
:
xfftn_internal
<
T
>
(
in
,
shape_t
(
axes
.
begin
(),
--
axes
.
end
()),
1.
,
false
,
false
);
size_t
axis
=
axes
.
back
();
size_t
axis
=
axes
.
back
();
if
(
lastsize
==
0
)
lastsize
=
2
*
inter
.
shape
(
axis
)
-
1
;
shape_t
dims_in
(
copy_shape
(
in
)),
dims_out
=
dims_in
;
if
(
ptrdiff_t
(
lastsize
/
2
)
+
1
!=
inter
.
shape
(
axis
))
if
(
lastsize
==
0
)
lastsize
=
2
*
dims_in
[
axis
]
-
1
;
if
((
lastsize
/
2
)
+
1
!=
dims_in
[
axis
])
throw
runtime_error
(
"bad lastsize"
);
throw
runtime_error
(
"bad lastsize"
);
auto
dims_out
(
copy_shape
(
inter
));
dims_out
[
axis
]
=
lastsize
;
dims_out
[
axis
]
=
lastsize
;
py
::
array
res
=
py
::
array_t
<
T
>
(
dims_out
);
py
::
array
res
=
py
::
array_t
<
T
>
(
dims_out
);
ndarr
<
cmplx
<
T
>>
ain
(
inter
.
data
(),
copy_shape
(
inter
),
copy_strides
(
inter
));
c2r
(
dims_in
,
lastsize
,
copy_strides
(
in
),
copy_strides
(
res
),
axes
,
ndarr
<
T
>
aout
(
res
.
mutable_data
(),
dims_out
,
copy_strides
(
res
));
in
.
data
(),
res
.
mutable_data
(),
T
(
fct
));
general_c2r
<
T
>
(
ain
,
aout
,
axis
,
fct
);
return
res
;
return
res
;
}
}
py
::
array
irfftn
(
const
py
::
array
&
in
,
py
::
object
axes_
,
size_t
lastsize
,
py
::
array
irfftn
(
const
py
::
array
&
in
,
py
::
object
axes_
,
size_t
lastsize
,
...
@@ -176,9 +165,8 @@ template<typename T> py::array hartley_internal(const py::array &in,
...
@@ -176,9 +165,8 @@ template<typename T> py::array hartley_internal(const py::array &in,
{
{
auto
dims
(
copy_shape
(
in
));
auto
dims
(
copy_shape
(
in
));
py
::
array
res
=
inplace
?
in
:
py
::
array_t
<
T
>
(
dims
);
py
::
array
res
=
inplace
?
in
:
py
::
array_t
<
T
>
(
dims
);
ndarr
<
T
>
ain
(
in
.
data
(),
copy_shape
(
in
),
copy_strides
(
in
));
r2r_hartley
(
dims
,
copy_strides
(
in
),
copy_strides
(
res
),
makeaxes
(
in
,
axes_
),
ndarr
<
T
>
aout
(
res
.
mutable_data
(),
copy_shape
(
res
),
copy_strides
(
res
));
in
.
data
(),
res
.
mutable_data
(),
T
(
fct
));
general_hartley
<
T
>
(
ain
,
aout
,
makeaxes
(
in
,
axes_
),
fct
);
return
res
;
return
res
;
}
}
py
::
array
hartley
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
,
py
::
array
hartley
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
,
...
@@ -192,6 +180,7 @@ py::array hartley(const py::array &in, py::object axes_, double fct,
...
@@ -192,6 +180,7 @@ py::array hartley(const py::array &in, py::object axes_, double fct,
template
<
typename
T
>
py
::
array
complex2hartley
(
const
py
::
array
&
in
,
template
<
typename
T
>
py
::
array
complex2hartley
(
const
py
::
array
&
in
,
const
py
::
array
&
tmp
,
py
::
object
axes_
,
bool
inplace
)
const
py
::
array
&
tmp
,
py
::
object
axes_
,
bool
inplace
)
{
{
using
namespace
pocketfft
::
detail
;
int
ndim
=
in
.
ndim
();
int
ndim
=
in
.
ndim
();
auto
dims_out
(
copy_shape
(
in
));
auto
dims_out
(
copy_shape
(
in
));
py
::
array
out
=
inplace
?
in
:
py
::
array_t
<
T
>
(
dims_out
);
py
::
array
out
=
inplace
?
in
:
py
::
array_t
<
T
>
(
dims_out
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment