Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Simon Perkins
ducc
Commits
8705186d
Commit
8705186d
authored
Jan 18, 2020
by
Martin Reinecke
Browse files
more mavs
parent
33590924
Changes
5
Hide whitespace changes
Inline
Side-by-side
mr_util/fft.h
View file @
8705186d
...
...
@@ -222,24 +222,10 @@ struct util // hack to avoid duplicate symbols
return
res
;
}
static
MRUTIL_NOINLINE
void
sanity_check
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
bool
inplace
)
static
void
sanity_check_axes
(
size_t
ndim
,
const
shape_t
&
axes
)
{
auto
ndim
=
shape
.
size
();
if
(
ndim
<
1
)
throw
std
::
runtime_error
(
"ndim must be >= 1"
);
if
((
stride_in
.
size
()
!=
ndim
)
||
(
stride_out
.
size
()
!=
ndim
))
throw
std
::
runtime_error
(
"stride dimension mismatch"
);
if
(
inplace
&&
(
stride_in
!=
stride_out
))
throw
std
::
runtime_error
(
"stride mismatch"
);
}
static
MRUTIL_NOINLINE
void
sanity_check
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
bool
inplace
,
const
shape_t
&
axes
)
{
sanity_check
(
shape
,
stride_in
,
stride_out
,
inplace
);
auto
ndim
=
shape
.
size
();
shape_t
tmp
(
ndim
,
0
);
if
(
axes
.
empty
())
throw
std
::
invalid_argument
(
"no axes specified"
);
for
(
auto
ax
:
axes
)
{
if
(
ax
>=
ndim
)
throw
std
::
invalid_argument
(
"bad axis number"
);
...
...
@@ -247,12 +233,30 @@ struct util // hack to avoid duplicate symbols
}
}
static
MRUTIL_NOINLINE
void
sanity_check
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
bool
inplace
,
size_t
axis
)
static
MRUTIL_NOINLINE
void
sanity_check_onetype
(
const
fmav_info
&
a1
,
const
fmav_info
&
a2
,
bool
inplace
,
const
shape_t
&
axes
)
{
sanity_check
(
shape
,
stride_in
,
stride_out
,
inplace
);
if
(
axis
>=
shape
.
size
())
throw
std
::
invalid_argument
(
"bad axis number"
);
sanity_check_axes
(
a1
.
ndim
(),
axes
);
MR_assert
(
a1
.
conformable
(
a2
),
"array sizes are not conformable"
);
if
(
inplace
)
MR_assert
(
a1
.
stride
()
==
a2
.
stride
(),
"stride mismatch"
);
}
static
MRUTIL_NOINLINE
void
sanity_check_cr
(
const
fmav_info
&
ac
,
const
fmav_info
&
ar
,
const
shape_t
&
axes
)
{
sanity_check_axes
(
ac
.
ndim
(),
axes
);
MR_assert
(
ac
.
ndim
()
==
ar
.
ndim
(),
"dimension mismatch"
);
for
(
size_t
i
=
0
;
i
<
ac
.
ndim
();
++
i
)
MR_assert
(
ac
.
shape
(
i
)
==
(
i
==
axes
.
back
())
?
(
ar
.
shape
(
i
)
/
2
+
1
)
:
ar
.
shape
(
i
),
"axis length mismatch"
);
}
static
MRUTIL_NOINLINE
void
sanity_check_cr
(
const
fmav_info
&
ac
,
const
fmav_info
&
ar
,
const
size_t
axis
)
{
if
(
axis
>=
ac
.
ndim
())
throw
std
::
invalid_argument
(
"bad axis number"
);
MR_assert
(
ac
.
ndim
()
==
ar
.
ndim
(),
"dimension mismatch"
);
for
(
size_t
i
=
0
;
i
<
ac
.
ndim
();
++
i
)
MR_assert
(
ac
.
shape
(
i
)
==
(
i
==
axis
)
?
(
ar
.
shape
(
i
)
/
2
+
1
)
:
ar
.
shape
(
i
),
"axis length mismatch"
);
}
#ifdef MRUTIL_NO_THREADING
...
...
@@ -2571,7 +2575,7 @@ MRUTIL_NOINLINE void general_nd(const cfmav<T> &in, const fmav<T> &out,
[
&
](
Scheduler
&
sched
)
{
constexpr
auto
vlen
=
VLEN
<
T0
>::
val
;
auto
storage
=
alloc_tmp
<
T
,
T0
>
(
in
.
shape
(),
len
);
const
auto
&
tin
(
iax
==
0
?
in
:
cfmav
<
T
>
(
out
)
)
;
const
auto
&
tin
(
iax
==
0
?
in
:
out
);
multi_iter
<
vlen
>
it
(
tin
,
out
,
axes
[
iax
],
sched
.
num_threads
(),
sched
.
thread_num
());
#ifndef POCKETFFT_NO_VECTORS
if
(
vlen
>
1
)
...
...
@@ -2814,177 +2818,132 @@ struct ExecR2R
}
};
template
<
typename
T
>
void
c2c
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
bool
forward
,
const
std
::
complex
<
T
>
*
data_in
,
std
::
complex
<
T
>
*
data_out
,
T
fct
,
size_t
nthreads
=
1
)
template
<
typename
T
>
void
c2c
(
const
cfmav
<
std
::
complex
<
T
>>
&
in
,
const
fmav
<
std
::
complex
<
T
>>
&
out
,
const
shape_t
&
axes
,
bool
forward
,
T
fct
,
size_t
nthreads
=
1
)
{
if
(
util
::
prod
(
shape
)
==
0
)
return
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
)
;
const
cfmav
<
Cmplx
<
T
>>
a
in
(
reinterpret_cast
<
const
Cmplx
<
T
>
*>
(
data
_in
),
shape
,
stride_
in
);
const
fmav
<
Cmplx
<
T
>>
a
out
(
reinterpret_cast
<
Cmplx
<
T
>
*>
(
data_out
),
shape
,
stride_
out
);
general_nd
<
pocketfft_c
<
T
>>
(
a
in
,
a
out
,
axes
,
fct
,
nthreads
,
ExecC2C
{
forward
});
util
::
sanity_check_onetype
(
in
,
out
,
in
.
data
()
==
out
.
data
(),
axes
)
;
if
(
in
.
size
()
==
0
)
return
;
cfmav
<
Cmplx
<
T
>>
in
2
(
reinterpret_cast
<
const
Cmplx
<
T
>
*>
(
in
.
data
()),
in
);
fmav
<
Cmplx
<
T
>>
out
2
(
reinterpret_cast
<
Cmplx
<
T
>
*>
(
out
.
data
()),
out
);
general_nd
<
pocketfft_c
<
T
>>
(
in
2
,
out
2
,
axes
,
fct
,
nthreads
,
ExecC2C
{
forward
});
}
template
<
typename
T
>
void
dct
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
int
type
,
const
T
*
data_in
,
T
*
data_out
,
T
fct
,
bool
ortho
,
size_t
nthreads
=
1
)
template
<
typename
T
>
void
dct
(
const
cfmav
<
T
>
&
in
,
const
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
int
type
,
T
fct
,
bool
ortho
,
size_t
nthreads
=
1
)
{
if
((
type
<
1
)
||
(
type
>
4
))
throw
std
::
invalid_argument
(
"invalid DCT type"
);
if
(
util
::
prod
(
shape
)
==
0
)
return
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
const
cfmav
<
T
>
ain
(
data_in
,
shape
,
stride_in
);
const
fmav
<
T
>
aout
(
data_out
,
shape
,
stride_out
);
util
::
sanity_check_onetype
(
in
,
out
,
in
.
data
()
==
out
.
data
(),
axes
);
if
(
in
.
size
()
==
0
)
return
;
const
ExecDcst
exec
{
ortho
,
type
,
true
};
if
(
type
==
1
)
general_nd
<
T_dct1
<
T
>>
(
a
in
,
a
out
,
axes
,
fct
,
nthreads
,
exec
);
general_nd
<
T_dct1
<
T
>>
(
in
,
out
,
axes
,
fct
,
nthreads
,
exec
);
else
if
(
type
==
4
)
general_nd
<
T_dcst4
<
T
>>
(
a
in
,
a
out
,
axes
,
fct
,
nthreads
,
exec
);
general_nd
<
T_dcst4
<
T
>>
(
in
,
out
,
axes
,
fct
,
nthreads
,
exec
);
else
general_nd
<
T_dcst23
<
T
>>
(
a
in
,
a
out
,
axes
,
fct
,
nthreads
,
exec
);
general_nd
<
T_dcst23
<
T
>>
(
in
,
out
,
axes
,
fct
,
nthreads
,
exec
);
}
template
<
typename
T
>
void
dst
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
int
type
,
const
T
*
data_in
,
T
*
data_out
,
T
fct
,
bool
ortho
,
size_t
nthreads
=
1
)
template
<
typename
T
>
void
dst
(
const
cfmav
<
T
>
&
in
,
const
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
int
type
,
T
fct
,
bool
ortho
,
size_t
nthreads
=
1
)
{
if
((
type
<
1
)
||
(
type
>
4
))
throw
std
::
invalid_argument
(
"invalid DST type"
);
if
(
util
::
prod
(
shape
)
==
0
)
return
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
const
cfmav
<
T
>
ain
(
data_in
,
shape
,
stride_in
);
const
fmav
<
T
>
aout
(
data_out
,
shape
,
stride_out
);
util
::
sanity_check_onetype
(
in
,
out
,
in
.
data
()
==
out
.
data
(),
axes
);
const
ExecDcst
exec
{
ortho
,
type
,
false
};
if
(
type
==
1
)
general_nd
<
T_dst1
<
T
>>
(
a
in
,
a
out
,
axes
,
fct
,
nthreads
,
exec
);
general_nd
<
T_dst1
<
T
>>
(
in
,
out
,
axes
,
fct
,
nthreads
,
exec
);
else
if
(
type
==
4
)
general_nd
<
T_dcst4
<
T
>>
(
a
in
,
a
out
,
axes
,
fct
,
nthreads
,
exec
);
general_nd
<
T_dcst4
<
T
>>
(
in
,
out
,
axes
,
fct
,
nthreads
,
exec
);
else
general_nd
<
T_dcst23
<
T
>>
(
a
in
,
a
out
,
axes
,
fct
,
nthreads
,
exec
);
general_nd
<
T_dcst23
<
T
>>
(
in
,
out
,
axes
,
fct
,
nthreads
,
exec
);
}
template
<
typename
T
>
void
r2c
(
const
shape_t
&
shape_in
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
size_t
axis
,
bool
forward
,
const
T
*
data_in
,
std
::
complex
<
T
>
*
data_out
,
T
fct
,
template
<
typename
T
>
void
r2c
(
const
cfmav
<
T
>
&
in
,
const
fmav
<
std
::
complex
<
T
>>
&
out
,
size_t
axis
,
bool
forward
,
T
fct
,
size_t
nthreads
=
1
)
{
if
(
util
::
prod
(
shape_in
)
==
0
)
return
;
util
::
sanity_check
(
shape_in
,
stride_in
,
stride_out
,
false
,
axis
);
const
cfmav
<
T
>
ain
(
data_in
,
shape_in
,
stride_in
);
shape_t
shape_out
(
shape_in
);
shape_out
[
axis
]
=
shape_in
[
axis
]
/
2
+
1
;
const
fmav
<
Cmplx
<
T
>>
aout
(
reinterpret_cast
<
Cmplx
<
T
>
*>
(
data_out
),
shape_out
,
stride_out
);
general_r2c
(
ain
,
aout
,
axis
,
forward
,
fct
,
nthreads
);
util
::
sanity_check_cr
(
out
,
in
,
axis
);
if
(
in
.
size
()
==
0
)
return
;
fmav
<
Cmplx
<
T
>>
out2
(
reinterpret_cast
<
Cmplx
<
T
>
*>
(
out
.
data
()),
out
);
general_r2c
(
in
,
out2
,
axis
,
forward
,
fct
,
nthreads
);
}
template
<
typename
T
>
void
r2c
(
const
shape_t
&
shape_in
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
bool
forward
,
const
T
*
data_in
,
std
::
complex
<
T
>
*
data_out
,
T
fct
,
size_t
nthreads
=
1
)
template
<
typename
T
>
void
r2c
(
const
cfmav
<
T
>
&
in
,
const
fmav
<
std
::
complex
<
T
>>
&
out
,
const
shape_t
&
axes
,
bool
forward
,
T
fct
,
size_t
nthreads
=
1
)
{
if
(
util
::
prod
(
shape_in
)
==
0
)
return
;
util
::
sanity_check
(
shape_in
,
stride_in
,
stride_out
,
false
,
axes
);
r2c
(
shape_in
,
stride_in
,
stride_out
,
axes
.
back
(),
forward
,
data_in
,
data_out
,
fct
,
nthreads
);
util
::
sanity_check_cr
(
out
,
in
,
axes
);
if
(
in
.
size
()
==
0
)
return
;
r2c
(
in
,
out
,
axes
.
back
(),
forward
,
fct
,
nthreads
);
if
(
axes
.
size
()
==
1
)
return
;
shape_t
shape_out
(
shape_in
);
shape_out
[
axes
.
back
()]
=
shape_in
[
axes
.
back
()]
/
2
+
1
;
auto
newaxes
=
shape_t
{
axes
.
begin
(),
--
axes
.
end
()};
c2c
(
shape_out
,
stride_out
,
stride_out
,
newaxes
,
forward
,
data_out
,
data_out
,
T
(
1
),
nthreads
);
c2c
(
out
,
out
,
newaxes
,
forward
,
T
(
1
),
nthreads
);
}
template
<
typename
T
>
void
c2r
(
const
shape_t
&
shape_out
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
size_t
axis
,
bool
forward
,
const
std
::
complex
<
T
>
*
data_in
,
T
*
data_out
,
T
fct
,
size_t
nthreads
=
1
)
template
<
typename
T
>
void
c2r
(
const
cfmav
<
std
::
complex
<
T
>>
&
in
,
const
fmav
<
T
>
&
out
,
size_t
axis
,
bool
forward
,
T
fct
,
size_t
nthreads
=
1
)
{
if
(
util
::
prod
(
shape_out
)
==
0
)
return
;
util
::
sanity_check
(
shape_out
,
stride_in
,
stride_out
,
false
,
axis
);
shape_t
shape_in
(
shape_out
);
shape_in
[
axis
]
=
shape_out
[
axis
]
/
2
+
1
;
const
cfmav
<
Cmplx
<
T
>>
ain
(
reinterpret_cast
<
const
Cmplx
<
T
>
*>
(
data_in
),
shape_in
,
stride_in
);
const
fmav
<
T
>
aout
(
data_out
,
shape_out
,
stride_out
);
general_c2r
(
ain
,
aout
,
axis
,
forward
,
fct
,
nthreads
);
util
::
sanity_check_cr
(
in
,
out
,
axis
);
if
(
in
.
size
()
==
0
)
return
;
cfmav
<
Cmplx
<
T
>>
in2
(
reinterpret_cast
<
const
Cmplx
<
T
>
*>
(
in
.
data
()),
in
);
general_c2r
(
in2
,
out
,
axis
,
forward
,
fct
,
nthreads
);
}
template
<
typename
T
>
void
c2r
(
const
shape_t
&
shape_out
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
bool
forward
,
const
std
::
complex
<
T
>
*
data_in
,
T
*
data_out
,
T
fct
,
template
<
typename
T
>
void
c2r
(
const
cfmav
<
std
::
complex
<
T
>>
&
in
,
const
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
bool
forward
,
T
fct
,
size_t
nthreads
=
1
)
{
if
(
util
::
prod
(
shape_out
)
==
0
)
return
;
if
(
axes
.
size
()
==
1
)
return
c2r
(
shape_out
,
stride_in
,
stride_out
,
axes
[
0
],
forward
,
data_in
,
data_out
,
fct
,
nthreads
);
util
::
sanity_check
(
shape_out
,
stride_in
,
stride_out
,
false
,
axes
);
auto
shape_in
=
shape_out
;
shape_in
[
axes
.
back
()]
=
shape_out
[
axes
.
back
()]
/
2
+
1
;
auto
nval
=
util
::
prod
(
shape_in
);
stride_t
stride_inter
(
shape_in
.
size
());
stride_inter
.
back
()
=
1
;
for
(
int
i
=
int
(
shape_in
.
size
())
-
2
;
i
>=
0
;
--
i
)
stride_inter
[
size_t
(
i
)]
=
stride_inter
[
size_t
(
i
+
1
)]
*
ptrdiff_t
(
shape_in
[
size_t
(
i
+
1
)]);
aligned_array
<
std
::
complex
<
T
>>
tmp
(
nval
);
return
c2r
(
in
,
out
,
axes
[
0
],
forward
,
fct
,
nthreads
);
util
::
sanity_check_cr
(
in
,
out
,
axes
);
if
(
in
.
size
()
==
0
)
return
;
aligned_array
<
std
::
complex
<
T
>>
tmp
(
in
.
size
());
fmav
<
std
::
complex
<
T
>>
atmp
(
tmp
.
data
(),
in
);
auto
newaxes
=
shape_t
({
axes
.
begin
(),
--
axes
.
end
()});
c2c
(
shape_in
,
stride_in
,
stride_inter
,
newaxes
,
forward
,
data_in
,
tmp
.
data
(),
T
(
1
),
nthreads
);
c2r
(
shape_out
,
stride_inter
,
stride_out
,
axes
.
back
(),
forward
,
tmp
.
data
(),
data_out
,
fct
,
nthreads
);
c2c
(
in
,
atmp
,
newaxes
,
forward
,
T
(
1
),
nthreads
);
c2r
(
atmp
,
out
,
axes
.
back
(),
forward
,
fct
,
nthreads
);
}
template
<
typename
T
>
void
r2r_fftpack
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
bool
real2hermitian
,
bool
forward
,
const
T
*
data_in
,
T
*
data_out
,
T
fct
,
size_t
nthreads
=
1
)
template
<
typename
T
>
void
r2r_fftpack
(
const
cfmav
<
T
>
&
in
,
const
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
bool
real2hermitian
,
bool
forward
,
T
fct
,
size_t
nthreads
=
1
)
{
if
(
util
::
prod
(
shape
)
==
0
)
return
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
const
cfmav
<
T
>
ain
(
data_in
,
shape
,
stride_in
);
const
fmav
<
T
>
aout
(
data_out
,
shape
,
stride_out
);
general_nd
<
pocketfft_r
<
T
>>
(
ain
,
aout
,
axes
,
fct
,
nthreads
,
util
::
sanity_check_onetype
(
in
,
out
,
in
.
data
()
==
out
.
data
(),
axes
);
if
(
in
.
size
()
==
0
)
return
;
general_nd
<
pocketfft_r
<
T
>>
(
in
,
out
,
axes
,
fct
,
nthreads
,
ExecR2R
{
real2hermitian
,
forward
});
}
template
<
typename
T
>
void
r2r_separable_hartley
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
const
T
*
data_in
,
T
*
data_out
,
T
fct
,
size_t
nthreads
=
1
)
template
<
typename
T
>
void
r2r_separable_hartley
(
const
cfmav
<
T
>
&
in
,
const
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
T
fct
,
size_t
nthreads
=
1
)
{
if
(
util
::
prod
(
shape
)
==
0
)
return
;
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
);
const
cfmav
<
T
>
ain
(
data_in
,
shape
,
stride_in
);
const
fmav
<
T
>
aout
(
data_out
,
shape
,
stride_out
);
general_nd
<
pocketfft_r
<
T
>>
(
ain
,
aout
,
axes
,
fct
,
nthreads
,
ExecHartley
{},
util
::
sanity_check_onetype
(
in
,
out
,
in
.
data
()
==
out
.
data
(),
axes
);
if
(
in
.
size
()
==
0
)
return
;
general_nd
<
pocketfft_r
<
T
>>
(
in
,
out
,
axes
,
fct
,
nthreads
,
ExecHartley
{},
false
);
}
template
<
typename
T
>
void
r2r_genuine_hartley
(
const
shape_t
&
shape
,
const
stride_t
&
stride_in
,
const
stride_t
&
stride_out
,
const
shape_t
&
axes
,
const
T
*
data_in
,
T
*
data_out
,
T
fct
,
size_t
nthreads
=
1
)
template
<
typename
T
>
void
r2r_genuine_hartley
(
const
cfmav
<
T
>
&
in
,
const
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
T
fct
,
size_t
nthreads
=
1
)
{
if
(
util
::
prod
(
shape
)
==
0
)
return
;
if
(
axes
.
size
()
==
1
)
return
r2r_separable_hartley
(
shape
,
stride_in
,
stride_out
,
axes
,
data_in
,
data_out
,
fct
,
nthread
s
);
util
::
sanity_check
(
shape
,
stride_in
,
stride_out
,
data_in
==
data_out
,
axes
)
;
shape_t
tshp
(
shape
);
return
r2r_separable_hartley
(
in
,
out
,
axes
,
fct
,
nthreads
);
util
::
sanity_check_onetype
(
in
,
out
,
in
.
data
()
==
out
.
data
(),
axe
s
);
if
(
in
.
size
()
==
0
)
return
;
shape_t
tshp
(
in
.
shape
()
);
tshp
[
axes
.
back
()]
=
tshp
[
axes
.
back
()]
/
2
+
1
;
aligned_array
<
std
::
complex
<
T
>>
tdata
(
util
::
prod
(
tshp
));
stride_t
tstride
(
shape
.
size
());
tstride
.
back
()
=
1
;
for
(
size_t
i
=
tstride
.
size
()
-
1
;
i
>
0
;
--
i
)
tstride
[
i
-
1
]
=
tstride
[
i
]
*
ptrdiff_t
(
tshp
[
i
]);
r2c
(
shape
,
stride_in
,
tstride
,
axes
,
true
,
data_in
,
tdata
.
data
(),
fct
,
nthreads
);
const
cfmav
<
Cmplx
<
T
>>
atmp
(
reinterpret_cast
<
const
Cmplx
<
T
>
*>
(
tdata
.
data
()),
tshp
,
tstride
);
const
fmav
<
T
>
aout
(
data_out
,
shape
,
stride_out
);
auto
tinfo
=
fmav_info
(
tshp
);
aligned_array
<
std
::
complex
<
T
>>
tdata
(
tinfo
.
size
());
fmav
<
std
::
complex
<
T
>>
atmp
(
tdata
.
data
(),
tinfo
);
r2c
(
in
,
atmp
,
axes
,
true
,
fct
,
nthreads
);
simple_iter
iin
(
atmp
);
rev_iter
iout
(
a
out
,
axes
);
rev_iter
iout
(
out
,
axes
);
while
(
iin
.
remaining
()
>
0
)
{
auto
v
=
atmp
[
iin
.
ofs
()];
a
out
[
iout
.
ofs
()]
=
v
.
r
+
v
.
i
;
a
out
[
iout
.
rev_ofs
()]
=
v
.
r
-
v
.
i
;
out
[
iout
.
ofs
()]
=
v
.
r
eal
()
+
v
.
imag
()
;
out
[
iout
.
rev_ofs
()]
=
v
.
r
eal
()
-
v
.
imag
()
;
iin
.
advance
();
iout
.
advance
();
}
}
...
...
mr_util/mav.h
View file @
8705186d
...
...
@@ -77,7 +77,7 @@ class fmav_info
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
if
(
str
[
ndim
-
1
-
i
]
!=
stride
)
return
false
;
stride
*=
shp
[
ndim
-
1
-
i
];
stride
*=
ptrdiff_t
(
shp
[
ndim
-
1
-
i
]
)
;
}
return
true
;
}
...
...
@@ -96,6 +96,8 @@ template<typename T> class cfmav: public fmav_info
:
fmav_info
(
shp_
,
str_
),
d
(
const_cast
<
T
*>
(
d_
))
{}
cfmav
(
const
T
*
d_
,
const
shape_t
&
shp_
)
:
fmav_info
(
shp_
),
d
(
const_cast
<
T
*>
(
d_
))
{}
cfmav
(
const
T
*
d_
,
const
fmav_info
&
info
)
:
fmav_info
(
info
),
d
(
const_cast
<
T
*>
(
d_
))
{}
template
<
typename
I
>
const
T
&
operator
[](
I
i
)
const
{
return
d
[
i
];
}
const
T
*
data
()
const
...
...
@@ -114,6 +116,8 @@ template<typename T> class fmav: public cfmav<T>
:
parent
(
d_
,
shp_
,
str_
)
{}
fmav
(
T
*
d_
,
const
shape_t
&
shp_
)
:
parent
(
d_
,
shp_
)
{}
fmav
(
T
*
d_
,
const
fmav_info
&
info
)
:
parent
(
d_
,
info
)
{}
template
<
typename
I
>
T
&
operator
[](
I
i
)
const
{
return
d
[
i
];
}
using
parent
::
shape
;
...
...
@@ -125,36 +129,40 @@ template<typename T> class fmav: public cfmav<T>
using
parent
::
conformable
;
};
template
<
typename
T
>
using
const_fmav
=
fmav
<
const
T
>
;
template
<
typename
T
,
size_t
ndim
>
class
mav
template
<
typename
T
,
size_t
ndim
>
class
cmav
{
static_assert
((
ndim
>
0
)
&&
(
ndim
<
3
),
"only supports 1D and 2D arrays"
);
pr
ivate
:
pr
otected
:
T
*
d
;
array
<
size_t
,
ndim
>
shp
;
array
<
ptrdiff_t
,
ndim
>
str
;
public:
mav
(
T
*
d_
,
const
array
<
size_t
,
ndim
>
&
shp_
,
c
mav
(
const
T
*
d_
,
const
array
<
size_t
,
ndim
>
&
shp_
,
const
array
<
ptrdiff_t
,
ndim
>
&
str_
)
:
d
(
d_
),
shp
(
shp_
),
str
(
str_
)
{}
mav
(
T
*
d_
,
const
array
<
size_t
,
ndim
>
&
shp_
)
:
d
(
d_
),
shp
(
shp_
)
:
d
(
const_cast
<
T
*>
(
d_
)
)
,
shp
(
shp_
),
str
(
str_
)
{}
c
mav
(
const
T
*
d_
,
const
array
<
size_t
,
ndim
>
&
shp_
)
:
d
(
const_cast
<
T
*>
(
d_
)
)
,
shp
(
shp_
)
{
str
[
ndim
-
1
]
=
1
;
for
(
size_t
i
=
2
;
i
<=
ndim
;
++
i
)
str
[
ndim
-
i
]
=
str
[
ndim
-
i
+
1
]
*
shp
[
ndim
-
i
+
1
];
}
T
&
operator
[](
size_t
i
)
const
operator
cfmav
<
T
>
()
const
{
return
cfmav
<
T
>
(
d
,
{
shp
.
begin
(),
shp
.
end
()},
{
str
.
begin
(),
str
.
end
()});
}
const
T
&
operator
[](
size_t
i
)
const
{
return
operator
()(
i
);
}
T
&
operator
()(
size_t
i
)
const
const
T
&
operator
()(
size_t
i
)
const
{
static_assert
(
ndim
==
1
,
"ndim must be 1"
);
return
d
[
str
[
0
]
*
i
];
}
T
&
operator
()(
size_t
i
,
size_t
j
)
const
const
T
&
operator
()(
size_t
i
,
size_t
j
)
const
{
static_assert
(
ndim
==
2
,
"ndim must be 2"
);
return
d
[
str
[
0
]
*
i
+
str
[
1
]
*
j
];
...
...
@@ -168,7 +176,7 @@ template<typename T, size_t ndim> class mav
return
res
;
}
ptrdiff_t
stride
(
size_t
i
)
const
{
return
str
[
i
];
}
T
*
data
()
const
const
T
*
data
()
const
{
return
d
;
}
bool
last_contiguous
()
const
{
return
(
str
[
ndim
-
1
]
==
1
)
||
(
str
[
ndim
-
1
]
==
0
);
}
...
...
@@ -182,6 +190,46 @@ template<typename T, size_t ndim> class mav
}
return
true
;
}
template
<
typename
T2
>
bool
conformable
(
const
cmav
<
T2
,
ndim
>
&
other
)
const
{
return
shp
==
other
.
shp
;
}
};
template
<
typename
T
,
size_t
ndim
>
class
mav
:
public
cmav
<
T
,
ndim
>
{
protected:
using
parent
=
cmav
<
T
,
ndim
>
;
using
parent
::
d
;
using
parent
::
shp
;
using
parent
::
str
;
public:
mav
(
T
*
d_
,
const
array
<
size_t
,
ndim
>
&
shp_
,
const
array
<
ptrdiff_t
,
ndim
>
&
str_
)
:
parent
(
d_
,
shp_
,
str_
)
{}
mav
(
T
*
d_
,
const
array
<
size_t
,
ndim
>
&
shp_
)
:
parent
(
d_
,
shp_
)
{}
operator
fmav
<
T
>
()
const
{
return
fmav
<
T
>
(
d
,
{
shp
.
begin
(),
shp
.
end
()},
{
str
.
begin
(),
str
.
end
()});
}
T
&
operator
[](
size_t
i
)
const
{
return
operator
()(
i
);
}
T
&
operator
()(
size_t
i
)
const
{
static_assert
(
ndim
==
1
,
"ndim must be 1"
);
return
d
[
str
[
0
]
*
i
];
}
T
&
operator
()(
size_t
i
,
size_t
j
)
const
{
static_assert
(
ndim
==
2
,
"ndim must be 2"
);
return
d
[
str
[
0
]
*
i
+
str
[
1
]
*
j
];
}
using
parent
::
shape
;
using
parent
::
stride
;
using
parent
::
size
;
T
*
data
()
const
{
return
d
;
}
using
parent
::
last_contiguous
;
using
parent
::
contiguous
;
void
fill
(
const
T
&
val
)
const
{
// FIXME: special cases for contiguous arrays and/or zeroing?
...
...
@@ -193,20 +241,9 @@ template<typename T, size_t ndim> class mav
for
(
size_t
j
=
0
;
j
<
shp
[
1
];
++
j
)
d
[
str
[
0
]
*
i
+
str
[
1
]
*
j
]
=
val
;
}
template
<
typename
T2
>
bool
conformable
(
const
mav
<
T2
,
ndim
>
&
other
)
const
{
return
shp
==
other
.
shp
;
}
using
parent
::
conformable
;
};
template
<
typename
T
,
size_t
ndim
>
using
const_mav
=
mav
<
const
T
,
ndim
>
;
template
<
typename
T
,
size_t
ndim
>
const_mav
<
T
,
ndim
>
cmav
(
const
mav
<
T
,
ndim
>
&
mav
)
{
return
const_mav
<
T
,
ndim
>
(
mav
.
data
(),
mav
.
shape
());
}
template
<
typename
T
,
size_t
ndim
>
const_mav
<
T
,
ndim
>
nullmav
()
{
array
<
size_t
,
ndim
>
shp
;
shp
.
fill
(
0
);
return
const_mav
<
T
,
ndim
>
(
nullptr
,
shp
);
}
}
using
detail_mav
::
shape_t
;
...
...
@@ -214,11 +251,8 @@ using detail_mav::stride_t;
using
detail_mav
::
fmav_info
;
using
detail_mav
::
fmav
;
using
detail_mav
::
cfmav
;
using
detail_mav
::
const_fmav
;
using
detail_mav
::
mav
;
using
detail_mav
::
const_mav
;
using
detail_mav
::
cmav
;
using
detail_mav
::
nullmav
;
}
...
...
nifty_gridder/gridder_cxx.h
View file @
8705186d
...
...
@@ -76,7 +76,6 @@ template<typename T, size_t ndim> class tmpStorage
tmpStorage
(
const
array
<
size_t
,
ndim
>
&
shp
)
:
d
(
prod
(
shp
)),
mav_
(
d
.
data
(),
shp
)
{}
mav
<
T
,
ndim
>
&
getMav
()
{
return
mav_
;
}
const_mav
<
T
,
ndim
>
getCmav
()
{
return
cmav
(
mav_
);
}
void
fill
(
const
T
&
val
)
{
std
::
fill
(
d
.
begin
(),
d
.
end
(),
val
);
}
};
...
...
@@ -86,7 +85,7 @@ template<typename T, size_t ndim> class tmpStorage
//
template
<
typename
T
>
void
complex2hartley
(
const
c
onst_
mav
<
complex
<
T
>
,
2
>
&
grid
,
const
mav
<
T
,
2
>
&
grid2
,
size_t
nthreads
)
(
const
cmav
<
complex
<
T
>
,
2
>
&
grid
,
const
mav
<
T
,
2
>
&
grid2
,
size_t
nthreads
)
{
checkShape
(
grid
.
shape
(),
grid2
.
shape
());
size_t
nu
=
grid
.
shape
(
0
),
nv
=
grid
.
shape
(
1
);
...
...
@@ -107,7 +106,7 @@ template<typename T> void complex2hartley
}
template
<
typename
T
>
void
hartley2complex
(
const
c
onst_
mav
<
T
,
2
>
&
grid
,
const
mav
<
complex
<
T
>
,
2
>
&
grid2
,
size_t
nthreads
)
(
const
cmav
<
T
,
2
>
&
grid
,
const
mav
<
complex
<
T
>
,
2
>
&
grid2
,
size_t
nthreads
)
{
checkShape
(
grid
.
shape
(),
grid2
.
shape
());
size_t
nu
=
grid
.
shape
(
0
),
nv
=
grid
.
shape
(
1
);
...
...
@@ -128,17 +127,13 @@ template<typename T> void hartley2complex
});
}
template
<
typename
T
>
void
hartley2_2D
(
const
c
onst_
mav
<
T
,
2
>
&
in
,
template
<
typename
T
>
void
hartley2_2D
(
const
cmav
<
T
,
2
>
&
in
,
const
mav
<
T
,
2
>
&
out
,
size_t
nthreads
)
{
checkShape
(
in
.
shape
(),
out
.
shape
());
size_t
nu
=
in
.
shape
(
0
),
nv
=
in
.
shape
(
1
);
stride_t
stri
{
in
.
stride
(
0
),
in
.
stride
(
1
)};
stride_t
stro
{
out
.
stride
(
0
),
out
.
stride
(
1
)};
auto
d_i
=
in
.
data
();
auto
ptmp
=
out
.
data
();
r2r_separable_hartley
({
nu
,
nv
},
stri
,
stro
,
{
0
,
1
},
d_i
,
ptmp
,
T
(
1
),
nthreads
);
r2r_separable_hartley
(
cfmav
<
T
>
(
in
),
fmav
<
T
>
(
out
),
{
0
,
1
},
T
(
1
),
nthreads
);
execStatic
((
nu
+
1
)
/
2
-
1
,
nthreads
,
0
,
[
&
](
Scheduler
&
sched
)
{
while
(
auto
rng
=
sched
.
getNext
())
for
(
auto
i
=
rng
.
lo
+
1
;
i
<
rng
.
hi
+
1
;
++
i
)
...
...
@@ -285,8 +280,8 @@ class Baselines
idx_t
shift
,
mask
;
public:
template
<
typename
T
>
Baselines
(
const
c
onst_
mav
<
T
,
2
>
&
coord_
,
const
c
onst_
mav
<
T
,
1
>
&
freq
,
bool
negate_v
=
false
)
template
<
typename
T
>
Baselines
(
const
cmav
<
T
,
2
>
&
coord_
,
const
cmav
<
T
,
1
>
&
freq
,
bool
negate_v
=
false
)
{
constexpr
double
speedOfLight
=
299792458.
;
MR_assert
(
coord_
.
shape
(
1
)
==
3
,
"dimension mismatch"
);
...
...
@@ -457,7 +452,7 @@ class GridderConfig
});
}
template
<
typename
T
>
void
grid2dirty
(
const
c
onst_
mav
<
T
,
2
>
&
grid
,
template
<
typename
T
>
void
grid2dirty
(
const
cmav
<
T
,
2
>
&
grid
,
const
mav
<
T
,
2
>
&
dirty
)
const
{
checkShape
(
grid
.
shape
(),
{
nu
,
nv
});
...
...
@@ -471,13 +466,12 @@ class GridderConfig
(
const
mav
<
complex
<
T
>
,
2
>
&
grid
,
const
mav
<
T
,
2
>
&
dirty
,
T
w
)
const
{
checkShape
(
grid
.
shape
(),
{
nu
,
nv
});
c2c
({
nu
,
nv
},{
grid
.
stride
(
0
),
grid
.
stride
(
1
)},
{
grid
.
stride
(
0
),
grid
.
stride
(
1
)},
{
0
,
1
},
BACKWARD
,
grid
.
data
(),
grid
.
data
(),
T
(
1
),
nthreads
);
fmav
<
complex
<
T
>>
inout
(
grid
);
c2c
(
inout
,
inout
,
{
0
,
1
},
BACKWARD
,
T
(
1
),
nthreads
);
grid2dirty_post2
(
grid
,
dirty
,
w
);
}
template
<
typename
T
>
void
dirty2grid_pre
(
const
c
onst_
mav
<
T
,
2
>
&
dirty
,
template
<
typename
T
>
void
dirty2grid_pre
(
const
cmav
<
T
,
2
>
&
dirty
,
const
mav
<
T
,
2
>
&
grid
)
const
{
checkShape
(
dirty
.
shape
(),
{
nx_dirty
,
ny_dirty
});
...
...
@@ -502,7 +496,7 @@ class GridderConfig
}