Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Martin Reinecke
pypocketfft
Commits
320b7b8d
Commit
320b7b8d
authored
Apr 16, 2019
by
Martin Reinecke
Browse files
cleanups
parent
06d75f2c
Changes
1
Hide whitespace changes
Inline
Side-by-side
pocketfft.cc
View file @
320b7b8d
...
...
@@ -4,6 +4,8 @@
#include
<iostream>
#include
<memory>
#include
<vector>
#include
<pybind11/pybind11.h>
#include
<pybind11/numpy.h>
#ifdef __GNUC__
#define NOINLINE __attribute__((noinline))
...
...
@@ -1047,14 +1049,14 @@ template<typename T0> class pocketfft_c
};
struct
diminfo
{
size_t
n
,
s
;
};
{
size_t
n
;
int64_t
s
;
};
class
multiarr
{
private:
vector
<
diminfo
>
dim
;
public:
multiarr
(
size_t
ndim_
,
const
size_t
*
n
,
const
size
_t
*
s
)
multiarr
(
size_t
ndim_
,
const
size_t
*
n
,
const
int64
_t
*
s
)
{
dim
.
reserve
(
ndim_
);
for
(
size_t
i
=
0
;
i
<
ndim_
;
++
i
)
...
...
@@ -1062,7 +1064,7 @@ class multiarr
}
size_t
ndim
()
const
{
return
dim
.
size
();
}
size_t
size
(
size_t
i
)
const
{
return
dim
[
i
].
n
;
}
size
_t
stride
(
size_t
i
)
const
{
return
dim
[
i
].
s
;
}
int64
_t
stride
(
size_t
i
)
const
{
return
dim
[
i
].
s
;
}
};
class
multi_iter
...
...
@@ -1070,7 +1072,8 @@ class multi_iter
private:
vector
<
diminfo
>
dim
;
vector
<
size_t
>
pos
;
size_t
ofs_
,
len
,
str
;
size_t
ofs_
,
len
;
int64_t
str
;
bool
done_
;
public:
...
...
@@ -1102,11 +1105,11 @@ class multi_iter
bool
done
()
const
{
return
done_
;
}
size_t
offset
()
const
{
return
ofs_
;
}
size_t
length
()
const
{
return
len
;
}
size
_t
stride
()
const
{
return
str
;
}
int64
_t
stride
()
const
{
return
str
;
}
};
template
<
typename
T
>
void
pocketfft_general_c
(
int
ndim
,
const
size_t
*
shape
,
const
size
_t
*
stride_in
,
const
size
_t
*
stride_out
,
int
nax
,
const
int64
_t
*
stride_in
,
const
int64
_t
*
stride_out
,
int
nax
,
const
size_t
*
axes
,
bool
forward
,
const
cmplx
<
T
>
*
data_in
,
cmplx
<
T
>
*
data_out
,
T
fct
)
{
...
...
@@ -1115,7 +1118,7 @@ template<typename T> void pocketfft_general_c(int ndim, const size_t *shape,
for
(
int
iax
=
0
;
iax
<
nax
;
++
iax
)
{
bool
inplace
=
false
;
size
_t
stride
=
(
iax
==
0
)
?
stride_in
[
axes
[
iax
]]
:
stride_out
[
axes
[
iax
]];
int64
_t
stride
=
(
iax
==
0
)
?
stride_in
[
axes
[
iax
]]
:
stride_out
[
axes
[
iax
]];
if
(
stride
==
1
)
if
((
iax
>
0
)
||
(
data_in
==
data_out
))
inplace
=
true
;
...
...
@@ -1158,37 +1161,38 @@ template<typename T> void pocketfft_general_c(int ndim, const size_t *shape,
a_in
=
a_out
;
data_in
=
data_out
;
// factor has been applied, use 1 for remaining axes
fct
=
1.
;
fct
=
T
(
1
)
;
}
}
}
// unnamed namespace
#if 1
#include
<pybind11/pybind11.h>
#include
<pybind11/numpy.h>
namespace
py
=
pybind11
;
namespace
{
auto
c64
=
py
::
dtype
(
"complex64"
);
auto
c128
=
py
::
dtype
(
"complex128"
);
py
::
array
execute
(
const
py
::
array
&
in
,
bool
fwd
)
{
py
::
array
res
;
vector
<
size_t
>
dims
(
in
.
ndim
()),
s_i
(
in
.
ndim
()),
s_o
(
in
.
ndim
()),
axes
(
in
.
ndim
());
vector
<
size_t
>
dims
(
in
.
ndim
()),
axes
(
in
.
ndim
());
vector
<
int64_t
>
s_i
(
in
.
ndim
()),
s_o
(
in
.
ndim
());
for
(
int
i
=
0
;
i
<
in
.
ndim
();
++
i
)
dims
[
i
]
=
in
.
shape
(
i
);
if
(
in
.
dtype
().
is
(
py
::
dtype
(
"complex
128
"
)
))
if
(
in
.
dtype
().
is
(
c
128
))
res
=
py
::
array_t
<
complex
<
double
>>
(
dims
);
else
if
(
in
.
dtype
().
is
(
py
::
dtype
(
"complex64"
)
))
else
if
(
in
.
dtype
().
is
(
c64
))
res
=
py
::
array_t
<
complex
<
float
>>
(
dims
);
else
throw
runtime_error
(
"unsupported data type"
);
for
(
int
i
=
0
;
i
<
in
.
ndim
();
++
i
)
{
s_i
[
i
]
=
in
.
strides
(
i
)
/
in
.
itemsize
();
if
(
s_i
[
i
]
*
in
.
itemsize
()
!=
in
.
strides
(
i
))
throw
runtime_error
(
"weird strides"
);
s_o
[
i
]
=
res
.
strides
(
i
)
/
res
.
itemsize
();
if
(
s_o
[
i
]
*
res
.
itemsize
()
!=
res
.
strides
(
i
))
throw
runtime_error
(
"weird strides"
);
axes
[
i
]
=
i
;
}
if
(
in
.
dtype
().
is
(
py
::
dtype
(
"complex
128
"
)
))
if
(
in
.
dtype
().
is
(
c
128
))
pocketfft_general_c
<
double
>
(
in
.
ndim
(),
dims
.
data
(),
s_i
.
data
(),
s_o
.
data
(),
in
.
ndim
(),
axes
.
data
(),
fwd
,
(
const
cmplx
<
double
>
*
)
in
.
data
(),
...
...
@@ -1211,19 +1215,3 @@ PYBIND11_MODULE(pypocketfft, m)
m
.
def
(
"fftn"
,
&
fftn
);
m
.
def
(
"ifftn"
,
&
ifftn
);
}
#else
int
main
()
{
int
sz
=
100
;
vector
<
cmplx
<
double
>>
data
(
sz
),
data2
(
sz
);
size_t
shape
[]
=
{
sz
};
size_t
stride
[]
=
{
1
};
size_t
axes
[]
=
{
0
};
pocketfft_general_c
(
1
,
shape
,
stride
,
stride
,
1
,
axes
,
true
,
data
.
data
(),
data2
.
data
(),
1.
);
}
#endif
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment