Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Martin Reinecke
pypocketfft
Commits
02db4525
Commit
02db4525
authored
Apr 23, 2019
by
Martin Reinecke
Browse files
structure
parent
97730839
Changes
1
Hide whitespace changes
Inline
Side-by-side
pocketfft.cc
View file @
02db4525
...
...
@@ -72,6 +72,10 @@ template<typename T> struct arr
size_t
size
()
const
{
return
sz
;
}
};
//
// twiddle factor section
//
class
sincos_2pibyn
{
private:
...
...
@@ -357,6 +361,10 @@ template<typename T> void ROTM90(cmplx<T> &a)
constexpr
size_t
NFCT
=
25
;
//
// complex FFTPACK transforms
//
template
<
typename
T0
>
class
cfftp
{
private:
...
...
@@ -940,6 +948,9 @@ template<bool bwd, typename T> NOINLINE void pass_all(T c[], T0 fact)
}
};
//
// real-valued FFTPACK transforms
//
template
<
typename
T0
>
class
rfftp
{
...
...
@@ -1740,6 +1751,10 @@ NOINLINE rfftp(size_t length_)
};
//
// complex Bluestein transforms
//
template
<
typename
T0
>
class
fftblue
{
private:
...
...
@@ -1854,6 +1869,10 @@ template<typename T0> class fftblue
}
};
//
// flexible (FFTPACK/Bluestein) complex 1D transform
//
template
<
typename
T0
>
class
pocketfft_c
{
private:
...
...
@@ -1894,6 +1913,11 @@ template<typename T0> class pocketfft_c
size_t
length
()
const
{
return
len
;
}
};
//
// flexible (FFTPACK/Bluestein) real-valued 1D transform
//
template
<
typename
T0
>
class
pocketfft_r
{
private:
...
...
@@ -1935,6 +1959,10 @@ template<typename T0> class pocketfft_r
size_t
length
()
const
{
return
len
;
}
};
//
// multi-D infrastructure
//
struct
diminfo
{
size_t
n
;
int64_t
s
;
};
class
multiarr
...
...
@@ -2345,6 +2373,10 @@ template<typename T> void pocketfft_general_c2r(const vector<size_t> &shape_out,
}
}
//
// Python interface
//
namespace
py
=
pybind11
;
auto
c64
=
py
::
dtype
(
"complex64"
);
...
...
@@ -2359,6 +2391,7 @@ vector<size_t> copy_shape(const py::array &arr)
res
[
i
]
=
arr
.
shape
(
i
);
return
res
;
}
vector
<
int64_t
>
copy_strides
(
const
py
::
array
&
arr
)
{
vector
<
int64_t
>
res
(
arr
.
ndim
());
...
...
@@ -2389,32 +2422,29 @@ vector<size_t> makeaxes(const py::array &in, py::object axes)
return
tmp
;
}
template
<
typename
T
>
py
::
array
execute
(
const
py
::
array
&
in
,
const
vector
<
size_t
>
&
axes
,
double
fct
,
bool
fwd
)
template
<
typename
T
>
py
::
array
xfftn_internal
(
const
py
::
array
&
in
,
const
vector
<
size_t
>
&
axes
,
double
fct
,
bool
fwd
)
{
auto
dims
(
copy_shape
(
in
));
py
::
array
res
=
py
::
array_t
<
complex
<
T
>>
(
dims
);
auto
s_i
(
copy_strides
(
in
)),
s_o
(
copy_strides
(
res
));
pocketfft_general_c
<
T
>
(
dims
,
s_i
,
s_o
,
axes
,
fwd
,
(
const
cmplx
<
T
>
*
)
in
.
data
(),
pocketfft_general_c
<
T
>
(
dims
,
s_i
,
s_o
,
axes
,
fwd
,
(
const
cmplx
<
T
>
*
)
in
.
data
(),
(
cmplx
<
T
>
*
)
res
.
mutable_data
(),
fct
);
return
res
;
}
py
::
array
fftn
(
const
py
::
array
&
a
,
py
::
object
axes
,
double
fct
)
py
::
array
xfftn
(
const
py
::
array
&
a
,
py
::
object
axes
,
double
fct
,
bool
fwd
)
{
if
(
a
.
dtype
().
is
(
c128
))
return
execute
<
double
>
(
a
,
makeaxes
(
a
,
axes
),
fct
,
true
);
return
xfftn_internal
<
double
>
(
a
,
makeaxes
(
a
,
axes
),
fct
,
fwd
);
else
if
(
a
.
dtype
().
is
(
c64
))
return
execute
<
float
>
(
a
,
makeaxes
(
a
,
axes
),
fct
,
true
);
return
xfftn_internal
<
float
>
(
a
,
makeaxes
(
a
,
axes
),
fct
,
fwd
);
throw
runtime_error
(
"unsupported data type"
);
}
py
::
array
fftn
(
const
py
::
array
&
a
,
py
::
object
axes
,
double
fct
)
{
return
xfftn
(
a
,
axes
,
fct
,
true
);
}
py
::
array
ifftn
(
const
py
::
array
&
a
,
py
::
object
axes
,
double
fct
)
{
if
(
a
.
dtype
().
is
(
c128
))
return
execute
<
double
>
(
a
,
makeaxes
(
a
,
axes
),
fct
,
false
);
else
if
(
a
.
dtype
().
is
(
c64
))
return
execute
<
float
>
(
a
,
makeaxes
(
a
,
axes
),
fct
,
false
);
throw
runtime_error
(
"unsupported data type"
);
}
{
return
xfftn
(
a
,
axes
,
fct
,
false
);
}
template
<
typename
T
>
py
::
array
rfftn_internal
(
const
py
::
array
&
in
,
py
::
object
axes_
,
T
fct
)
...
...
@@ -2452,7 +2482,7 @@ template<typename T> py::array irfftn_internal(const py::array &in,
vector
<
size_t
>
axes2
(
axes
.
size
()
-
1
);
for
(
size_t
i
=
0
;
i
<
axes2
.
size
();
++
i
)
axes2
[
i
]
=
axes
[
i
];
inter
=
execute
<
T
>
(
in
,
axes2
,
1.
,
false
);
inter
=
xfftn_internal
<
T
>
(
in
,
axes2
,
1.
,
false
);
}
else
inter
=
in
;
...
...
@@ -2482,7 +2512,7 @@ py::array irfftn(const py::array &in, py::object axes_, size_t lastsize,
template
<
typename
T
>
py
::
array
hartley_internal
(
const
py
::
array
&
in
,
py
::
object
axes_
,
double
fct
)
{
auto
axes
=
makeaxes
(
in
,
axes_
);
auto
axes
(
makeaxes
(
in
,
axes_
)
)
;
auto
dims
(
copy_shape
(
in
));
py
::
array
res
=
py
::
array_t
<
T
>
(
dims
);
auto
s_i
(
copy_strides
(
in
)),
s_o
(
copy_strides
(
res
));
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment