Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
P
pypocketfft
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
4
Issues
4
List
Boards
Labels
Service Desk
Milestones
Merge Requests
2
Merge Requests
2
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Martin Reinecke
pypocketfft
Commits
97730839
Commit
97730839
authored
Apr 23, 2019
by
Martin Reinecke
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
cleanup
parent
d2dd06b3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
63 deletions
+52
-63
pocketfft.cc
pocketfft.cc
+52
-63
No files found.
pocketfft.cc
View file @
97730839
...
...
@@ -2359,6 +2359,17 @@ vector<size_t> copy_shape(const py::array &arr)
res
[
i
]
=
arr
.
shape
(
i
);
return
res
;
}
vector
<
int64_t
>
copy_strides
(
const
py
::
array
&
arr
)
{
vector
<
int64_t
>
res
(
arr
.
ndim
());
for
(
size_t
i
=
0
;
i
<
res
.
size
();
++
i
)
{
res
[
i
]
=
arr
.
strides
(
i
)
/
arr
.
itemsize
();
if
(
res
[
i
]
*
arr
.
itemsize
()
!=
arr
.
strides
(
i
))
throw
runtime_error
(
"weird strides"
);
}
return
res
;
}
vector
<
size_t
>
makeaxes
(
const
py
::
array
&
in
,
py
::
object
axes
)
{
...
...
@@ -2377,45 +2388,33 @@ vector<size_t> makeaxes(const py::array &in, py::object axes)
throw
runtime_error
(
"invalid axis number"
);
return
tmp
;
}
void
make_strides
(
const
py
::
array
&
in
,
const
py
::
array
&
out
,
vector
<
int64_t
>
&
stride_in
,
vector
<
int64_t
>
&
stride_out
)
{
for
(
size_t
i
=
0
;
i
<
stride_in
.
size
();
++
i
)
{
stride_in
[
i
]
=
in
.
strides
(
i
)
/
in
.
itemsize
();
if
(
stride_in
[
i
]
*
in
.
itemsize
()
!=
in
.
strides
(
i
))
throw
runtime_error
(
"weird strides"
);
stride_out
[
i
]
=
out
.
strides
(
i
)
/
out
.
itemsize
();
if
(
stride_out
[
i
]
*
out
.
itemsize
()
!=
out
.
strides
(
i
))
throw
runtime_error
(
"weird strides"
);
}
}
py
::
array
execute
(
const
py
::
array
&
in
,
vector
<
size_t
>
axes
,
double
fct
,
bool
fwd
)
template
<
typename
T
>
py
::
array
execute
(
const
py
::
array
&
in
,
const
vector
<
size_t
>
&
axes
,
double
fct
,
bool
fwd
)
{
py
::
array
res
;
auto
dims
(
copy_shape
(
in
));
vector
<
int64_t
>
s_i
(
in
.
ndim
()),
s_o
(
in
.
ndim
());
if
(
in
.
dtype
().
is
(
c128
))
res
=
py
::
array_t
<
complex
<
double
>>
(
dims
);
else
if
(
in
.
dtype
().
is
(
c64
))
res
=
py
::
array_t
<
complex
<
float
>>
(
dims
);
else
throw
runtime_error
(
"unsupported data type"
);
make_strides
(
in
,
res
,
s_i
,
s_o
);
if
(
in
.
dtype
().
is
(
c128
))
pocketfft_general_c
<
double
>
(
dims
,
s_i
,
s_o
,
axes
,
fwd
,
(
const
cmplx
<
double
>
*
)
in
.
data
(),
(
cmplx
<
double
>
*
)
res
.
mutable_data
(),
fct
);
else
pocketfft_general_c
<
float
>
(
dims
,
s_i
,
s_o
,
axes
,
fwd
,
(
const
cmplx
<
float
>
*
)
in
.
data
(),
(
cmplx
<
float
>
*
)
res
.
mutable_data
(),
fct
);
py
::
array
res
=
py
::
array_t
<
complex
<
T
>>
(
dims
);
auto
s_i
(
copy_strides
(
in
)),
s_o
(
copy_strides
(
res
));
pocketfft_general_c
<
T
>
(
dims
,
s_i
,
s_o
,
axes
,
fwd
,
(
const
cmplx
<
T
>
*
)
in
.
data
(),
(
cmplx
<
T
>
*
)
res
.
mutable_data
(),
fct
);
return
res
;
}
py
::
array
fftn
(
const
py
::
array
&
a
,
py
::
object
axes
,
double
fct
)
{
return
execute
(
a
,
makeaxes
(
a
,
axes
),
fct
,
true
);
}
{
if
(
a
.
dtype
().
is
(
c128
))
return
execute
<
double
>
(
a
,
makeaxes
(
a
,
axes
),
fct
,
true
);
else
if
(
a
.
dtype
().
is
(
c64
))
return
execute
<
float
>
(
a
,
makeaxes
(
a
,
axes
),
fct
,
true
);
throw
runtime_error
(
"unsupported data type"
);
}
py
::
array
ifftn
(
const
py
::
array
&
a
,
py
::
object
axes
,
double
fct
)
{
return
execute
(
a
,
makeaxes
(
a
,
axes
),
fct
,
false
);
}
{
if
(
a
.
dtype
().
is
(
c128
))
return
execute
<
double
>
(
a
,
makeaxes
(
a
,
axes
),
fct
,
false
);
else
if
(
a
.
dtype
().
is
(
c64
))
return
execute
<
float
>
(
a
,
makeaxes
(
a
,
axes
),
fct
,
false
);
throw
runtime_error
(
"unsupported data type"
);
}
template
<
typename
T
>
py
::
array
rfftn_internal
(
const
py
::
array
&
in
,
py
::
object
axes_
,
T
fct
)
...
...
@@ -2423,9 +2422,8 @@ template<typename T> py::array rfftn_internal(const py::array &in,
auto
axes
=
makeaxes
(
in
,
axes_
);
auto
dims_in
(
copy_shape
(
in
)),
dims_out
(
dims_in
);
dims_out
[
axes
.
back
()]
=
(
dims_out
[
axes
.
back
()]
>>
1
)
+
1
;
vector
<
int64_t
>
s_i
(
in
.
ndim
()),
s_o
(
in
.
ndim
());
py
::
array
res
=
py
::
array_t
<
complex
<
T
>>
(
dims_out
);
make_strides
(
in
,
res
,
s_i
,
s_o
);
auto
s_i
(
copy_strides
(
in
)),
s_o
(
copy_strides
(
res
)
);
pocketfft_general_r2c
<
T
>
(
dims_in
,
s_i
,
s_o
,
axes
.
back
(),
(
const
T
*
)
in
.
data
(),
(
cmplx
<
T
>
*
)
res
.
mutable_data
(),
fct
);
if
(
axes
.
size
()
==
1
)
return
res
;
...
...
@@ -2433,8 +2431,7 @@ template<typename T> py::array rfftn_internal(const py::array &in,
for
(
size_t
i
=
0
;
i
<
axes2
.
size
();
++
i
)
axes2
[
i
]
=
axes
[
i
];
pocketfft_general_c
<
T
>
(
dims_out
,
s_o
,
s_o
,
axes2
,
true
,
(
const
cmplx
<
T
>
*
)
res
.
data
(),
(
cmplx
<
T
>
*
)
res
.
mutable_data
(),
1.
);
(
const
cmplx
<
T
>
*
)
res
.
data
(),
(
cmplx
<
T
>
*
)
res
.
mutable_data
(),
1.
);
return
res
;
}
py
::
array
rfftn
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
)
...
...
@@ -2445,8 +2442,8 @@ py::array rfftn(const py::array &in, py::object axes_, double fct)
return
rfftn_internal
<
float
>
(
in
,
axes_
,
fct
);
else
throw
runtime_error
(
"unsupported data type"
);
}
py
::
array
irfftn
(
const
py
::
array
&
in
,
py
::
object
axes_
,
size_t
lastsize
,
double
fct
)
template
<
typename
T
>
py
::
array
irfftn_internal
(
const
py
::
array
&
in
,
py
::
object
axes_
,
size_t
lastsize
,
T
fct
)
{
auto
axes
=
makeaxes
(
in
,
axes_
);
py
::
array
inter
;
...
...
@@ -2455,7 +2452,7 @@ py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
vector
<
size_t
>
axes2
(
axes
.
size
()
-
1
);
for
(
size_t
i
=
0
;
i
<
axes2
.
size
();
++
i
)
axes2
[
i
]
=
axes
[
i
];
inter
=
execute
(
in
,
axes2
,
1.
,
false
);
inter
=
execute
<
T
>
(
in
,
axes2
,
1.
,
false
);
}
else
inter
=
in
;
...
...
@@ -2465,27 +2462,22 @@ py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
if
(
int64_t
(
lastsize
/
2
)
+
1
!=
inter
.
shape
(
axis
))
throw
runtime_error
(
"bad lastsize"
);
auto
dims_in
(
copy_shape
(
inter
)),
dims_out
(
dims_in
);
vector
<
int64_t
>
s_i
(
inter
.
ndim
()),
s_o
(
inter
.
ndim
());
dims_out
[
axis
]
=
lastsize
;
py
::
array
res
;
if
(
inter
.
dtype
().
is
(
c128
))
res
=
py
::
array_t
<
double
>
(
dims_out
);
else
if
(
inter
.
dtype
().
is
(
c64
))
res
=
py
::
array_t
<
float
>
(
dims_out
);
else
throw
runtime_error
(
"unsupported data type"
);
make_strides
(
inter
,
res
,
s_i
,
s_o
);
if
(
inter
.
dtype
().
is
(
c128
))
pocketfft_general_c2r
<
double
>
(
dims_out
,
s_i
,
s_o
,
axis
,
(
const
cmplx
<
double
>
*
)
inter
.
data
(),
(
double
*
)
res
.
mutable_data
(),
fct
);
else
pocketfft_general_c2r
<
float
>
(
dims_out
,
s_i
,
s_o
,
axis
,
(
const
cmplx
<
float
>
*
)
inter
.
data
(),
(
float
*
)
res
.
mutable_data
(),
fct
);
py
::
array
res
=
py
::
array_t
<
T
>
(
dims_out
);
auto
s_i
(
copy_strides
(
inter
)),
s_o
(
copy_strides
(
res
));
pocketfft_general_c2r
<
T
>
(
dims_out
,
s_i
,
s_o
,
axis
,
(
const
cmplx
<
T
>
*
)
inter
.
data
(),
(
T
*
)
res
.
mutable_data
(),
fct
);
return
res
;
}
py
::
array
irfftn
(
const
py
::
array
&
in
,
py
::
object
axes_
,
size_t
lastsize
,
double
fct
)
{
if
(
in
.
dtype
().
is
(
c128
))
return
irfftn_internal
<
double
>
(
in
,
axes_
,
lastsize
,
fct
);
else
if
(
in
.
dtype
().
is
(
c64
))
return
irfftn_internal
<
float
>
(
in
,
axes_
,
lastsize
,
fct
);
throw
runtime_error
(
"unsupported data type"
);
}
template
<
typename
T
>
py
::
array
hartley_internal
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
)
...
...
@@ -2493,10 +2485,8 @@ template<typename T> py::array hartley_internal(const py::array &in,
auto
axes
=
makeaxes
(
in
,
axes_
);
auto
dims
(
copy_shape
(
in
));
py
::
array
res
=
py
::
array_t
<
T
>
(
dims
);
vector
<
int64_t
>
s_i
(
in
.
ndim
()),
s_o
(
in
.
ndim
());
make_strides
(
in
,
res
,
s_i
,
s_o
);
pocketfft_general_hartley
<
T
>
(
dims
,
s_i
,
s_o
,
axes
,
(
const
T
*
)
in
.
data
(),
auto
s_i
(
copy_strides
(
in
)),
s_o
(
copy_strides
(
res
));
pocketfft_general_hartley
<
T
>
(
dims
,
s_i
,
s_o
,
axes
,
(
const
T
*
)
in
.
data
(),
(
T
*
)
res
.
mutable_data
(),
1.
);
return
res
;
}
...
...
@@ -2516,8 +2506,7 @@ template<typename T>py::array complex2hartley(const py::array &in,
auto
dims_out
(
copy_shape
(
in
));
py
::
array
out
=
py
::
array_t
<
T
>
(
dims_out
);
auto
dims_tmp
(
copy_shape
(
tmp
));
vector
<
int64_t
>
stride_tmp
(
ndim
),
stride_out
(
ndim
);
make_strides
(
tmp
,
out
,
stride_tmp
,
stride_out
);
auto
stride_tmp
(
copy_strides
(
tmp
)),
stride_out
(
copy_strides
(
out
));
auto
axes
=
makeaxes
(
in
,
axes_
);
size_t
axis
=
axes
.
back
();
multiarr
a_tmp
(
dims_tmp
,
stride_tmp
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment