# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import next, range
import numpy as np
from itertools import product
from functools import reduce
from .domain_object import DomainObject
def get_slice_list(shape, axes):
"""
Helper function which generates slice list(s) to traverse over all
combinations of axes, other than the selected axes.
Parameters
----------
shape: tuple
Shape of the data array to traverse over.
axes: tuple
Axes which should not be iterated over.
Yields
-------
list
The next list of indices and/or slice objects for each dimension.
Raises
------
ValueError
If shape is empty.
ValueError
If axes(axis) does not match shape.
"""
if not shape:
raise ValueError("shape cannot be None.")
if axes:
if not all(axis < len(shape) for axis in axes):
raise ValueError("axes(axis) does not match shape.")
axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
axes_iterables = \
[list(range(y)) for x, y in enumerate(shape) if x not in axes]
for index in product(*axes_iterables):
it_iter = iter(index)
slice_list = [
next(it_iter)
if axis else slice(None, None) for axis in axes_select
]
yield slice_list
else:
yield [slice(None, None)]
def cast_iseq_to_tuple(seq):
if seq is None:
return None
if np.isscalar(seq):
return (int(seq),)
return tuple(int(item) for item in seq)
def bincount_axis(obj, minlength=None, weights=None, axis=None):
if minlength is not None:
length = max(np.amax(obj) + 1, minlength)
else:
length = np.amax(obj) + 1
if obj.shape == ():
raise ValueError("object of too small depth for desired array")
data = obj
# if present, parse the axis keyword and transpose/reorder self.data
# such that all affected axes follow each other. Only if they are in a
# sequence flattening will be possible
if axis is not None:
# do the reordering
ndim = len(obj.shape)
axis = sorted(cast_iseq_to_tuple(axis))
reordering = [x for x in range(ndim) if x not in axis]
reordering += axis
data = np.transpose(data, reordering)
if weights is not None:
weights = np.transpose(weights, reordering)
reord_axis = list(range(ndim-len(axis), ndim))
# semi-flatten the dimensions in `axis`, i.e. after reordering
# the last ones.
semi_flat_dim = reduce(lambda x, y: x*y,
data.shape[ndim-len(reord_axis):])
flat_shape = data.shape[:ndim-len(reord_axis)] + (semi_flat_dim, )
else:
flat_shape = (reduce(lambda x, y: x*y, data.shape), )
data = np.ascontiguousarray(data.reshape(flat_shape))
if weights is not None:
weights = np.ascontiguousarray(weights.reshape(flat_shape))
# compute the local bincount results
# -> prepare the local result array
result_dtype = np.int if weights is None else np.float
local_counts = np.empty(flat_shape[:-1] + (length, ), dtype=result_dtype)
# iterate over all entries in the surviving axes and compute the local
# bincounts
for slice_list in get_slice_list(flat_shape, axes=(len(flat_shape)-1,)):
current_weights = None if weights is None else weights[slice_list]
local_counts[slice_list] = np.bincount(data[slice_list],
weights=current_weights,
minlength=length)
# restore the original ordering
# place the bincount stuff at the location of the first `axis` entry
if axis is not None:
# axis has been sorted above
insert_position = axis[0]
new_ndim = len(local_counts.shape)
return_order = (list(range(0, insert_position)) +
[new_ndim-1, ] +
list(range(insert_position, new_ndim-1)))
local_counts = np.ascontiguousarray(
local_counts.transpose(return_order))
return local_counts