Commit 857bb932 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

remove duplicate code

parent a7f3e8c4
...@@ -16,11 +16,9 @@ ...@@ -16,11 +16,9 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import next from builtins import next, range
from builtins import range
import numpy as np import numpy as np
from itertools import product from itertools import product
import itertools
from functools import reduce from functools import reduce
...@@ -113,50 +111,6 @@ def parse_domain(domain): ...@@ -113,50 +111,6 @@ def parse_domain(domain):
return domain return domain
def slicing_generator(shape, axes):
"""
Helper function which generates slice list(s) to traverse over all
combinations of axes, other than the selected axes.
Parameters
----------
shape: tuple
Shape of the data array to traverse over.
axes: tuple
Axes which should not be iterated over.
Yields
-------
list
The next list of indices and/or slice objects for each dimension.
Raises
------
ValueError
If shape is empty.
ValueError
If axes(axis) does not match shape.
"""
if not shape:
raise ValueError("ERROR: shape cannot be None.")
if axes:
if not all(axis < len(shape) for axis in axes):
raise ValueError("ERROR: axes(axis) does not match shape.")
axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
axes_iterables =\
[list(range(y)) for x, y in enumerate(shape) if x not in axes]
for current_index in itertools.product(*axes_iterables):
it_iter = iter(current_index)
slice_list = [next(it_iter) if use_axis else
slice(None, None) for use_axis in axes_select]
yield slice_list
else:
yield [slice(None, None)]
return
def bincount_axis(obj, minlength=None, weights=None, axis=None): def bincount_axis(obj, minlength=None, weights=None, axis=None):
if minlength is not None: if minlength is not None:
length = max(np.amax(obj) + 1, minlength) length = max(np.amax(obj) + 1, minlength)
...@@ -206,8 +160,8 @@ def bincount_axis(obj, minlength=None, weights=None, axis=None): ...@@ -206,8 +160,8 @@ def bincount_axis(obj, minlength=None, weights=None, axis=None):
dtype=result_dtype) dtype=result_dtype)
# iterate over all entries in the surviving axes and compute the local # iterate over all entries in the surviving axes and compute the local
# bincounts # bincounts
for slice_list in slicing_generator(flat_shape, for slice_list in get_slice_list(flat_shape,
axes=(len(flat_shape)-1, )): axes=(len(flat_shape)-1, )):
if weights is not None: if weights is not None:
current_weights = weights[slice_list] current_weights = weights[slice_list]
else: else:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment