Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Martin Reinecke
pypocketfft
Commits
9dae4cdd
Commit
9dae4cdd
authored
Apr 17, 2019
by
Martin Reinecke
Browse files
more
parent
92cfcfdd
Changes
1
Hide whitespace changes
Inline
Side-by-side
pocketfft.cc
View file @
9dae4cdd
...
...
@@ -6,6 +6,7 @@
#include
<vector>
#include
<pybind11/pybind11.h>
#include
<pybind11/numpy.h>
#include
<pybind11/stl.h>
#ifdef __GNUC__
#define NOINLINE __attribute__((noinline))
...
...
@@ -2064,6 +2065,76 @@ template<typename T> void pocketfft_general_r(int ndim, const size_t *shape,
it_out
.
advance
();
}
}
template
<
typename
T
>
void
pocketfft_general_r2c
(
int
ndim
,
const
size_t
*
shape
,
const
int64_t
*
stride_in
,
const
int64_t
*
stride_out
,
size_t
axis
,
const
T
*
data_in
,
cmplx
<
T
>
*
data_out
,
T
fct
)
{
// allocate temporary 1D array storage
arr
<
T
>
tdata
(
shape
[
axis
]);
multiarr
a_in
(
ndim
,
shape
,
stride_in
),
a_out
(
ndim
,
shape
,
stride_out
);
pocketfft_r
<
T
>
plan
(
shape
[
axis
]);
multi_iter
it_in
(
a_in
,
axis
),
it_out
(
a_out
,
axis
);
size_t
len
=
shape
[
axis
],
s_i
=
it_in
.
stride
(),
s_o
=
it_out
.
stride
();
while
(
!
it_in
.
done
())
{
const
T
*
d_i
=
data_in
+
it_in
.
offset
();
cmplx
<
T
>
*
d_o
=
data_out
+
it_out
.
offset
();
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
tdata
[
i
]
=
d_i
[
i
*
s_i
];
plan
.
forward
(
tdata
.
data
(),
fct
);
d_o
[
0
].
r
=
tdata
[
0
];
d_o
[
0
].
i
=
0.
;
size_t
i
;
for
(
i
=
1
;
i
<
len
-
1
;
i
+=
2
)
{
size_t
io
=
(
i
+
1
)
/
2
;
d_o
[
io
*
s_o
].
r
=
tdata
[
i
];
d_o
[
io
*
s_o
].
i
=
tdata
[
i
+
1
];
}
if
(
i
<
len
)
{
size_t
io
=
(
i
+
1
)
/
2
;
d_o
[
io
*
s_o
].
r
=
tdata
[
i
];
d_o
[
io
*
s_o
].
i
=
0.
;
}
it_in
.
advance
();
it_out
.
advance
();
}
}
template
<
typename
T
>
void
pocketfft_general_c2r
(
int
ndim
,
const
size_t
*
shape_out
,
const
int64_t
*
stride_in
,
const
int64_t
*
stride_out
,
size_t
axis
,
const
cmplx
<
T
>
*
data_in
,
T
*
data_out
,
T
fct
)
{
// allocate temporary 1D array storage
arr
<
T
>
tdata
(
shape_out
[
axis
]);
multiarr
a_in
(
ndim
,
shape_out
,
stride_in
),
a_out
(
ndim
,
shape_out
,
stride_out
);
pocketfft_r
<
T
>
plan
(
shape_out
[
axis
]);
multi_iter
it_in
(
a_in
,
axis
),
it_out
(
a_out
,
axis
);
size_t
len
=
shape_out
[
axis
],
s_i
=
it_in
.
stride
(),
s_o
=
it_out
.
stride
();
while
(
!
it_in
.
done
())
{
const
cmplx
<
T
>
*
d_i
=
data_in
+
it_in
.
offset
();
T
*
d_o
=
data_out
+
it_out
.
offset
();
tdata
[
0
]
=
d_i
[
0
].
r
;
size_t
i
;
for
(
i
=
1
;
i
<
len
-
1
;
i
+=
2
)
{
size_t
ii
=
(
i
+
1
)
/
2
;
tdata
[
i
]
=
d_i
[
ii
*
s_i
].
r
;
tdata
[
i
+
1
]
=
d_i
[
ii
*
s_i
].
i
;
}
if
(
i
<
len
)
{
size_t
ii
=
(
i
+
1
)
/
2
;
tdata
[
i
]
=
d_i
[
ii
*
s_i
].
r
;
}
plan
.
backward
(
tdata
.
data
(),
fct
);
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
d_o
[
i
*
s_o
]
=
tdata
[
i
];
it_in
.
advance
();
it_out
.
advance
();
}
}
namespace
py
=
pybind11
;
...
...
@@ -2072,10 +2143,24 @@ auto c128 = py::dtype("complex128");
auto
f32
=
py
::
dtype
(
"float32"
);
auto
f64
=
py
::
dtype
(
"float64"
);
py
::
array
execute
(
const
py
::
array
&
in
,
bool
fwd
)
void
check_args
(
const
py
::
array
&
in
,
vector
<
size_t
>
&
axes
)
{
if
(
axes
.
size
()
==
0
)
axes
.
resize
(
in
.
ndim
());
for
(
int
i
=
0
;
i
<
in
.
ndim
();
++
i
)
axes
[
i
]
=
i
;
if
(
axes
.
size
()
>
size_t
(
in
.
ndim
()))
throw
runtime_error
(
"bad axes argument"
);
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
if
(
axes
[
i
]
>=
size_t
(
in
.
ndim
()))
throw
runtime_error
(
"invalid axis number"
);
}
py
::
array
execute
(
const
py
::
array
&
in
,
vector
<
size_t
>
axes
,
bool
fwd
)
{
check_args
(
in
,
axes
);
py
::
array
res
;
vector
<
size_t
>
dims
(
in
.
ndim
())
,
axes
(
in
.
ndim
())
;
vector
<
size_t
>
dims
(
in
.
ndim
());
vector
<
int64_t
>
s_i
(
in
.
ndim
()),
s_o
(
in
.
ndim
());
for
(
int
i
=
0
;
i
<
in
.
ndim
();
++
i
)
dims
[
i
]
=
in
.
shape
(
i
);
...
...
@@ -2092,36 +2177,39 @@ py::array execute(const py::array &in, bool fwd)
s_o
[
i
]
=
res
.
strides
(
i
)
/
res
.
itemsize
();
if
(
s_o
[
i
]
*
res
.
itemsize
()
!=
res
.
strides
(
i
))
throw
runtime_error
(
"weird strides"
);
axes
[
i
]
=
i
;
}
if
(
in
.
dtype
().
is
(
c128
))
pocketfft_general_c
<
double
>
(
in
.
ndim
(),
dims
.
data
(),
s_i
.
data
(),
s_o
.
data
(),
in
.
ndim
(),
s_i
.
data
(),
s_o
.
data
(),
axes
.
size
(),
axes
.
data
(),
fwd
,
(
const
cmplx
<
double
>
*
)
in
.
data
(),
(
cmplx
<
double
>
*
)
res
.
mutable_data
(),
1.
);
else
pocketfft_general_c
<
float
>
(
in
.
ndim
(),
dims
.
data
(),
s_i
.
data
(),
s_o
.
data
(),
in
.
ndim
(),
s_i
.
data
(),
s_o
.
data
(),
axes
.
size
(),
axes
.
data
(),
fwd
,
(
const
cmplx
<
float
>
*
)
in
.
data
(),
(
cmplx
<
float
>
*
)
res
.
mutable_data
(),
1.
);
return
res
;
}
py
::
array
fftn
(
const
py
::
array
&
in
)
{
return
execute
(
in
,
true
);
}
py
::
array
ifftn
(
const
py
::
array
&
in
)
{
return
execute
(
in
,
false
);
}
py
::
array
fftn
(
const
py
::
array
&
a
,
const
vector
<
size_t
>
&
axes
)
{
return
execute
(
a
,
axes
,
true
);
}
py
::
array
ifftn
(
const
py
::
array
&
a
,
const
vector
<
size_t
>
&
axes
)
{
return
execute
(
a
,
axes
,
false
);
}
py
::
array
execute_r
(
const
py
::
array
&
in
,
int
axis
,
bool
fwd
)
py
::
array
rfftn
(
const
py
::
array
&
in
,
int
axis
)
{
py
::
array
res
;
vector
<
size_t
>
dims
(
in
.
ndim
());
vector
<
size_t
>
dims
_in
(
in
.
ndim
()),
dims_out
(
in
.
ndim
());
vector
<
int64_t
>
s_i
(
in
.
ndim
()),
s_o
(
in
.
ndim
());
for
(
int
i
=
0
;
i
<
in
.
ndim
();
++
i
)
dims
[
i
]
=
in
.
shape
(
i
);
{
dims_in
[
i
]
=
in
.
shape
(
i
);
dims_out
[
i
]
=
in
.
shape
(
i
);
}
dims_out
[
axis
]
=
(
dims_out
[
axis
]
>>
1
)
+
1
;
if
(
in
.
dtype
().
is
(
f64
))
res
=
py
::
array_t
<
double
>
(
dims
);
res
=
py
::
array_t
<
complex
<
double
>
>
(
dims
_out
);
else
if
(
in
.
dtype
().
is
(
f32
))
res
=
py
::
array_t
<
float
>
(
dims
);
res
=
py
::
array_t
<
complex
<
float
>
>
(
dims
_out
);
else
throw
runtime_error
(
"unsupported data type"
);
for
(
int
i
=
0
;
i
<
in
.
ndim
();
++
i
)
{
...
...
@@ -2133,27 +2221,71 @@ py::array execute_r(const py::array &in, int axis, bool fwd)
throw
runtime_error
(
"weird strides"
);
}
if
(
in
.
dtype
().
is
(
f64
))
pocketfft_general_r
<
double
>
(
in
.
ndim
(),
dims
.
data
(),
pocketfft_general_r2c
<
double
>
(
in
.
ndim
(),
dims_in
.
data
(),
s_i
.
data
(),
s_o
.
data
(),
axis
,
(
const
double
*
)
in
.
data
(),
(
cmplx
<
double
>
*
)
res
.
mutable_data
(),
1.
);
else
pocketfft_general_r2c
<
float
>
(
in
.
ndim
(),
dims_in
.
data
(),
s_i
.
data
(),
s_o
.
data
(),
axis
,
fwd
,
(
const
double
*
)
in
.
data
(),
axis
,
(
const
float
*
)
in
.
data
(),
(
cmplx
<
float
>
*
)
res
.
mutable_data
(),
1.
);
return
res
;
}
py
::
array
irfftn
(
const
py
::
array
&
in
,
int
axis
,
int
osize
)
{
py
::
array
res
;
if
(
osize
/
2
+
1
!=
in
.
shape
(
axis
))
throw
runtime_error
(
"bad output size"
);
vector
<
size_t
>
dims_in
(
in
.
ndim
()),
dims_out
(
in
.
ndim
());
vector
<
int64_t
>
s_i
(
in
.
ndim
()),
s_o
(
in
.
ndim
());
for
(
int
i
=
0
;
i
<
in
.
ndim
();
++
i
)
{
dims_in
[
i
]
=
in
.
shape
(
i
);
dims_out
[
i
]
=
in
.
shape
(
i
);
}
dims_out
[
axis
]
=
osize
;
if
(
in
.
dtype
().
is
(
c128
))
res
=
py
::
array_t
<
double
>
(
dims_out
);
else
if
(
in
.
dtype
().
is
(
c64
))
res
=
py
::
array_t
<
float
>
(
dims_out
);
else
throw
runtime_error
(
"unsupported data type"
);
for
(
int
i
=
0
;
i
<
in
.
ndim
();
++
i
)
{
s_i
[
i
]
=
in
.
strides
(
i
)
/
in
.
itemsize
();
if
(
s_i
[
i
]
*
in
.
itemsize
()
!=
in
.
strides
(
i
))
throw
runtime_error
(
"weird strides"
);
s_o
[
i
]
=
res
.
strides
(
i
)
/
res
.
itemsize
();
if
(
s_o
[
i
]
*
res
.
itemsize
()
!=
res
.
strides
(
i
))
throw
runtime_error
(
"weird strides"
);
}
if
(
in
.
dtype
().
is
(
c128
))
pocketfft_general_c2r
<
double
>
(
in
.
ndim
(),
dims_out
.
data
(),
s_i
.
data
(),
s_o
.
data
(),
axis
,
(
const
cmplx
<
double
>
*
)
in
.
data
(),
(
double
*
)
res
.
mutable_data
(),
1.
);
else
pocketfft_general_r
<
float
>
(
in
.
ndim
(),
dims
.
data
(),
pocketfft_general_
c2
r
<
float
>
(
in
.
ndim
(),
dims
_out
.
data
(),
s_i
.
data
(),
s_o
.
data
(),
axis
,
fwd
,
(
const
float
*
)
in
.
data
(),
axis
,
(
const
cmplx
<
float
>
*
)
in
.
data
(),
(
float
*
)
res
.
mutable_data
(),
1.
);
return
res
;
}
py
::
array
rfftn
(
const
py
::
array
&
in
,
int
axis
)
{
return
execute_r
(
in
,
axis
,
true
);
}
py
::
array
irfftn
(
const
py
::
array
&
in
,
int
axis
)
{
return
execute_r
(
in
,
axis
,
false
);
}
void
xtest
(
const
py
::
array
&
a
,
const
vector
<
size_t
>
&
s
,
const
vector
<
size_t
>
&
axes
,
const
string
&
norm
)
{
// for (size_t i=0; i<data.size(); ++i)
// cout << data[i] << endl;
}
}
// unnamed namespace
PYBIND11_MODULE
(
pypocketfft
,
m
)
{
m
.
def
(
"fftn"
,
&
fftn
);
m
.
def
(
"ifftn"
,
&
ifftn
);
using
namespace
pybind11
::
literals
;
m
.
def
(
"fftn"
,
&
fftn
,
"a"
_a
,
"axes"
_a
=
vector
<
size_t
>
());
m
.
def
(
"ifftn"
,
&
ifftn
,
"a"
_a
,
"axes"
_a
=
vector
<
size_t
>
());
m
.
def
(
"rfftn"
,
&
rfftn
);
m
.
def
(
"irfftn"
,
&
irfftn
);
m
.
def
(
"xtest"
,
&
xtest
,
"a"
_a
,
"s"
_a
=
vector
<
size_t
>
(),
"axes"
_a
=
vector
<
size_t
>
(),
"norm"
_a
=
""
);
}
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment