Commit 2621861d authored by theos's avatar theos

Added axis-keyword functionality to bincount.

parent 72493960
......@@ -1431,7 +1431,7 @@ class distributed_data_object(object):
return self.distributor.unique(self.data)
def bincount(self, weights=None, minlength=None):
def bincount(self, weights=None, minlength=None, axis=None):
""" Count weighted number of occurrences of each value in the d2o.
The number of integer bins is `max(self.amax()+1, minlength)`.
......@@ -1465,19 +1465,15 @@ class distributed_data_object(object):
raise TypeError(about_cstring(
"ERROR: Distributed-data-object must be of integer datatype!"))
minlength = max(self.amax() + 1, minlength)
if axis is ():
return self.copy()
if weights is not None:
local_weights = self.distributor.extract_local_data(weights).\
flatten()
else:
local_weights = None
length = max(self.amax() + 1, minlength)
local_data = self.get_local_data(copy=False).flatten()
counts = self.distributor.bincount(local_data=local_data,
local_weights=local_weights,
minlength=minlength)
return counts
return self.distributor.bincount(obj=self,
length=length,
weights=weights,
axis=axis)
def where(self):
""" Return the indices where `self` is True.
......
......@@ -16,6 +16,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import numpy as np
from d2o.config import configuration as gc,\
......@@ -29,6 +30,7 @@ from d2o_librarian import d2o_librarian
from dtype_converter import dtype_converter
from cast_axis_to_tuple import cast_axis_to_tuple
from translate_to_mpi_operator import op_translate_dict
from slicing_generator import slicing_generator
from strategies import STRATEGIES
......@@ -42,6 +44,7 @@ about_cstring = lambda z: z
from sys import stdout
about_infos_cprint = lambda z: stdout.write(z + "\n"); stdout.flush()
class _distributor_factory(object):
def __init__(self):
......@@ -402,6 +405,73 @@ class distributor(object):
**kwargs)
i += 1
def bincount(self, obj, length, weights=None, axis=None):
data = obj.get_local_data(copy=False)
# this implementation fits all distribution strategies where the
# axes of the global array correspond to the axes of the local data
if weights is not None:
local_weights = self.extract_local_data(weights)
else:
local_weights = None
# if present, parse the axis keyword and transpose/reorder self.data
# such that all affected axes follow each other. Only if they are in a
# sequence flattening will be possible
if axis is not None:
# do the reordering
ndim = len(self.global_shape)
axis = sorted(cast_axis_to_tuple(axis, length=ndim))
reordering = [x for x in xrange(ndim) if x not in axis]
reordering += axis
data = np.transpose(data, reordering)
if local_weights is not None:
local_weights = np.transpose(local_weights, reordering)
reord_axis = range(ndim-len(axis), ndim)
# semi-flatten the dimensions in `axis`, i.e. after reordering
# the last ones.
semi_flat_dim = reduce(lambda x, y: x*y,
data.shape[ndim-len(reord_axis):])
flat_shape = data.shape[:ndim-len(reord_axis)] + (semi_flat_dim, )
else:
flat_shape = (reduce(lambda x, y: x*y, data.shape), )
data = np.ascontiguousarray(data.reshape(flat_shape))
if local_weights is not None:
local_weights = np.ascontiguousarray(
local_weights.reshape(flat_shape))
# compute the local bincount results
# -> prepare the local result array
if local_weights is None:
result_dtype = np.int
else:
result_dtype = np.float
local_counts = np.empty(flat_shape[:-1] + (length, ),
dtype=result_dtype)
# iterate over all entries in the surviving axes and compute the local
# bincounts
for slice_list in slicing_generator(flat_shape,
axes=(len(flat_shape)-1, )):
local_counts[slice_list] = np.bincount(data[slice_list],
weights=local_weights,
minlength=length)
# restore the original ordering
# place the bincount stuff at the location of the first `axis` entry
if axis is not None:
# axis has been sorted above
insert_position = axis[0]
return_order = (range(0, insert_position) +
[ndim-1, ] +
range(insert_position, ndim-1))
local_counts = np.ascontiguousarray(
local_counts.transpose(return_order))
return self._combine_local_bincount_counts(obj, local_counts, axis)
class _slicing_distributor(distributor):
def __init__(self, slicer, name, dtype, comm, **remaining_parsed_kwargs):
......@@ -1495,23 +1565,40 @@ class _slicing_distributor(distributor):
global_unique_data]))
return global_unique_data
def bincount(self, local_data, local_weights, minlength):
if local_weights is None:
result_dtype = np.int
def _combine_local_bincount_counts(self, obj, local_counts, axis):
if axis is None or 0 in axis:
global_counts = np.empty_like(local_counts)
self._Allreduce_helper(local_counts, global_counts, MPI.SUM)
result_object = obj.copy_empty(global_shape=global_counts.shape,
dtype=global_counts.dtype,
distribution_strategy='not')
else:
result_dtype = np.float
global_counts = local_counts
result_object = obj.copy_empty(local_shape=global_counts.shape,
dtype=global_counts.dtype,
distribution_strategy='freeform')
local_counts = np.bincount(local_data,
weights=local_weights,
minlength=minlength)
result_object.set_local_data(global_counts, copy=False)
return result_object
# cast the local_counts to the right dtype while avoiding copying
local_counts = np.array(local_counts, copy=False, dtype=result_dtype)
global_counts = np.empty_like(local_counts)
self._Allreduce_helper(local_counts,
global_counts,
MPI.SUM)
return global_counts
# def bincount(self, local_data, local_weights, minlength, axis=None):
# if local_weights is None:
# result_dtype = np.int
# else:
# result_dtype = np.float
#
# local_counts = np.bincount(local_data,
# weights=local_weights,
# minlength=minlength)
#
# # cast the local_counts to the right dtype while avoiding copying
# local_counts = np.array(local_counts, copy=False, dtype=result_dtype)
# global_counts = np.empty_like(local_counts)
# self._Allreduce_helper(local_counts,
# global_counts,
# MPI.SUM)
# return global_counts
def cumsum(self, parent, axis):
data = parent.data
......@@ -2053,11 +2140,11 @@ class _not_distributor(distributor):
def unique(self, data):
return np.unique(data)
def bincount(self, local_data, local_weights, minlength):
counts = np.bincount(local_data,
weights=local_weights,
minlength=minlength)
return counts
def _combine_local_bincount_counts(self, obj, local_counts, axis):
result_object = obj.copy_empty(global_shape=local_counts.shape,
dtype=local_counts.dtype)
result_object.set_local_data(local_counts, copy=False)
return result_object
def cumsum(self, parent, axis):
data = parent.data
......
# D2O
# Copyright (C) 2016 Theo Steininger
#
# Author: Theo Steininger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import itertools
def slicing_generator(shape, axes):
"""
Helper function which generates slice list(s) to traverse over all
combinations of axes, other than the selected axes.
Parameters
----------
shape: tuple
Shape of the data array to traverse over.
axes: tuple
Axes which should not be iterated over.
Yields
-------
list
The next list of indices and/or slice objects for each dimension.
Raises
------
ValueError
If shape is empty.
ValueError
If axes(axis) does not match shape.
"""
if not shape:
raise ValueError("ERROR: shape cannot be None.")
if axes:
if not all(axis < len(shape) for axis in axes):
raise ValueError("ERROR: axes(axis) does not match shape.")
axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
axes_iterables =\
[range(y) for x, y in enumerate(shape) if x not in axes]
for current_index in itertools.product(*axes_iterables):
it_iter = iter(current_index)
slice_list = [next(it_iter) if use_axis else
slice(None, None) for use_axis in axes_select]
yield slice_list
else:
yield [slice(None, None)]
return
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment