Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Martin Reinecke
pypocketfft
Commits
28540189
Commit
28540189
authored
Jun 18, 2019
by
Martin Reinecke
Browse files
introduce simple_iter and rev_iter to simplify iterating over arrays with Hermitian symmetry
parent
d77cdeb5
Changes
2
Hide whitespace changes
Inline
Side-by-side
pocketfft_hdronly.h
View file @
28540189
...
...
@@ -2236,16 +2236,18 @@ template<typename T> class ndarr
{
return
*
reinterpret_cast
<
T
*>
(
d
+
ofs
);
}
const
T
&
operator
[](
ptrdiff_t
ofs
)
const
{
return
*
reinterpret_cast
<
const
T
*>
(
cd
+
ofs
);
}
T
get
(
ptrdiff_t
ofs
)
const
{
return
*
reinterpret_cast
<
const
T
*>
(
cd
+
ofs
);
}
void
set
(
ptrdiff_t
ofs
,
const
T
&
val
)
const
{
*
reinterpret_cast
<
T
*>
(
d
+
ofs
)
=
val
;
}
};
template
<
size_t
N
,
typename
Ti
,
typename
To
>
class
multi_iter
{
p
ublic
:
p
rivate
:
shape_t
pos
;
const
ndarr
<
Ti
>
&
iarr
;
ndarr
<
To
>
&
oarr
;
private:
ptrdiff_t
p_ii
,
p_i
[
N
],
str_i
,
p_oi
,
p_o
[
N
],
str_o
;
size_t
idim
,
rem
;
...
...
@@ -2320,6 +2322,97 @@ template<size_t N, typename Ti, typename To> class multi_iter
bool
contiguous_out
()
const
{
return
str_o
==
sizeof
(
To
);
}
};
template
<
typename
T
>
class
simple_iter
{
private:
shape_t
pos
;
const
ndarr
<
T
>
&
arr
;
ptrdiff_t
p
;
size_t
rem
;
public:
simple_iter
(
const
ndarr
<
T
>
&
arr_
)
:
pos
(
arr_
.
ndim
(),
0
),
arr
(
arr_
),
p
(
0
),
rem
(
arr_
.
size
())
{}
void
advance
()
{
--
rem
;
for
(
int
i_
=
int
(
pos
.
size
())
-
1
;
i_
>=
0
;
--
i_
)
{
auto
i
=
size_t
(
i_
);
p
+=
arr
.
stride
(
i
);
if
(
++
pos
[
i
]
<
arr
.
shape
(
i
))
return
;
pos
[
i
]
=
0
;
p
-=
ptrdiff_t
(
arr
.
shape
(
i
))
*
arr
.
stride
(
i
);
}
}
ptrdiff_t
ofs
()
const
{
return
p
;
}
size_t
remaining
()
const
{
return
rem
;
}
};
template
<
typename
T
>
class
rev_iter
{
private:
shape_t
pos
;
ndarr
<
T
>
&
arr
;
vector
<
char
>
rev_axis
;
vector
<
char
>
rev_jump
;
size_t
last_axis
,
last_size
;
shape_t
shp
;
ptrdiff_t
p
,
rp
;
size_t
rem
;
public:
rev_iter
(
ndarr
<
T
>
&
arr_
,
const
shape_t
&
axes
)
:
pos
(
arr_
.
ndim
(),
0
),
arr
(
arr_
),
rev_axis
(
arr_
.
ndim
(),
0
),
rev_jump
(
arr_
.
ndim
(),
1
),
p
(
0
),
rp
(
0
)
{
for
(
auto
ax
:
axes
)
rev_axis
[
ax
]
=
1
;
last_axis
=
axes
.
back
();
last_size
=
arr
.
shape
(
last_axis
)
/
2
+
1
;
shp
=
arr
.
shape
();
shp
[
last_axis
]
=
last_size
;
rem
=
1
;
for
(
auto
i
:
shp
)
rem
*=
i
;
}
void
advance
()
{
--
rem
;
for
(
int
i_
=
int
(
pos
.
size
())
-
1
;
i_
>=
0
;
--
i_
)
{
auto
i
=
size_t
(
i_
);
p
+=
arr
.
stride
(
i
);
if
(
!
rev_axis
[
i
])
rp
+=
arr
.
stride
(
i
);
else
{
rp
-=
arr
.
stride
(
i
);
if
(
rev_jump
[
i
])
{
rp
+=
ptrdiff_t
(
arr
.
shape
(
i
))
*
arr
.
stride
(
i
);
rev_jump
[
i
]
=
0
;
}
}
if
(
++
pos
[
i
]
<
shp
[
i
])
return
;
pos
[
i
]
=
0
;
p
-=
ptrdiff_t
(
shp
[
i
])
*
arr
.
stride
(
i
);
if
(
rev_axis
[
i
])
{
rp
-=
ptrdiff_t
(
arr
.
shape
(
i
)
-
shp
[
i
])
*
arr
.
stride
(
i
);
rev_jump
[
i
]
=
1
;
}
else
rp
-=
ptrdiff_t
(
shp
[
i
])
*
arr
.
stride
(
i
);
}
}
ptrdiff_t
ofs
()
const
{
return
p
;
}
ptrdiff_t
rev_ofs
()
const
{
return
rp
;
}
size_t
remaining
()
const
{
return
rem
;
}
};
#ifndef POCKETFFT_NO_VECTORS
template
<
typename
T
>
struct
VTYPE
{};
template
<
>
struct
VTYPE
<
float
>
...
...
pypocketfft.cc
View file @
28540189
...
...
@@ -216,7 +216,6 @@ template<typename T>py::array complex2hartley(const py::array &in,
const
py
::
array
&
tmp
,
py
::
object
axes_
,
bool
inplace
)
{
using
namespace
pocketfft
::
detail
;
size_t
ndim
=
size_t
(
in
.
ndim
());
auto
dims_out
(
copy_shape
(
in
));
py
::
array
out
=
inplace
?
in
:
py
::
array_t
<
T
>
(
dims_out
);
ndarr
<
cmplx
<
T
>>
atmp
(
tmp
.
data
(),
copy_shape
(
tmp
),
copy_strides
(
tmp
));
...
...
@@ -224,35 +223,15 @@ template<typename T>py::array complex2hartley(const py::array &in,
auto
axes
=
makeaxes
(
in
,
axes_
);
{
py
::
gil_scoped_release
release
;
size_t
axis
=
axes
.
back
();
multi_iter
<
1
,
cmplx
<
T
>
,
T
>
it
(
atmp
,
aout
,
axis
);
vector
<
bool
>
swp
(
ndim
,
false
);
for
(
auto
i
:
axes
)
if
(
i
!=
axis
)
swp
[
i
]
=
true
;
while
(
it
.
remaining
()
>
0
)
simple_iter
<
cmplx
<
T
>>
iin
(
atmp
);
rev_iter
<
T
>
iout
(
aout
,
axes
);
if
(
iin
.
remaining
()
!=
iout
.
remaining
())
throw
runtime_error
(
"oops"
);
while
(
iin
.
remaining
()
>
0
)
{
ptrdiff_t
rofs
=
0
;
for
(
size_t
i
=
0
;
i
<
it
.
pos
.
size
();
++
i
)
{
if
(
i
==
axis
)
continue
;
if
(
!
swp
[
i
])
rofs
+=
ptrdiff_t
(
it
.
pos
[
i
])
*
it
.
oarr
.
stride
(
i
);
else
{
auto
x
=
ptrdiff_t
((
it
.
pos
[
i
]
==
0
)
?
0
:
it
.
iarr
.
shape
(
i
)
-
it
.
pos
[
i
]);
rofs
+=
x
*
it
.
oarr
.
stride
(
i
);
}
}
it
.
advance
(
1
);
for
(
size_t
i
=
0
;
i
<
it
.
length_in
();
++
i
)
{
auto
re
=
it
.
in
(
i
).
r
;
auto
im
=
it
.
in
(
i
).
i
;
auto
rev_i
=
ptrdiff_t
((
i
==
0
)
?
0
:
it
.
length_out
()
-
i
);
it
.
out
(
i
)
=
re
+
im
;
aout
[
rofs
+
rev_i
*
it
.
stride_out
()]
=
re
-
im
;
}
auto
v
=
atmp
.
get
(
iin
.
ofs
());
aout
.
set
(
iout
.
ofs
(),
v
.
r
+
v
.
i
);
aout
.
set
(
iout
.
rev_ofs
(),
v
.
r
-
v
.
i
);
iin
.
advance
();
iout
.
advance
();
}
}
return
out
;
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment