Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Martin Reinecke
ducc
Commits
f1c7d8d8
Commit
f1c7d8d8
authored
Feb 23, 2020
by
Martin Reinecke
Browse files
Merge branch 'mav_rewrite' into 'master'
Mav rewrite See merge request mtr/cxxbase!1
parents
0d8a2890
43f38678
Changes
12
Expand all
Show whitespace changes
Inline
Side-by-side
Healpix_cxx/healpix_base.cc
View file @
f1c7d8d8
...
...
@@ -498,9 +498,9 @@ template<typename I> template<typename I2>
double
dr
=
base
[
o
].
max_pixrad
();
// safety distance
for
(
size_t
i
=
0
;
i
<
nv
;
++
i
)
{
crlimit
(
o
,
i
,
0
)
=
(
rad
[
i
]
+
dr
>
pi
)
?
-
1.
:
cos
(
rad
[
i
]
+
dr
);
crlimit
(
o
,
i
,
1
)
=
(
o
==
0
)
?
cos
(
rad
[
i
])
:
crlimit
(
0
,
i
,
1
);
crlimit
(
o
,
i
,
2
)
=
(
rad
[
i
]
-
dr
<
0.
)
?
1.
:
cos
(
rad
[
i
]
-
dr
);
crlimit
.
v
(
o
,
i
,
0
)
=
(
rad
[
i
]
+
dr
>
pi
)
?
-
1.
:
cos
(
rad
[
i
]
+
dr
);
crlimit
.
v
(
o
,
i
,
1
)
=
(
o
==
0
)
?
cos
(
rad
[
i
])
:
crlimit
(
0
,
i
,
1
);
crlimit
.
v
(
o
,
i
,
2
)
=
(
rad
[
i
]
-
dr
<
0.
)
?
1.
:
cos
(
rad
[
i
]
-
dr
);
}
}
...
...
@@ -563,9 +563,9 @@ template<typename I> void T_Healpix_Base<I>::query_multidisc_general
double
dr
=
base
[
o
].
max_pixrad
();
// safety distance
for
(
size_t
i
=
0
;
i
<
nv
;
++
i
)
{
crlimit
(
o
,
i
,
0
)
=
(
rad
[
i
]
+
dr
>
pi
)
?
-
1.
:
cos
(
rad
[
i
]
+
dr
);
crlimit
(
o
,
i
,
1
)
=
(
o
==
0
)
?
cos
(
rad
[
i
])
:
crlimit
(
0
,
i
,
1
);
crlimit
(
o
,
i
,
2
)
=
(
rad
[
i
]
-
dr
<
0.
)
?
1.
:
cos
(
rad
[
i
]
-
dr
);
crlimit
.
v
(
o
,
i
,
0
)
=
(
rad
[
i
]
+
dr
>
pi
)
?
-
1.
:
cos
(
rad
[
i
]
+
dr
);
crlimit
.
v
(
o
,
i
,
1
)
=
(
o
==
0
)
?
cos
(
rad
[
i
])
:
crlimit
(
0
,
i
,
1
);
crlimit
.
v
(
o
,
i
,
2
)
=
(
rad
[
i
]
-
dr
<
0.
)
?
1.
:
cos
(
rad
[
i
]
-
dr
);
}
}
...
...
mr_util/error_handling.h
View file @
f1c7d8d8
...
...
@@ -25,6 +25,8 @@
#include <sstream>
#include <exception>
#include "mr_util/useful_macros.h"
namespace
mr
{
namespace
detail_error_handling
{
...
...
@@ -35,19 +37,6 @@ namespace detail_error_handling {
#define MRUTIL_ERROR_HANDLING_LOC_ ::mr::detail_error_handling::CodeLocation(__FILE__, __LINE__)
#endif
#define MR_fail(...) \
do { \
::std::ostringstream msg; \
::mr::detail_error_handling::streamDump__(msg, MRUTIL_ERROR_HANDLING_LOC_, "\n", ##__VA_ARGS__, "\n"); \
throw ::std::runtime_error(msg.str()); \
} while(0)
#define MR_assert(cond,...) \
do { \
if (cond); \
else { MR_fail("Assertion failure\n", ##__VA_ARGS__); } \
} while(0)
// to be replaced with std::source_location once generally available
class
CodeLocation
{
...
...
@@ -73,21 +62,39 @@ inline ::std::ostream &operator<<(::std::ostream &os, const CodeLocation &loc)
#if (__cplusplus>=201703L) // hyper-elegant C++2017 version
template
<
typename
...
Args
>
inline
void
streamDump__
(
::
std
::
ostream
&
os
,
Args
&&
...
args
)
void
streamDump__
(
::
std
::
ostream
&
os
,
Args
&&
...
args
)
{
(
os
<<
...
<<
args
);
}
#else
template
<
typename
T
>
inline
void
streamDump__
(
::
std
::
ostream
&
os
,
const
T
&
value
)
void
streamDump__
(
::
std
::
ostream
&
os
,
const
T
&
value
)
{
os
<<
value
;
}
template
<
typename
T
,
typename
...
Args
>
inline
void
streamDump__
(
::
std
::
ostream
&
os
,
const
T
&
value
,
void
streamDump__
(
::
std
::
ostream
&
os
,
const
T
&
value
,
const
Args
&
...
args
)
{
os
<<
value
;
streamDump__
(
os
,
args
...);
}
#endif
template
<
typename
...
Args
>
[[
noreturn
]]
void
MRUTIL_NOINLINE
fail__
(
Args
&&
...
args
)
{
::
std
::
ostringstream
msg
;
\
::
mr
::
detail_error_handling
::
streamDump__
(
msg
,
args
...);
\
throw
::
std
::
runtime_error
(
msg
.
str
());
\
}
#define MR_fail(...) \
do { \
::mr::detail_error_handling::fail__(MRUTIL_ERROR_HANDLING_LOC_, "\n", ##__VA_ARGS__, "\n"); \
} while(0)
#define MR_assert(cond,...) \
do { \
if (cond); \
else { MR_fail("Assertion failure\n", ##__VA_ARGS__); } \
} while(0)
}}
...
...
mr_util/fft.h
View file @
f1c7d8d8
...
...
@@ -621,7 +621,7 @@ template<typename T, typename T0> aligned_array<T> alloc_tmp
}
template
<
typename
T
,
size_t
vlen
>
void
copy_input
(
const
multi_iter
<
vlen
>
&
it
,
const
c
fmav
<
Cmplx
<
T
>>
&
src
,
Cmplx
<
native_simd
<
T
>>
*
MRUTIL_RESTRICT
dst
)
const
fmav
<
Cmplx
<
T
>>
&
src
,
Cmplx
<
native_simd
<
T
>>
*
MRUTIL_RESTRICT
dst
)
{
for
(
size_t
i
=
0
;
i
<
it
.
length_in
();
++
i
)
for
(
size_t
j
=
0
;
j
<
vlen
;
++
j
)
...
...
@@ -632,7 +632,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
}
template
<
typename
T
,
size_t
vlen
>
void
copy_input
(
const
multi_iter
<
vlen
>
&
it
,
const
c
fmav
<
T
>
&
src
,
native_simd
<
T
>
*
MRUTIL_RESTRICT
dst
)
const
fmav
<
T
>
&
src
,
native_simd
<
T
>
*
MRUTIL_RESTRICT
dst
)
{
for
(
size_t
i
=
0
;
i
<
it
.
length_in
();
++
i
)
for
(
size_t
j
=
0
;
j
<
vlen
;
++
j
)
...
...
@@ -640,7 +640,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
}
template
<
typename
T
,
size_t
vlen
>
void
copy_input
(
const
multi_iter
<
vlen
>
&
it
,
const
c
fmav
<
T
>
&
src
,
T
*
MRUTIL_RESTRICT
dst
)
const
fmav
<
T
>
&
src
,
T
*
MRUTIL_RESTRICT
dst
)
{
if
(
dst
==
&
src
[
it
.
iofs
(
0
)])
return
;
// in-place
for
(
size_t
i
=
0
;
i
<
it
.
length_in
();
++
i
)
...
...
@@ -648,27 +648,30 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
}
template
<
typename
T
,
size_t
vlen
>
void
copy_output
(
const
multi_iter
<
vlen
>
&
it
,
const
Cmplx
<
native_simd
<
T
>>
*
MRUTIL_RESTRICT
src
,
const
fmav
<
Cmplx
<
T
>>
&
dst
)
const
Cmplx
<
native_simd
<
T
>>
*
MRUTIL_RESTRICT
src
,
fmav
<
Cmplx
<
T
>>
&
dst
)
{
auto
ptr
=
dst
.
vdata
();
for
(
size_t
i
=
0
;
i
<
it
.
length_out
();
++
i
)
for
(
size_t
j
=
0
;
j
<
vlen
;
++
j
)
dst
[
it
.
oofs
(
j
,
i
)].
Set
(
src
[
i
].
r
[
j
],
src
[
i
].
i
[
j
]);
ptr
[
it
.
oofs
(
j
,
i
)].
Set
(
src
[
i
].
r
[
j
],
src
[
i
].
i
[
j
]);
}
template
<
typename
T
,
size_t
vlen
>
void
copy_output
(
const
multi_iter
<
vlen
>
&
it
,
const
native_simd
<
T
>
*
MRUTIL_RESTRICT
src
,
const
fmav
<
T
>
&
dst
)
const
native_simd
<
T
>
*
MRUTIL_RESTRICT
src
,
fmav
<
T
>
&
dst
)
{
auto
ptr
=
dst
.
vdata
();
for
(
size_t
i
=
0
;
i
<
it
.
length_out
();
++
i
)
for
(
size_t
j
=
0
;
j
<
vlen
;
++
j
)
dst
[
it
.
oofs
(
j
,
i
)]
=
src
[
i
][
j
];
ptr
[
it
.
oofs
(
j
,
i
)]
=
src
[
i
][
j
];
}
template
<
typename
T
,
size_t
vlen
>
void
copy_output
(
const
multi_iter
<
vlen
>
&
it
,
const
T
*
MRUTIL_RESTRICT
src
,
const
fmav
<
T
>
&
dst
)
const
T
*
MRUTIL_RESTRICT
src
,
fmav
<
T
>
&
dst
)
{
auto
ptr
=
dst
.
vdata
();
if
(
src
==
&
dst
[
it
.
oofs
(
0
)])
return
;
// in-place
for
(
size_t
i
=
0
;
i
<
it
.
length_out
();
++
i
)
dst
[
it
.
oofs
(
i
)]
=
src
[
i
];
ptr
[
it
.
oofs
(
i
)]
=
src
[
i
];
}
template
<
typename
T
>
struct
add_vec
{
using
type
=
native_simd
<
T
>
;
};
...
...
@@ -677,7 +680,7 @@ template <typename T> struct add_vec<Cmplx<T>>
template
<
typename
T
>
using
add_vec_t
=
typename
add_vec
<
T
>::
type
;
template
<
typename
Tplan
,
typename
T
,
typename
T0
,
typename
Exec
>
MRUTIL_NOINLINE
void
general_nd
(
const
c
fmav
<
T
>
&
in
,
const
fmav
<
T
>
&
out
,
MRUTIL_NOINLINE
void
general_nd
(
const
fmav
<
T
>
&
in
,
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
T0
fct
,
size_t
nthreads
,
const
Exec
&
exec
,
const
bool
allow_inplace
=
true
)
{
...
...
@@ -709,7 +712,7 @@ MRUTIL_NOINLINE void general_nd(const cfmav<T> &in, const fmav<T> &out,
{
it
.
advance
(
1
);
auto
buf
=
allow_inplace
&&
it
.
stride_out
()
==
1
?
&
out
[
it
.
oofs
(
0
)
]
:
reinterpret_cast
<
T
*>
(
storage
.
data
());
&
out
.
vraw
(
it
.
oofs
(
0
)
)
:
reinterpret_cast
<
T
*>
(
storage
.
data
());
exec
(
it
,
tin
,
out
,
buf
,
*
plan
,
fct
);
}
});
// end of parallel region
...
...
@@ -721,9 +724,9 @@ struct ExecC2C
{
bool
forward
;
template
<
typename
T0
,
typename
T
,
size_t
vlen
>
void
operator
()
(
const
multi_iter
<
vlen
>
&
it
,
const
c
fmav
<
Cmplx
<
T0
>>
&
in
,
const
fmav
<
Cmplx
<
T0
>>
&
out
,
T
*
buf
,
const
pocketfft_c
<
T0
>
&
plan
,
T0
fct
)
const
template
<
typename
T0
,
typename
T
,
size_t
vlen
>
void
operator
()
(
const
multi_iter
<
vlen
>
&
it
,
const
fmav
<
Cmplx
<
T0
>>
&
in
,
fmav
<
Cmplx
<
T0
>>
&
out
,
T
*
buf
,
const
pocketfft_c
<
T0
>
&
plan
,
T0
fct
)
const
{
copy_input
(
it
,
in
,
buf
);
plan
.
exec
(
buf
,
fct
,
forward
);
...
...
@@ -732,40 +735,42 @@ struct ExecC2C
};
template
<
typename
T
,
size_t
vlen
>
void
copy_hartley
(
const
multi_iter
<
vlen
>
&
it
,
const
native_simd
<
T
>
*
MRUTIL_RESTRICT
src
,
const
fmav
<
T
>
&
dst
)
const
native_simd
<
T
>
*
MRUTIL_RESTRICT
src
,
fmav
<
T
>
&
dst
)
{
auto
ptr
=
dst
.
vdata
();
for
(
size_t
j
=
0
;
j
<
vlen
;
++
j
)
dst
[
it
.
oofs
(
j
,
0
)]
=
src
[
0
][
j
];
ptr
[
it
.
oofs
(
j
,
0
)]
=
src
[
0
][
j
];
size_t
i
=
1
,
i1
=
1
,
i2
=
it
.
length_out
()
-
1
;
for
(
i
=
1
;
i
<
it
.
length_out
()
-
1
;
i
+=
2
,
++
i1
,
--
i2
)
for
(
size_t
j
=
0
;
j
<
vlen
;
++
j
)
{
dst
[
it
.
oofs
(
j
,
i1
)]
=
src
[
i
][
j
]
+
src
[
i
+
1
][
j
];
dst
[
it
.
oofs
(
j
,
i2
)]
=
src
[
i
][
j
]
-
src
[
i
+
1
][
j
];
ptr
[
it
.
oofs
(
j
,
i1
)]
=
src
[
i
][
j
]
+
src
[
i
+
1
][
j
];
ptr
[
it
.
oofs
(
j
,
i2
)]
=
src
[
i
][
j
]
-
src
[
i
+
1
][
j
];
}
if
(
i
<
it
.
length_out
())
for
(
size_t
j
=
0
;
j
<
vlen
;
++
j
)
dst
[
it
.
oofs
(
j
,
i1
)]
=
src
[
i
][
j
];
ptr
[
it
.
oofs
(
j
,
i1
)]
=
src
[
i
][
j
];
}
template
<
typename
T
,
size_t
vlen
>
void
copy_hartley
(
const
multi_iter
<
vlen
>
&
it
,
const
T
*
MRUTIL_RESTRICT
src
,
const
fmav
<
T
>
&
dst
)
const
T
*
MRUTIL_RESTRICT
src
,
fmav
<
T
>
&
dst
)
{
dst
[
it
.
oofs
(
0
)]
=
src
[
0
];
auto
ptr
=
dst
.
vdata
();
ptr
[
it
.
oofs
(
0
)]
=
src
[
0
];
size_t
i
=
1
,
i1
=
1
,
i2
=
it
.
length_out
()
-
1
;
for
(
i
=
1
;
i
<
it
.
length_out
()
-
1
;
i
+=
2
,
++
i1
,
--
i2
)
{
dst
[
it
.
oofs
(
i1
)]
=
src
[
i
]
+
src
[
i
+
1
];
dst
[
it
.
oofs
(
i2
)]
=
src
[
i
]
-
src
[
i
+
1
];
ptr
[
it
.
oofs
(
i1
)]
=
src
[
i
]
+
src
[
i
+
1
];
ptr
[
it
.
oofs
(
i2
)]
=
src
[
i
]
-
src
[
i
+
1
];
}
if
(
i
<
it
.
length_out
())
dst
[
it
.
oofs
(
i1
)]
=
src
[
i
];
ptr
[
it
.
oofs
(
i1
)]
=
src
[
i
];
}
struct
ExecHartley
{
template
<
typename
T0
,
typename
T
,
size_t
vlen
>
void
operator
()
(
const
multi_iter
<
vlen
>
&
it
,
const
c
fmav
<
T0
>
&
in
,
const
fmav
<
T0
>
&
out
,
const
multi_iter
<
vlen
>
&
it
,
const
fmav
<
T0
>
&
in
,
fmav
<
T0
>
&
out
,
T
*
buf
,
const
pocketfft_r
<
T0
>
&
plan
,
T0
fct
)
const
{
copy_input
(
it
,
in
,
buf
);
...
...
@@ -781,8 +786,8 @@ struct ExecDcst
bool
cosine
;
template
<
typename
T0
,
typename
T
,
typename
Tplan
,
size_t
vlen
>
void
operator
()
(
const
multi_iter
<
vlen
>
&
it
,
const
c
fmav
<
T0
>
&
in
,
const
fmav
<
T0
>
&
out
,
T
*
buf
,
const
Tplan
&
plan
,
T0
fct
)
const
void
operator
()
(
const
multi_iter
<
vlen
>
&
it
,
const
fmav
<
T0
>
&
in
,
fmav
<
T0
>
&
out
,
T
*
buf
,
const
Tplan
&
plan
,
T0
fct
)
const
{
copy_input
(
it
,
in
,
buf
);
plan
.
exec
(
buf
,
fct
,
ortho
,
type
,
cosine
);
...
...
@@ -791,7 +796,7 @@ struct ExecDcst
};
template
<
typename
T
>
MRUTIL_NOINLINE
void
general_r2c
(
const
c
fmav
<
T
>
&
in
,
const
fmav
<
Cmplx
<
T
>>
&
out
,
size_t
axis
,
bool
forward
,
T
fct
,
const
fmav
<
T
>
&
in
,
fmav
<
Cmplx
<
T
>>
&
out
,
size_t
axis
,
bool
forward
,
T
fct
,
size_t
nthreads
)
{
auto
plan
=
get_plan
<
pocketfft_r
<
T
>>
(
in
.
shape
(
axis
));
...
...
@@ -810,20 +815,21 @@ template<typename T> MRUTIL_NOINLINE void general_r2c(
auto
tdatav
=
reinterpret_cast
<
native_simd
<
T
>
*>
(
storage
.
data
());
copy_input
(
it
,
in
,
tdatav
);
plan
->
exec
(
tdatav
,
fct
,
true
);
auto
vout
=
out
.
vdata
();
for
(
size_t
j
=
0
;
j
<
vlen
;
++
j
)
out
[
it
.
oofs
(
j
,
0
)].
Set
(
tdatav
[
0
][
j
]);
v
out
[
it
.
oofs
(
j
,
0
)].
Set
(
tdatav
[
0
][
j
]);
size_t
i
=
1
,
ii
=
1
;
if
(
forward
)
for
(;
i
<
len
-
1
;
i
+=
2
,
++
ii
)
for
(
size_t
j
=
0
;
j
<
vlen
;
++
j
)
out
[
it
.
oofs
(
j
,
ii
)].
Set
(
tdatav
[
i
][
j
],
tdatav
[
i
+
1
][
j
]);
v
out
[
it
.
oofs
(
j
,
ii
)].
Set
(
tdatav
[
i
][
j
],
tdatav
[
i
+
1
][
j
]);
else
for
(;
i
<
len
-
1
;
i
+=
2
,
++
ii
)
for
(
size_t
j
=
0
;
j
<
vlen
;
++
j
)
out
[
it
.
oofs
(
j
,
ii
)].
Set
(
tdatav
[
i
][
j
],
-
tdatav
[
i
+
1
][
j
]);
v
out
[
it
.
oofs
(
j
,
ii
)].
Set
(
tdatav
[
i
][
j
],
-
tdatav
[
i
+
1
][
j
]);
if
(
i
<
len
)
for
(
size_t
j
=
0
;
j
<
vlen
;
++
j
)
out
[
it
.
oofs
(
j
,
ii
)].
Set
(
tdatav
[
i
][
j
]);
v
out
[
it
.
oofs
(
j
,
ii
)].
Set
(
tdatav
[
i
][
j
]);
}
#endif
while
(
it
.
remaining
()
>
0
)
...
...
@@ -832,21 +838,22 @@ template<typename T> MRUTIL_NOINLINE void general_r2c(
auto
tdata
=
reinterpret_cast
<
T
*>
(
storage
.
data
());
copy_input
(
it
,
in
,
tdata
);
plan
->
exec
(
tdata
,
fct
,
true
);
out
[
it
.
oofs
(
0
)].
Set
(
tdata
[
0
]);
auto
vout
=
out
.
vdata
();
vout
[
it
.
oofs
(
0
)].
Set
(
tdata
[
0
]);
size_t
i
=
1
,
ii
=
1
;
if
(
forward
)
for
(;
i
<
len
-
1
;
i
+=
2
,
++
ii
)
out
[
it
.
oofs
(
ii
)].
Set
(
tdata
[
i
],
tdata
[
i
+
1
]);
v
out
[
it
.
oofs
(
ii
)].
Set
(
tdata
[
i
],
tdata
[
i
+
1
]);
else
for
(;
i
<
len
-
1
;
i
+=
2
,
++
ii
)
out
[
it
.
oofs
(
ii
)].
Set
(
tdata
[
i
],
-
tdata
[
i
+
1
]);
v
out
[
it
.
oofs
(
ii
)].
Set
(
tdata
[
i
],
-
tdata
[
i
+
1
]);
if
(
i
<
len
)
out
[
it
.
oofs
(
ii
)].
Set
(
tdata
[
i
]);
v
out
[
it
.
oofs
(
ii
)].
Set
(
tdata
[
i
]);
}
});
// end of parallel region
}
template
<
typename
T
>
MRUTIL_NOINLINE
void
general_c2r
(
const
c
fmav
<
Cmplx
<
T
>>
&
in
,
const
fmav
<
T
>
&
out
,
size_t
axis
,
bool
forward
,
T
fct
,
const
fmav
<
Cmplx
<
T
>>
&
in
,
fmav
<
T
>
&
out
,
size_t
axis
,
bool
forward
,
T
fct
,
size_t
nthreads
)
{
auto
plan
=
get_plan
<
pocketfft_r
<
T
>>
(
out
.
shape
(
axis
));
...
...
@@ -922,7 +929,7 @@ struct ExecR2R
bool
r2c
,
forward
;
template
<
typename
T0
,
typename
T
,
size_t
vlen
>
void
operator
()
(
const
multi_iter
<
vlen
>
&
it
,
const
c
fmav
<
T0
>
&
in
,
const
fmav
<
T0
>
&
out
,
T
*
buf
,
const
multi_iter
<
vlen
>
&
it
,
const
fmav
<
T0
>
&
in
,
fmav
<
T0
>
&
out
,
T
*
buf
,
const
pocketfft_r
<
T0
>
&
plan
,
T0
fct
)
const
{
copy_input
(
it
,
in
,
buf
);
...
...
@@ -937,18 +944,18 @@ struct ExecR2R
}
};
template
<
typename
T
>
void
c2c
(
const
c
fmav
<
std
::
complex
<
T
>>
&
in
,
const
fmav
<
std
::
complex
<
T
>>
&
out
,
const
shape_t
&
axes
,
bool
forward
,
template
<
typename
T
>
void
c2c
(
const
fmav
<
std
::
complex
<
T
>>
&
in
,
fmav
<
std
::
complex
<
T
>>
&
out
,
const
shape_t
&
axes
,
bool
forward
,
T
fct
,
size_t
nthreads
=
1
)
{
util
::
sanity_check_onetype
(
in
,
out
,
in
.
data
()
==
out
.
data
(),
axes
);
if
(
in
.
size
()
==
0
)
return
;
c
fmav
<
Cmplx
<
T
>>
in2
(
reinterpret_cast
<
const
Cmplx
<
T
>
*>
(
in
.
data
()),
in
);
fmav
<
Cmplx
<
T
>>
out2
(
reinterpret_cast
<
Cmplx
<
T
>
*>
(
out
.
data
()),
out
);
fmav
<
Cmplx
<
T
>>
in2
(
reinterpret_cast
<
const
Cmplx
<
T
>
*>
(
in
.
data
()),
in
);
fmav
<
Cmplx
<
T
>>
out2
(
reinterpret_cast
<
Cmplx
<
T
>
*>
(
out
.
v
data
()),
out
,
out
.
writable
()
);
general_nd
<
pocketfft_c
<
T
>>
(
in2
,
out2
,
axes
,
fct
,
nthreads
,
ExecC2C
{
forward
});
}
template
<
typename
T
>
void
dct
(
const
c
fmav
<
T
>
&
in
,
const
fmav
<
T
>
&
out
,
template
<
typename
T
>
void
dct
(
const
fmav
<
T
>
&
in
,
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
int
type
,
T
fct
,
bool
ortho
,
size_t
nthreads
=
1
)
{
if
((
type
<
1
)
||
(
type
>
4
))
throw
std
::
invalid_argument
(
"invalid DCT type"
);
...
...
@@ -963,7 +970,7 @@ template<typename T> void dct(const cfmav<T> &in, const fmav<T> &out,
general_nd
<
T_dcst23
<
T
>>
(
in
,
out
,
axes
,
fct
,
nthreads
,
exec
);
}
template
<
typename
T
>
void
dst
(
const
c
fmav
<
T
>
&
in
,
const
fmav
<
T
>
&
out
,
template
<
typename
T
>
void
dst
(
const
fmav
<
T
>
&
in
,
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
int
type
,
T
fct
,
bool
ortho
,
size_t
nthreads
=
1
)
{
if
((
type
<
1
)
||
(
type
>
4
))
throw
std
::
invalid_argument
(
"invalid DST type"
);
...
...
@@ -977,18 +984,18 @@ template<typename T> void dst(const cfmav<T> &in, const fmav<T> &out,
general_nd
<
T_dcst23
<
T
>>
(
in
,
out
,
axes
,
fct
,
nthreads
,
exec
);
}
template
<
typename
T
>
void
r2c
(
const
c
fmav
<
T
>
&
in
,
const
fmav
<
std
::
complex
<
T
>>
&
out
,
size_t
axis
,
bool
forward
,
T
fct
,
template
<
typename
T
>
void
r2c
(
const
fmav
<
T
>
&
in
,
fmav
<
std
::
complex
<
T
>>
&
out
,
size_t
axis
,
bool
forward
,
T
fct
,
size_t
nthreads
=
1
)
{
util
::
sanity_check_cr
(
out
,
in
,
axis
);
if
(
in
.
size
()
==
0
)
return
;
fmav
<
Cmplx
<
T
>>
out2
(
reinterpret_cast
<
Cmplx
<
T
>
*>
(
out
.
data
()),
out
);
fmav
<
Cmplx
<
T
>>
out2
(
reinterpret_cast
<
Cmplx
<
T
>
*>
(
out
.
v
data
()),
out
,
out
.
writable
()
);
general_r2c
(
in
,
out2
,
axis
,
forward
,
fct
,
nthreads
);
}
template
<
typename
T
>
void
r2c
(
const
c
fmav
<
T
>
&
in
,
const
fmav
<
std
::
complex
<
T
>>
&
out
,
const
shape_t
&
axes
,
template
<
typename
T
>
void
r2c
(
const
fmav
<
T
>
&
in
,
fmav
<
std
::
complex
<
T
>>
&
out
,
const
shape_t
&
axes
,
bool
forward
,
T
fct
,
size_t
nthreads
=
1
)
{
util
::
sanity_check_cr
(
out
,
in
,
axes
);
...
...
@@ -1000,17 +1007,17 @@ template<typename T> void r2c(const cfmav<T> &in,
c2c
(
out
,
out
,
newaxes
,
forward
,
T
(
1
),
nthreads
);
}
template
<
typename
T
>
void
c2r
(
const
c
fmav
<
std
::
complex
<
T
>>
&
in
,
const
fmav
<
T
>
&
out
,
size_t
axis
,
bool
forward
,
T
fct
,
size_t
nthreads
=
1
)
template
<
typename
T
>
void
c2r
(
const
fmav
<
std
::
complex
<
T
>>
&
in
,
fmav
<
T
>
&
out
,
size_t
axis
,
bool
forward
,
T
fct
,
size_t
nthreads
=
1
)
{
util
::
sanity_check_cr
(
in
,
out
,
axis
);
if
(
in
.
size
()
==
0
)
return
;
c
fmav
<
Cmplx
<
T
>>
in2
(
reinterpret_cast
<
const
Cmplx
<
T
>
*>
(
in
.
data
()),
in
);
fmav
<
Cmplx
<
T
>>
in2
(
reinterpret_cast
<
const
Cmplx
<
T
>
*>
(
in
.
data
()),
in
);
general_c2r
(
in2
,
out
,
axis
,
forward
,
fct
,
nthreads
);
}
template
<
typename
T
>
void
c2r
(
const
c
fmav
<
std
::
complex
<
T
>>
&
in
,
const
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
bool
forward
,
T
fct
,
template
<
typename
T
>
void
c2r
(
const
fmav
<
std
::
complex
<
T
>>
&
in
,
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
bool
forward
,
T
fct
,
size_t
nthreads
=
1
)
{
if
(
axes
.
size
()
==
1
)
...
...
@@ -1023,8 +1030,8 @@ template<typename T> void c2r(const cfmav<std::complex<T>> &in,
c2r
(
atmp
,
out
,
axes
.
back
(),
forward
,
fct
,
nthreads
);
}
template
<
typename
T
>
void
r2r_fftpack
(
const
c
fmav
<
T
>
&
in
,
const
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
bool
real2hermitian
,
bool
forward
,
template
<
typename
T
>
void
r2r_fftpack
(
const
fmav
<
T
>
&
in
,
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
bool
real2hermitian
,
bool
forward
,
T
fct
,
size_t
nthreads
=
1
)
{
util
::
sanity_check_onetype
(
in
,
out
,
in
.
data
()
==
out
.
data
(),
axes
);
...
...
@@ -1033,8 +1040,8 @@ template<typename T> void r2r_fftpack(const cfmav<T> &in,
ExecR2R
{
real2hermitian
,
forward
});
}
template
<
typename
T
>
void
r2r_separable_hartley
(
const
c
fmav
<
T
>
&
in
,
const
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
T
fct
,
size_t
nthreads
=
1
)
template
<
typename
T
>
void
r2r_separable_hartley
(
const
fmav
<
T
>
&
in
,
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
T
fct
,
size_t
nthreads
=
1
)
{
util
::
sanity_check_onetype
(
in
,
out
,
in
.
data
()
==
out
.
data
(),
axes
);
if
(
in
.
size
()
==
0
)
return
;
...
...
@@ -1042,8 +1049,8 @@ template<typename T> void r2r_separable_hartley(const cfmav<T> &in,
false
);
}
template
<
typename
T
>
void
r2r_genuine_hartley
(
const
c
fmav
<
T
>
&
in
,
const
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
T
fct
,
size_t
nthreads
=
1
)
template
<
typename
T
>
void
r2r_genuine_hartley
(
const
fmav
<
T
>
&
in
,
fmav
<
T
>
&
out
,
const
shape_t
&
axes
,
T
fct
,
size_t
nthreads
=
1
)
{
if
(
axes
.
size
()
==
1
)
return
r2r_separable_hartley
(
in
,
out
,
axes
,
fct
,
nthreads
);
...
...
@@ -1055,11 +1062,12 @@ template<typename T> void r2r_genuine_hartley(const cfmav<T> &in,
r2c
(
in
,
atmp
,
axes
,
true
,
fct
,
nthreads
);
simple_iter
iin
(
atmp
);
rev_iter
iout
(
out
,
axes
);
auto
vout
=
out
.
vdata
();
while
(
iin
.
remaining
()
>
0
)
{
auto
v
=
atmp
[
iin
.
ofs
()];
out
[
iout
.
ofs
()]
=
v
.
real
()
+
v
.
imag
();
out
[
iout
.
rev_ofs
()]
=
v
.
real
()
-
v
.
imag
();
v
out
[
iout
.
ofs
()]
=
v
.
real
()
+
v
.
imag
();
v
out
[
iout
.
rev_ofs
()]
=
v
.
real
()
-
v
.
imag
();
iin
.
advance
();
iout
.
advance
();
}
}
...
...
mr_util/mav.h
View file @
f1c7d8d8
...
...
@@ -16,7 +16,7 @@
* Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
*/
/* Copyright (C) 2019 Max-Planck-Society
/* Copyright (C) 2019
-2020
Max-Planck-Society
Author: Martin Reinecke */
#ifndef MRUTIL_MAV_H
...
...
@@ -42,6 +42,7 @@ class fmav_info
protected:
shape_t
shp
;
stride_t
str
;
size_t
sz
;
static
size_t
prod
(
const
shape_t
&
shape
)
{
...
...
@@ -53,18 +54,22 @@ class fmav_info
public:
fmav_info
(
const
shape_t
&
shape_
,
const
stride_t
&
stride_
)
:
shp
(
shape_
),
str
(
stride_
)
{
MR_assert
(
shp
.
size
()
==
str
.
size
(),
"dimensions mismatch"
);
}
:
shp
(
shape_
),
str
(
stride_
),
sz
(
prod
(
shp
))
{
MR_assert
(
shp
.
size
()
>
0
,
"at least 1D required"
);
MR_assert
(
shp
.
size
()
==
str
.
size
(),
"dimensions mismatch"
);
}
fmav_info
(
const
shape_t
&
shape_
)
:
shp
(
shape_
),
str
(
shape_
.
size
())
:
shp
(
shape_
),
str
(
shape_
.
size
())
,
sz
(
prod
(
shp
))
{
auto
ndim
=
shp
.
size
();
MR_assert
(
ndim
>
0
,
"at least 1D required"
);
str
[
ndim
-
1
]
=
1
;
for
(
size_t
i
=
2
;
i
<=
ndim
;
++
i
)
str
[
ndim
-
i
]
=
str
[
ndim
-
i
+
1
]
*
ptrdiff_t
(
shp
[
ndim
-
i
+
1
]);
}
size_t
ndim
()
const
{
return
shp
.
size
();
}
size_t
size
()
const
{
return
prod
(
shp
)
;
}
size_t
size
()
const
{
return
sz
;
}
const
shape_t
&
shape
()
const
{
return
shp
;
}
size_t
shape
(
size_t
i
)
const
{
return
shp
[
i
];
}
const
stride_t
&
stride
()
const
{
return
str
;
}
...
...
@@ -86,215 +91,284 @@ class fmav_info
{
return
shp
==
other
.
shp
;
}
};
template
<
typename
T
,
size_t
ndim
>
class
cmav
;
// "mav" stands for "multidimensional array view"
template
<
typename
T
>
class
cfmav
:
public
fmav_info
template
<
typename
T
>
class
membuf
{
protected:
using
Tsp
=
shared_ptr
<
vector
<
T
>>
;
Tsp
ptr
;
T
*
d
;
const
T
*
d
;
bool
rw
;
public:
cfmav
(
const
T
*
d_
,
const
shape_t
&
shp_
,
const
stride_t
&
str_
)
:
fmav_info
(
shp_
,
str_
),
d
(
const_cast
<
T
*>
(
d_
))
{}
cfmav
(
const
T
*
d_
,
const
shape_t
&
shp_
)
:
fmav_info
(
shp_
),
d
(
const_cast
<
T
*>
(
d_
))
{}
cfmav
(
const
shape_t
&
shp_
)
:
fmav_info
(
shp_
),
ptr
(
make_unique
<
vector
<
T
>>
(
size
())),
d
(
const_cast
<
T
*>
(
ptr
->
data
()))
{}
cfmav
(
const
T
*
d_
,
const
fmav_info
&
info
)
:
fmav_info
(
info
),
d
(
const_cast
<
T
*>
(
d_
))
{}
cfmav
(
const
cfmav
&
other
)
=
default
;
cfmav
(
cfmav
&&
other
)
=
default
;
membuf
(
T
*
d_
,
bool
rw_
=
false
)
:
d
(
d_
),
rw
(
rw_
)
{}
membuf
(
const
T
*
d_
)
:
d
(
d_
),
rw
(
false
)
{}
membuf
(
size_t
sz
)
:
ptr
(
make_unique
<
vector
<
T
>>
(
sz
)),
d
(
ptr
->
data
()),
rw
(
true
)
{}
membuf
(
const
membuf
&
other
)
:
ptr
(
other
.
ptr
),
d
(
other
.
d
),
rw
(
false
)
{}
membuf
(
membuf
&
other
)
=
default
;
membuf
(
membuf
&&
other
)
=
default
;
// Not for public use!
cfmav
(
const
T
*
d_
,
const
Tsp
&
p
,
const
shape_t
&
shp_
,
const
stride_t
&
str_
)
:
fmav_info
(
shp_
,
str_
),
ptr
(
p
),
d
(
const_cast
<
T
*>
(
d_
))
{}
membuf
(
T
*
d_
,
const
Tsp
&
p
,
bool
rw_
)
:
ptr
(
p
),
d
(
d_
),
rw
(
rw_
)
{}
template
<
typename
I
>
T
&
vraw
(
I
i
)
{
MR_assert
(
rw
,
"array is not writable"
);
return
const_cast
<
T
*>
(
d
)[
i
];
}
template
<
typename
I
>
const
T
&
operator
[](
I
i
)
const
{
return
d
[
i
];
}
const
T
*
data
()
const
{
return
d
;
}
T
*
vdata
()
{
MR_assert
(
rw
,
"array is not writable"
);
return
const_cast
<
T
*>
(
d
);
}
bool
writable
()
{
return
rw
;
}
};
template
<
typename
T
>
class
fmav
:
public
cfmav
<
T
>
// "mav" stands for "multidimensional array view"
template
<
typename
T
>
class
fmav
:
public
fmav_info
,
public
membuf
<
T
>
{
protected:
using
Tsp
=
shared_ptr
<
vector
<
T
>>
;
using
parent
=
cfmav
<
T
>
;
using
parent
::
d
;
using
parent
::
shp
;
using
parent
::
str
;
using
typename
membuf
<
T
>::
Tsp
;
public:
fmav
(
T
*
d_
,
const
shape_t
&
shp_
,
const
stride_t
&
str_
)
:
parent
(
d_
,
shp_
,
str_
)
{}
fmav
(
T
*
d_
,
const
shape_t
&
shp_
)
:
parent
(
d_
,
shp_
)
{}
fmav
(
const
T
*
d_
,
const
shape_t
&
shp_
,
const
stride_t
&
str_
)
:
fmav_info
(
shp_
,
str_
),
membuf
<
T
>
(
d_
)
{}
fmav
(
const
T
*
d_
,
const
shape_t
&
shp_
)
:
fmav_info
(
shp_
),
membuf
<
T
>
(
d_
)
{}
fmav
(
T
*
d_
,
const
shape_t
&
shp_
,
const
stride_t
&
str_
,
bool
rw_
)
:
fmav_info
(
shp_
,
str_
),
membuf
<
T
>
(
d_
,
rw_
)
{}
fmav
(
T
*
d_
,
const
shape_t
&
shp_
,
bool
rw_
)
:
fmav_info
(
shp_
),
membuf
<
T
>
(
d_
,
rw_
)
{}
fmav
(
const
shape_t
&
shp_
)