Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Simon Perkins
ducc
Commits
e2b8a7ca
Commit
e2b8a7ca
authored
Jan 18, 2020
by
Martin Reinecke
Browse files
start using fmav in FFT
parent
00c8d288
Changes
8
Hide whitespace changes
Inline
Side-by-side
.gitignore
View file @
e2b8a7ca
...
...
@@ -87,3 +87,4 @@ libtool
ar-lib
*-uninstalled.sh
missing
test-driver
libsharp2/sharp_geomhelpers.cc
View file @
e2b8a7ca
...
...
@@ -243,7 +243,7 @@ unique_ptr<sharp_geom_info> sharp_make_fejer1_geom_info (size_t nrings, size_t p
weight
[
2
*
k
]
=
2.
/
(
1.
-
4.
*
k
*
k
)
*
sin
((
k
*
pi
)
/
nrings
);
}
if
((
nrings
&
1
)
==
0
)
weight
[
nrings
-
1
]
=
0.
;
mr
::
r2r_fftpack
({
size_t
(
nrings
)},
{
sizeof
(
double
)},
{
sizeof
(
double
)
},
{
0
},
false
,
false
,
weight
.
data
(),
weight
.
data
(),
1.
);
mr
::
r2r_fftpack
({
size_t
(
nrings
)},
{
1
},
{
1
},
{
0
},
false
,
false
,
weight
.
data
(),
weight
.
data
(),
1.
);
for
(
size_t
m
=
0
;
m
<
(
nrings
+
1
)
/
2
;
++
m
)
{
...
...
@@ -275,7 +275,7 @@ unique_ptr<sharp_geom_info> sharp_make_cc_geom_info (size_t nrings, size_t pprin
for
(
size_t
k
=
1
;
k
<=
(
n
/
2
-
1
);
++
k
)
weight
[
2
*
k
-
1
]
=
2.
/
(
1.
-
4.
*
k
*
k
)
+
dw
;
weight
[
2
*
(
n
/
2
)
-
1
]
=
(
n
-
3.
)
/
(
2
*
(
n
/
2
)
-
1
)
-
1.
-
dw
*
((
2
-
(
n
&
1
))
*
n
-
1
);
mr
::
r2r_fftpack
({
size_t
(
n
)},
{
sizeof
(
double
)},
{
sizeof
(
double
)
},
{
0
},
false
,
false
,
weight
.
data
(),
weight
.
data
(),
1.
);
mr
::
r2r_fftpack
({
size_t
(
n
)},
{
1
},
{
1
},
{
0
},
false
,
false
,
weight
.
data
(),
weight
.
data
(),
1.
);
weight
[
n
]
=
weight
[
0
];
for
(
size_t
m
=
0
;
m
<
(
nrings
+
1
)
/
2
;
++
m
)
...
...
@@ -308,7 +308,7 @@ unique_ptr<sharp_geom_info> sharp_make_fejer2_geom_info (size_t nrings, size_t p
for
(
size_t
k
=
1
;
k
<=
(
n
/
2
-
1
);
++
k
)
weight
[
2
*
k
-
1
]
=
2.
/
(
1.
-
4.
*
k
*
k
);
weight
[
2
*
(
n
/
2
)
-
1
]
=
(
n
-
3.
)
/
(
2
*
(
n
/
2
)
-
1
)
-
1.
;
mr
::
r2r_fftpack
({
size_t
(
n
)},
{
sizeof
(
double
)},
{
sizeof
(
double
)
},
{
0
},
false
,
false
,
weight
.
data
(),
weight
.
data
(),
1.
);
mr
::
r2r_fftpack
({
size_t
(
n
)},
{
1
},
{
1
},
{
0
},
false
,
false
,
weight
.
data
(),
weight
.
data
(),
1.
);
for
(
size_t
m
=
0
;
m
<
nrings
;
++
m
)
weight
[
m
]
=
weight
[
m
+
1
];
...
...
mr_util/fft.h
View file @
e2b8a7ca
...
...
@@ -60,6 +60,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include
"mr_util/unity_roots.h"
#include
"mr_util/useful_macros.h"
#include
"mr_util/simd.h"
#include
"mr_util/mav.h"
#ifndef MRUTIL_NO_THREADING
#include
<mutex>
#endif
...
...
@@ -76,9 +77,6 @@ template <typename T> T cos(T) = delete;
template
<
typename
T
>
T
sin
(
T
)
=
delete
;
template
<
typename
T
>
T
sqrt
(
T
)
=
delete
;
using
shape_t
=
std
::
vector
<
size_t
>
;
using
stride_t
=
std
::
vector
<
ptrdiff_t
>
;
constexpr
bool
FORWARD
=
true
,
BACKWARD
=
false
;
...
...
@@ -2293,51 +2291,11 @@ template<typename T> std::shared_ptr<T> get_plan(size_t length)
#endif
}
class
arr_info
{
protected:
shape_t
shp
;
stride_t
str
;
public:
arr_info
(
const
shape_t
&
shape_
,
const
stride_t
&
stride_
)
:
shp
(
shape_
),
str
(
stride_
)
{}
size_t
ndim
()
const
{
return
shp
.
size
();
}
size_t
size
()
const
{
return
util
::
prod
(
shp
);
}
const
shape_t
&
shape
()
const
{
return
shp
;
}
size_t
shape
(
size_t
i
)
const
{
return
shp
[
i
];
}
const
stride_t
&
stride
()
const
{
return
str
;
}
const
ptrdiff_t
&
stride
(
size_t
i
)
const
{
return
str
[
i
];
}
};
template
<
typename
T
>
class
cndarr
:
public
arr_info
{
protected:
const
char
*
d
;
public:
cndarr
(
const
void
*
data_
,
const
shape_t
&
shape_
,
const
stride_t
&
stride_
)
:
arr_info
(
shape_
,
stride_
),
d
(
reinterpret_cast
<
const
char
*>
(
data_
))
{}
const
T
&
operator
[](
ptrdiff_t
ofs
)
const
{
return
*
reinterpret_cast
<
const
T
*>
(
d
+
ofs
);
}
};
template
<
typename
T
>
class
ndarr
:
public
cndarr
<
T
>
{
public:
ndarr
(
void
*
data_
,
const
shape_t
&
shape_
,
const
stride_t
&
stride_
)
:
cndarr
<
T
>::
cndarr
(
const_cast
<
const
void
*>
(
data_
),
shape_
,
stride_
)
{}
T
&
operator
[](
ptrdiff_t
ofs
)
{
return
*
reinterpret_cast
<
T
*>
(
const_cast
<
char
*>
(
cndarr
<
T
>::
d
+
ofs
));
}
};
template
<
size_t
N
>
class
multi_iter
{
private:
shape_t
pos
;
const
arr
_info
&
iarr
,
&
oarr
;
fmav
_info
iarr
,
oarr
;
ptrdiff_t
p_ii
,
p_i
[
N
],
str_i
,
p_oi
,
p_o
[
N
],
str_o
;
size_t
idim
,
rem
;
...
...
@@ -2358,7 +2316,7 @@ template<size_t N> class multi_iter
}
public:
multi_iter
(
const
arr
_info
&
iarr_
,
const
arr
_info
&
oarr_
,
size_t
idim_
,
multi_iter
(
const
fmav
_info
&
iarr_
,
const
fmav
_info
&
oarr_
,
size_t
idim_
,
size_t
nshares
,
size_t
myshare
)
:
pos
(
iarr_
.
ndim
(),
0
),
iarr
(
iarr_
),
oarr
(
oarr_
),
p_ii
(
0
),
str_i
(
iarr
.
stride
(
idim_
)),
p_oi
(
0
),
str_o
(
oarr
.
stride
(
idim_
)),
...
...
@@ -2412,12 +2370,12 @@ class simple_iter
{
private:
shape_t
pos
;
const
arr
_info
&
arr
;
fmav
_info
arr
;
ptrdiff_t
p
;
size_t
rem
;
public:
simple_iter
(
const
arr
_info
&
arr_
)
simple_iter
(
const
fmav
_info
&
arr_
)
:
pos
(
arr_
.
ndim
(),
0
),
arr
(
arr_
),
p
(
0
),
rem
(
arr_
.
size
())
{}
void
advance
()
{
...
...
@@ -2440,7 +2398,7 @@ class rev_iter
{
private:
shape_t
pos
;
const
arr
_info
&
arr
;
fmav
_info
arr
;
std
::
vector
<
char
>
rev_axis
;
std
::
vector
<
char
>
rev_jump
;
size_t
last_axis
,
last_size
;
...
...
@@ -2449,7 +2407,7 @@ class rev_iter
size_t
rem
;
public:
rev_iter
(
const
arr
_info
&
arr_
,
const
shape_t
&
axes
)
rev_iter
(
const
fmav
_info
&
arr_
,
const
shape_t
&
axes
)
:
pos
(
arr_
.
ndim
(),
0
),
arr
(
arr_
),
rev_axis
(
arr_
.
ndim
(),
0
),
rev_jump
(
arr_
.
ndim
(),
1
),
p
(
0
),
rp
(
0
)
{
...
...
@@ -2517,15 +2475,15 @@ template<> struct VTYPE<long double>
};
#endif
template
<
typename
T
>
aligned_array
<
char
>
alloc_tmp
(
const
shape_t
&
shape
,
size_t
axsize
,
size_t
elemsize
)
template
<
typename
T
,
typename
T0
>
aligned_array
<
T
>
alloc_tmp
(
const
shape_t
&
shape
,
size_t
axsize
)
{
auto
othersize
=
util
::
prod
(
shape
)
/
axsize
;
auto
tmpsize
=
axsize
*
((
othersize
>=
VLEN
<
T
>::
val
)
?
VLEN
<
T
>::
val
:
1
);
return
aligned_array
<
char
>
(
tmpsize
*
elemsize
);
auto
tmpsize
=
axsize
*
((
othersize
>=
VLEN
<
T
0
>::
val
)
?
VLEN
<
T
0
>::
val
:
1
);
return
aligned_array
<
T
>
(
tmpsize
);
}
template
<
typename
T
>
aligned_array
<
char
>
alloc_tmp
(
const
shape_t
&
shape
,
const
shape_t
&
axes
,
size_t
elemsize
)
template
<
typename
T
,
typename
T0
>
aligned_array
<
T
>
alloc_tmp
(
const
shape_t
&
shape
,
const
shape_t
&
axes
)
{
size_t
fullsize
=
util
::
prod
(
shape
);
size_t
tmpsize
=
0
;
...
...
@@ -2533,14 +2491,14 @@ template<typename T> aligned_array<char> alloc_tmp(const shape_t &shape,
{
auto
axsize
=
shape
[
axes
[
i
]];
auto
othersize
=
fullsize
/
axsize
;
auto
sz
=
axsize
*
((
othersize
>=
VLEN
<
T
>::
val
)
?
VLEN
<
T
>::
val
:
1
);
auto
sz
=
axsize
*
((
othersize
>=
VLEN
<
T
0
>::
val
)
?
VLEN
<
T
0
>::
val
:
1
);
if
(
sz
>
tmpsize
)
tmpsize
=
sz
;
}
return
aligned_array
<
char
>
(
tmpsize
*
elemsize
);
return
aligned_array
<
T
>
(
tmpsize
);
}
template
<
typename
T
,
size_t
vlen
>
void
copy_input
(
const
multi_iter
<
vlen
>
&
it
,
const
c
ndarr
<
Cmplx
<
T
>>
&
src
,
Cmplx
<
vtype_t
<
T
>>
*
MRUTIL_RESTRICT
dst
)
const
c
fmav
<
Cmplx
<
T
>>
&
src
,
Cmplx
<
vtype_t
<
T
>>
*
MRUTIL_RESTRICT
dst
)
{
for
(
size_t
i
=
0
;
i
<
it
.
length_in
();
++
i
)
for
(
size_t
j
=
0
;
j
<
vlen
;
++
j
)
...
...
@@ -2551,7 +2509,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
}
template
<
typename
T
,
size_t
vlen
>
void
copy_input
(
const
multi_iter
<
vlen
>
&
it
,
const
c
ndarr
<
T
>
&
src
,
vtype_t
<
T
>
*
MRUTIL_RESTRICT
dst
)
const
c
fmav
<
T
>
&
src
,
vtype_t
<
T
>
*
MRUTIL_RESTRICT
dst
)
{
for
(
size_t
i
=
0
;
i
<
it
.
length_in
();
++
i
)
for
(
size_t
j
=
0
;
j
<
vlen
;
++
j
)
...
...
@@ -2559,7 +2517,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
}
template
<
typename
T
,
size_t
vlen
>
void
copy_input
(
const
multi_iter
<
vlen
>
&
it
,
const
c
ndarr
<
T
>
&
src
,
T
*
MRUTIL_RESTRICT
dst
)
const
c
fmav
<
T
>
&
src
,
T
*
MRUTIL_RESTRICT
dst
)
{
if
(
dst
==
&
src
[
it
.
iofs
(
0
)])
return
;
// in-place
for
(
size_t
i
=
0
;
i
<
it
.
length_in
();
++
i
)
...
...
@@ -2567,7 +2525,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
}
template
<
typename
T
,
size_t
vlen
>
void
copy_output
(
const
multi_iter
<
vlen
>
&
it
,
const
Cmplx
<
vtype_t
<
T
>>
*
MRUTIL_RESTRICT
src
,
ndarr
<
Cmplx
<
T
>>
&
dst
)
const
Cmplx
<
vtype_t
<
T
>>
*
MRUTIL_RESTRICT
src
,
const
fmav
<
Cmplx
<
T
>>
&
dst
)
{
for
(
size_t
i
=
0
;
i
<
it
.
length_out
();
++
i
)
for
(
size_t
j
=
0
;
j
<
vlen
;
++
j
)
...
...
@@ -2575,7 +2533,7 @@ template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
}
template
<
typename
T
,
size_t
vlen
>
void
copy_output
(
const
multi_iter
<
vlen
>
&
it
,
const
vtype_t
<
T
>
*
MRUTIL_RESTRICT
src
,
ndarr
<
T
>
&
dst
)
const
vtype_t
<
T
>
*
MRUTIL_RESTRICT
src
,
const
fmav
<
T
>
&
dst
)
{
for
(
size_t
i
=
0
;
i
<
it
.
length_out
();
++
i
)
for
(
size_t
j
=
0
;
j
<
vlen
;
++
j
)
...
...
@@ -2583,7 +2541,7 @@ template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
}
template
<
typename
T
,
size_t
vlen
>
void
copy_output
(
const
multi_iter
<
vlen
>
&
it
,
const
T
*
MRUTIL_RESTRICT
src
,
ndarr
<
T
>
&
dst
)
const
T
*
MRUTIL_RESTRICT
src
,
const
fmav
<
T
>
&
dst
)
{
if
(
src
==
&
dst
[
it
.
oofs
(
0
)])
return
;
// in-place
for
(
size_t
i
=
0
;
i
<
it
.
length_out
();
++
i
)
...
...
@@ -2596,7 +2554,7 @@ template <typename T> struct add_vec<Cmplx<T>>
template
<
typename
T
>
using
add_vec_t
=
typename
add_vec
<
T
>::
type
;
template
<
typename
Tplan
,
typename
T
,
typename
T0
,
typename
Exec
>
MRUTIL_NOINLINE
void
general_nd
(
const
c
ndarr
<
T
>
&
in
,
ndarr
<
T
>
&
out
,
MRUTIL_NOINLINE
void
general_nd
(
const
c
fmav
<
T
>
&
in
,
const
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
T0
fct
,
size_t
nthreads
,
const
Exec
&
exec
,
const
bool
allow_inplace
=
true
)
{
...
...
@@ -2612,8 +2570,8 @@ MRUTIL_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
util
::
thread_count
(
nthreads
,
in
.
shape
(),
axes
[
iax
],
VLEN
<
T
>::
val
),
[
&
](
Scheduler
&
sched
)
{
constexpr
auto
vlen
=
VLEN
<
T0
>::
val
;
auto
storage
=
alloc_tmp
<
T0
>
(
in
.
shape
(),
len
,
sizeof
(
T
)
);
const
auto
&
tin
(
iax
==
0
?
in
:
out
);
auto
storage
=
alloc_tmp
<
T
,
T0
>
(
in
.
shape
(),
len
);
const
auto
&
tin
(
iax
==
0
?
in
:
cfmav
<
T
>
(
out
)
)
;
multi_iter
<
vlen
>
it
(
tin
,
out
,
axes
[
iax
],
sched
.
num_threads
(),
sched
.
thread_num
());
#ifndef POCKETFFT_NO_VECTORS
if
(
vlen
>
1
)
...
...
@@ -2627,7 +2585,7 @@ MRUTIL_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
while
(
it
.
remaining
()
>
0
)
{
it
.
advance
(
1
);
auto
buf
=
allow_inplace
&&
it
.
stride_out
()
==
sizeof
(
T
)
?
auto
buf
=
allow_inplace
&&
it
.
stride_out
()
==
1
?
&
out
[
it
.
oofs
(
0
)]
:
reinterpret_cast
<
T
*>
(
storage
.
data
());
exec
(
it
,
tin
,
out
,
buf
,
*
plan
,
fct
);
}
...
...
@@ -2641,8 +2599,8 @@ struct ExecC2C
bool
forward
;
template
<
typename
T0
,
typename
T
,
size_t
vlen
>
void
operator
()
(
const
multi_iter
<
vlen
>
&
it
,
const
c
ndarr
<
Cmplx
<
T0
>>
&
in
,
ndarr
<
Cmplx
<
T0
>>
&
out
,
T
*
buf
,
const
pocketfft_c
<
T0
>
&
plan
,
T0
fct
)
const
const
multi_iter
<
vlen
>
&
it
,
const
c
fmav
<
Cmplx
<
T0
>>
&
in
,
const
fmav
<
Cmplx
<
T0
>>
&
out
,
T
*
buf
,
const
pocketfft_c
<
T0
>
&
plan
,
T0
fct
)
const
{
copy_input
(
it
,
in
,
buf
);
plan
.
exec
(
buf
,
fct
,
forward
);
...
...
@@ -2651,7 +2609,7 @@ struct ExecC2C
};
template
<
typename
T
,
size_t
vlen
>
void
copy_hartley
(
const
multi_iter
<
vlen
>
&
it
,
const
vtype_t
<
T
>
*
MRUTIL_RESTRICT
src
,
ndarr
<
T
>
&
dst
)
const
vtype_t
<
T
>
*
MRUTIL_RESTRICT
src
,
const
fmav
<
T
>
&
dst
)
{
for
(
size_t
j
=
0
;
j
<
vlen
;
++
j
)
dst
[
it
.
oofs
(
j
,
0
)]
=
src
[
0
][
j
];
...
...
@@ -2668,7 +2626,7 @@ template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
}
template
<
typename
T
,
size_t
vlen
>
void
copy_hartley
(
const
multi_iter
<
vlen
>
&
it
,
const
T
*
MRUTIL_RESTRICT
src
,
ndarr
<
T
>
&
dst
)
const
T
*
MRUTIL_RESTRICT
src
,
const
fmav
<
T
>
&
dst
)
{
dst
[
it
.
oofs
(
0
)]
=
src
[
0
];
size_t
i
=
1
,
i1
=
1
,
i2
=
it
.
length_out
()
-
1
;
...
...
@@ -2684,7 +2642,7 @@ template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
struct
ExecHartley
{
template
<
typename
T0
,
typename
T
,
size_t
vlen
>
void
operator
()
(
const
multi_iter
<
vlen
>
&
it
,
const
c
ndarr
<
T0
>
&
in
,
ndarr
<
T0
>
&
out
,
const
multi_iter
<
vlen
>
&
it
,
const
c
fmav
<
T0
>
&
in
,
const
fmav
<
T0
>
&
out
,
T
*
buf
,
const
pocketfft_r
<
T0
>
&
plan
,
T0
fct
)
const
{
copy_input
(
it
,
in
,
buf
);
...
...
@@ -2700,8 +2658,8 @@ struct ExecDcst
bool
cosine
;
template
<
typename
T0
,
typename
T
,
typename
Tplan
,
size_t
vlen
>
void
operator
()
(
const
multi_iter
<
vlen
>
&
it
,
const
c
ndarr
<
T0
>
&
in
,
ndarr
<
T0
>
&
out
,
T
*
buf
,
const
Tplan
&
plan
,
T0
fct
)
const
void
operator
()
(
const
multi_iter
<
vlen
>
&
it
,
const
c
fmav
<
T0
>
&
in
,
const
fmav
<
T0
>
&
out
,
T
*
buf
,
const
Tplan
&
plan
,
T0
fct
)
const
{
copy_input
(
it
,
in
,
buf
);
plan
.
exec
(
buf
,
fct
,
ortho
,
type
,
cosine
);
...
...
@@ -2710,7 +2668,7 @@ struct ExecDcst
};
template
<
typename
T
>
MRUTIL_NOINLINE
void
general_r2c
(
const
c
ndarr
<
T
>
&
in
,
ndarr
<
Cmplx
<
T
>>
&
out
,
size_t
axis
,
bool
forward
,
T
fct
,
const
c
fmav
<
T
>
&
in
,
const
fmav
<
Cmplx
<
T
>>
&
out
,
size_t
axis
,
bool
forward
,
T
fct
,
size_t
nthreads
)
{
auto
plan
=
get_plan
<
pocketfft_r
<
T
>>
(
in
.
shape
(
axis
));
...
...
@@ -2719,7 +2677,7 @@ template<typename T> MRUTIL_NOINLINE void general_r2c(
util
::
thread_count
(
nthreads
,
in
.
shape
(),
axis
,
VLEN
<
T
>::
val
),
[
&
](
Scheduler
&
sched
)
{
constexpr
auto
vlen
=
VLEN
<
T
>::
val
;
auto
storage
=
alloc_tmp
<
T
>
(
in
.
shape
(),
len
,
sizeof
(
T
)
);
auto
storage
=
alloc_tmp
<
T
,
T
>
(
in
.
shape
(),
len
);
multi_iter
<
vlen
>
it
(
in
,
out
,
axis
,
sched
.
num_threads
(),
sched
.
thread_num
());
#ifndef POCKETFFT_NO_VECTORS
if
(
vlen
>
1
)
...
...
@@ -2765,7 +2723,7 @@ template<typename T> MRUTIL_NOINLINE void general_r2c(
});
// end of parallel region
}
template
<
typename
T
>
MRUTIL_NOINLINE
void
general_c2r
(
const
c
ndarr
<
Cmplx
<
T
>>
&
in
,
ndarr
<
T
>
&
out
,
size_t
axis
,
bool
forward
,
T
fct
,
const
c
fmav
<
Cmplx
<
T
>>
&
in
,
const
fmav
<
T
>
&
out
,
size_t
axis
,
bool
forward
,
T
fct
,
size_t
nthreads
)
{
auto
plan
=
get_plan
<
pocketfft_r
<
T
>>
(
out
.
shape
(
axis
));
...
...
@@ -2774,7 +2732,7 @@ template<typename T> MRUTIL_NOINLINE void general_c2r(
util
::
thread_count
(
nthreads
,
in
.
shape
(),
axis
,
VLEN
<
T
>::
val
),
[
&
](
Scheduler
&
sched
)
{
constexpr
auto
vlen
=
VLEN
<
T
>::
val
;
auto
storage
=
alloc_tmp
<
T
>
(
out
.
shape
(),
len
,
sizeof
(
T
)
);
auto
storage
=
alloc_tmp
<
T
,
T
>
(
out
.
shape
(),
len
);
multi_iter
<
vlen
>
it
(
in
,
out
,
axis
,
sched
.
num_threads
(),
sched
.
thread_num
());
#ifndef POCKETFFT_NO_VECTORS
if
(
vlen
>
1
)
...
...
@@ -2841,7 +2799,7 @@ struct ExecR2R
bool
r2c
,
forward
;
template
<
typename
T0
,
typename
T
,
size_t
vlen
>
void
operator
()
(
const
multi_iter
<
vlen
>
&
it
,
const
c
ndarr
<
T0
>
&
in
,
ndarr
<
T0
>
&
out
,
T
*
buf
,
const
multi_iter
<
vlen
>
&
it
,
const
c
fmav
<
T0
>
&
in
,
const
fmav
<
T0
>
&
out
,
T
*
buf
,
const
pocketfft_r
<
T0
>
&
plan
,
T0
fct
)
const
{
copy_input
(
it
,
in
,
buf
);
...
...
@@ -2863,8 +2821,8 @@ template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,
{
if
(
util
::
prod
(
shape
)
==
0
)
return
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
c
ndarr
<
Cmplx
<
T
>>
ain
(
data_in
,
shape
,
stride_in
);
ndarr
<
Cmplx
<
T
>>
aout
(
data_out
,
shape
,
stride_out
);
c
onst
cfmav
<
Cmplx
<
T
>>
ain
(
reinterpret_cast
<
const
Cmplx
<
T
>
*>
(
data_in
)
,
shape
,
stride_in
);
const
fmav
<
Cmplx
<
T
>>
aout
(
reinterpret_cast
<
Cmplx
<
T
>
*>
(
data_out
)
,
shape
,
stride_out
);
general_nd
<
pocketfft_c
<
T
>>
(
ain
,
aout
,
axes
,
fct
,
nthreads
,
ExecC2C
{
forward
});
}
...
...
@@ -2875,8 +2833,8 @@ template<typename T> void dct(const shape_t &shape,
if
((
type
<
1
)
||
(
type
>
4
))
throw
std
::
invalid_argument
(
"invalid DCT type"
);
if
(
util
::
prod
(
shape
)
==
0
)
return
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
c
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
);
ndarr
<
T
>
aout
(
data_out
,
shape
,
stride_out
);
c
onst
cfmav
<
T
>
ain
(
data_in
,
shape
,
stride_in
);
const
fmav
<
T
>
aout
(
data_out
,
shape
,
stride_out
);
const
ExecDcst
exec
{
ortho
,
type
,
true
};
if
(
type
==
1
)
general_nd
<
T_dct1
<
T
>>
(
ain
,
aout
,
axes
,
fct
,
nthreads
,
exec
);
...
...
@@ -2893,8 +2851,8 @@ template<typename T> void dst(const shape_t &shape,
if
((
type
<
1
)
||
(
type
>
4
))
throw
std
::
invalid_argument
(
"invalid DST type"
);
if
(
util
::
prod
(
shape
)
==
0
)
return
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
c
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
);
ndarr
<
T
>
aout
(
data_out
,
shape
,
stride_out
);
c
onst
cfmav
<
T
>
ain
(
data_in
,
shape
,
stride_in
);
const
fmav
<
T
>
aout
(
data_out
,
shape
,
stride_out
);
const
ExecDcst
exec
{
ortho
,
type
,
false
};
if
(
type
==
1
)
general_nd
<
T_dst1
<
T
>>
(
ain
,
aout
,
axes
,
fct
,
nthreads
,
exec
);
...
...
@@ -2911,10 +2869,10 @@ template<typename T> void r2c(const shape_t &shape_in,
{
if
(
util
::
prod
(
shape_in
)
==
0
)
return
;
util
::
sanity_check
(
shape_in
,
stride_in
,
stride_out
,
false
,
axis
);
c
ndarr
<
T
>
ain
(
data_in
,
shape_in
,
stride_in
);
c
onst
cfmav
<
T
>
ain
(
data_in
,
shape_in
,
stride_in
);
shape_t
shape_out
(
shape_in
);
shape_out
[
axis
]
=
shape_in
[
axis
]
/
2
+
1
;
ndarr
<
Cmplx
<
T
>>
aout
(
data_out
,
shape_out
,
stride_out
);
const
fmav
<
Cmplx
<
T
>>
aout
(
reinterpret_cast
<
Cmplx
<
T
>
*>
(
data_out
)
,
shape_out
,
stride_out
);
general_r2c
(
ain
,
aout
,
axis
,
forward
,
fct
,
nthreads
);
}
...
...
@@ -2945,8 +2903,8 @@ template<typename T> void c2r(const shape_t &shape_out,
util
::
sanity_check
(
shape_out
,
stride_in
,
stride_out
,
false
,
axis
);
shape_t
shape_in
(
shape_out
);
shape_in
[
axis
]
=
shape_out
[
axis
]
/
2
+
1
;
c
ndarr
<
Cmplx
<
T
>>
ain
(
data_in
,
shape_in
,
stride_in
);
ndarr
<
T
>
aout
(
data_out
,
shape_out
,
stride_out
);
c
onst
cfmav
<
Cmplx
<
T
>>
ain
(
reinterpret_cast
<
const
Cmplx
<
T
>
*>
(
data_in
)
,
shape_in
,
stride_in
);
const
fmav
<
T
>
aout
(
data_out
,
shape_out
,
stride_out
);
general_c2r
(
ain
,
aout
,
axis
,
forward
,
fct
,
nthreads
);
}
...
...
@@ -2964,7 +2922,7 @@ template<typename T> void c2r(const shape_t &shape_out,
shape_in
[
axes
.
back
()]
=
shape_out
[
axes
.
back
()]
/
2
+
1
;
auto
nval
=
util
::
prod
(
shape_in
);
stride_t
stride_inter
(
shape_in
.
size
());
stride_inter
.
back
()
=
sizeof
(
Cmplx
<
T
>
)
;
stride_inter
.
back
()
=
1
;
for
(
int
i
=
int
(
shape_in
.
size
())
-
2
;
i
>=
0
;
--
i
)
stride_inter
[
size_t
(
i
)]
=
stride_inter
[
size_t
(
i
+
1
)]
*
ptrdiff_t
(
shape_in
[
size_t
(
i
+
1
)]);
...
...
@@ -2983,8 +2941,8 @@ template<typename T> void r2r_fftpack(const shape_t &shape,
{
if
(
util
::
prod
(
shape
)
==
0
)
return
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
c
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
);
ndarr
<
T
>
aout
(
data_out
,
shape
,
stride_out
);
c
onst
cfmav
<
T
>
ain
(
data_in
,
shape
,
stride_in
);
const
fmav
<
T
>
aout
(
data_out
,
shape
,
stride_out
);
general_nd
<
pocketfft_r
<
T
>>
(
ain
,
aout
,
axes
,
fct
,
nthreads
,
ExecR2R
{
real2hermitian
,
forward
});
}
...
...
@@ -2995,8 +2953,8 @@ template<typename T> void r2r_separable_hartley(const shape_t &shape,
{
if
(
util
::
prod
(
shape
)
==
0
)
return
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
c
ndarr
<
T
>
ain
(
data_in
,
shape
,
stride_in
);
ndarr
<
T
>
aout
(
data_out
,
shape
,
stride_out
);
c
onst
cfmav
<
T
>
ain
(
data_in
,
shape
,
stride_in
);
const
fmav
<
T
>
aout
(
data_out
,
shape
,
stride_out
);
general_nd
<
pocketfft_r
<
T
>>
(
ain
,
aout
,
axes
,
fct
,
nthreads
,
ExecHartley
{},
false
);
}
...
...
@@ -3014,12 +2972,12 @@ template<typename T> void r2r_genuine_hartley(const shape_t &shape,
tshp
[
axes
.
back
()]
=
tshp
[
axes
.
back
()]
/
2
+
1
;
aligned_array
<
std
::
complex
<
T
>>
tdata
(
util
::
prod
(
tshp
));
stride_t
tstride
(
shape
.
size
());
tstride
.
back
()
=
sizeof
(
std
::
complex
<
T
>
)
;
tstride
.
back
()
=
1
;
for
(
size_t
i
=
tstride
.
size
()
-
1
;
i
>
0
;
--
i
)
tstride
[
i
-
1
]
=
tstride
[
i
]
*
ptrdiff_t
(
tshp
[
i
]);
r2c
(
shape
,
stride_in
,
tstride
,
axes
,
true
,
data_in
,
tdata
.
data
(),
fct
,
nthreads
);
c
ndarr
<
Cmplx
<
T
>>
atmp
(
tdata
.
data
(),
tshp
,
tstride
);
ndarr
<
T
>
aout
(
data_out
,
shape
,
stride_out
);
c
onst
cfmav
<
Cmplx
<
T
>>
atmp
(
reinterpret_cast
<
const
Cmplx
<
T
>
*>
(
tdata
.
data
()
)
,
tshp
,
tstride
);
const
fmav
<
T
>
aout
(
data_out
,
shape
,
stride_out
);
simple_iter
iin
(
atmp
);
rev_iter
iout
(
aout
,
axes
);
while
(
iin
.
remaining
()
>
0
)
...
...
@@ -3035,8 +2993,6 @@ template<typename T> void r2r_genuine_hartley(const shape_t &shape,
using
detail_fft
::
FORWARD
;
using
detail_fft
::
BACKWARD
;
using
detail_fft
::
shape_t
;
using
detail_fft
::
stride_t
;
using
detail_fft
::
c2c
;
using
detail_fft
::
c2r
;
using
detail_fft
::
r2c
;
...
...
mr_util/mav.h
View file @
e2b8a7ca
...
...
@@ -24,6 +24,8 @@
#include
<cstdlib>
#include
<array>
#include
<vector>
#include
"mr_util/error_handling.h"
namespace
mr
{
...
...
@@ -31,7 +33,100 @@ namespace detail_mav {
using
namespace
std
;
using
shape_t
=
vector
<
size_t
>
;
using
stride_t
=
vector
<
ptrdiff_t
>
;
class
fmav_info
{
protected:
shape_t
shp
;
stride_t
str
;
static
size_t
prod
(
const
shape_t
&
shape
)
{
size_t
res
=
1
;
for
(
auto
sz
:
shape
)
res
*=
sz
;
return
res
;
}
public:
fmav_info
(
const
shape_t
&
shape_
,
const
stride_t
&
stride_
)
:
shp
(
shape_
),
str
(
stride_
)
{
MR_assert
(
shp
.
size
()
==
str
.
size
(),
"dimensions mismatch"
);
}
fmav_info
(
const
shape_t
&
shape_
)
:
shp
(
shape_
),
str
(
shape_
.
size
())
{
auto
ndim
=
shp
.
size
();
str
[
ndim
-
1
]
=
1
;
for
(
size_t
i
=
2
;
i
<=
ndim
;
++
i
)
str
[
ndim
-
i
]
=
str
[
ndim
-
i
+
1
]
*
ptrdiff_t
(
shp
[
ndim
-
i
+
1
]);
}
size_t
ndim
()
const
{
return
shp
.
size
();
}
size_t
size
()
const
{
return
prod
(
shp
);
}
const
shape_t
&
shape
()
const
{
return
shp
;
}
size_t
shape
(
size_t
i
)
const
{
return
shp
[
i
];
}
const
stride_t
&
stride
()
const
{
return
str
;
}
const
ptrdiff_t
&
stride
(
size_t
i
)
const
{
return
str
[
i
];
}
bool
last_contiguous
()
const
{
return
(
str
.
back
()
==
1
);
}
bool
contiguous
()
const
{
auto
ndim
=
shp
.
size
();
ptrdiff_t
stride
=
1
;
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
if
(
str
[
ndim
-
1
-
i
]
!=
stride
)
return
false
;
stride
*=
shp
[
ndim
-
1
-
i
];
}
return
true
;
}
bool
conformable
(
const
fmav_info
&
other
)
const
{
return
shp
==
other
.
shp
;
}
};
// "mav" stands for "multidimensional array view"
template
<
typename
T
>
class
cfmav
:
public
fmav_info
{
protected:
T
*
d
;
public:
cfmav
(
const
T
*
d_
,
const
shape_t
&
shp_
,
const
stride_t
&
str_
)
:
fmav_info
(
shp_
,
str_
),
d
(
const_cast
<
T
*>
(
d_
))
{}
cfmav
(
const
T
*
d_
,
const
shape_t
&
shp_
)
:
fmav_info
(
shp_
),
d
(
const_cast
<
T
*>
(
d_
))
{}
template
<
typename
I
>
const
T
&
operator
[](
I
i
)
const
{
return
d
[
i
];
}
const
T
*
data
()
const
{
return
d
;
}
};
template
<
typename
T
>
class
fmav
:
public
cfmav
<
T
>
{
protected:
using
parent
=
cfmav
<
T
>
;
using
parent
::
d
;
using
parent
::
shp
;
using
parent
::
str
;
public:
fmav
(
T
*
d_
,
const
shape_t
&
shp_
,
const
stride_t
&
str_
)
:
parent
(
d_
,
shp_
,
str_
)
{}
fmav
(
T
*
d_
,
const
shape_t
&
shp_
)
:
parent
(
d_
,
shp_
)
{}
template
<
typename
I
>
T
&
operator
[](
I
i
)
const
{
return
d
[
i
];
}
using
parent
::
shape
;
using
parent
::
stride
;
T
*
data
()
const
{
return
d
;
}
using
parent
::
last_contiguous
;
using
parent
::
contiguous
;
using
parent
::
conformable
;
};
template
<
typename
T
>
using
const_fmav
=
fmav
<
const
T
>
;
template
<
typename
T
,
size_t
ndim
>
class
mav
{
static_assert
((
ndim
>
0
)
&&
(
ndim
<
3
),
"only supports 1D and 2D arrays"
);
...
...
@@ -49,8 +144,8 @@ template<typename T, size_t ndim> class mav
:
d
(
d_
),
shp
(
shp_
)
{
str
[
ndim
-
1
]
=
1
;
for
(
size_t
d
=
2
;
d
<=
ndim
;
++
d
)
str
[
ndim
-
d
]
=
str
[
ndim
-
d
+
1
]
*
shp
[
ndim
-
d
+
1
];
for
(
size_t
i
=
2
;
i
<=
ndim
;
++
i
)
str
[
ndim
-
i
]
=
str
[
ndim
-
i
+
1
]
*
shp
[
ndim
-
i
+
1
];
}
T
&
operator
[](
size_t
i
)
const
{
return
operator
()(
i
);
}
...
...
@@ -98,6 +193,8 @@ template<typename T, size_t ndim> class mav
for
(
size_t
j
=
0
;
j
<
shp
[
1
];
++
j
)
d
[
str
[
0
]
*
i
+
str
[
1
]
*
j
]
=
val
;
}
template
<
typename
T2
>
bool
conformable
(
const
mav
<
T2
,
ndim
>
&
other
)
const
{
return
shp
==
other
.
shp
;
}
};
template
<
typename
T
,
size_t
ndim
>
using
const_mav
=
mav
<
const
T
,
ndim
>
;
...
...
@@ -112,6 +209,12 @@ template<typename T, size_t ndim> const_mav<T, ndim> nullmav()
}
using
detail_mav
::
shape_t
;
using
detail_mav
::
stride_t
;
using
detail_mav
::
fmav_info
;