Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
857bb932
Commit
857bb932
authored
Sep 07, 2017
by
Martin Reinecke
Browse files
remove duplicate code
parent
a7f3e8c4
Changes
1
Hide whitespace changes
Inline
Side-by-side
nifty2go/nifty_utilities.py
View file @
857bb932
...
...
@@ -16,11 +16,9 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
builtins
import
next
from
builtins
import
range
from
builtins
import
next
,
range
import
numpy
as
np
from
itertools
import
product
import
itertools
from
functools
import
reduce
...
...
@@ -113,50 +111,6 @@ def parse_domain(domain):
return
domain
def
slicing_generator
(
shape
,
axes
):
"""
Helper function which generates slice list(s) to traverse over all
combinations of axes, other than the selected axes.
Parameters
----------
shape: tuple
Shape of the data array to traverse over.
axes: tuple
Axes which should not be iterated over.
Yields
-------
list
The next list of indices and/or slice objects for each dimension.
Raises
------
ValueError
If shape is empty.
ValueError
If axes(axis) does not match shape.
"""
if
not
shape
:
raise
ValueError
(
"ERROR: shape cannot be None."
)
if
axes
:
if
not
all
(
axis
<
len
(
shape
)
for
axis
in
axes
):
raise
ValueError
(
"ERROR: axes(axis) does not match shape."
)
axes_select
=
[
0
if
x
in
axes
else
1
for
x
,
y
in
enumerate
(
shape
)]
axes_iterables
=
\
[
list
(
range
(
y
))
for
x
,
y
in
enumerate
(
shape
)
if
x
not
in
axes
]
for
current_index
in
itertools
.
product
(
*
axes_iterables
):
it_iter
=
iter
(
current_index
)
slice_list
=
[
next
(
it_iter
)
if
use_axis
else
slice
(
None
,
None
)
for
use_axis
in
axes_select
]
yield
slice_list
else
:
yield
[
slice
(
None
,
None
)]
return
def
bincount_axis
(
obj
,
minlength
=
None
,
weights
=
None
,
axis
=
None
):
if
minlength
is
not
None
:
length
=
max
(
np
.
amax
(
obj
)
+
1
,
minlength
)
...
...
@@ -206,8 +160,8 @@ def bincount_axis(obj, minlength=None, weights=None, axis=None):
dtype
=
result_dtype
)
# iterate over all entries in the surviving axes and compute the local
# bincounts
for
slice_list
in
slicing_generator
(
flat_shape
,
axes
=
(
len
(
flat_shape
)
-
1
,
)):
for
slice_list
in
get_slice_list
(
flat_shape
,
axes
=
(
len
(
flat_shape
)
-
1
,
)):
if
weights
is
not
None
:
current_weights
=
weights
[
slice_list
]
else
:
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment