Commit 7f09a39a authored by theos's avatar theos
Browse files

obj.unique() now is built on top of MPI.comm.allreduce with a custom MPI.Op

parent ec62ebd2
Pipeline #1744 skipped
......@@ -1377,10 +1377,7 @@ class distributed_data_object(object):
def unique(self):
""" Returns a `numpy.ndarray` holding the d2o's unique elements. """
local_unique = np.unique(self.get_local_data())
global_unique = self.distributor._allgather(local_unique)
global_unique = np.concatenate(global_unique)
return np.unique(global_unique)
return self.distributor.unique(self.data)
def bincount(self, weights=None, minlength=None):
""" Count weighted number of occurrences of each value in the d2o.
......
......@@ -1355,6 +1355,17 @@ class _slicing_distributor(distributor):
local_where)
return global_where
def unique(self, data):
# if the size of the MPI communicator is equal to 1, the
# reduce operator will not be applied. -> Cover this case directly.
if self.comm.size == 1:
unique_data = np.unique(data)
else:
data = data.flatten()
(mpi_unique, bufferQ) = op_translate_dict[np.unique]
unique_data = self.comm.allreduce(data, op=mpi_unique)
return unique_data
def bincount(self, local_data, local_weights, minlength):
local_counts = np.bincount(local_data,
weights=local_weights,
......@@ -1851,6 +1862,9 @@ class _not_distributor(distributor):
local_where)
return global_where
def unique(self, data):
return np.unique(data)
def bincount(self, local_data, local_weights, minlength):
counts = np.bincount(local_data,
weights=local_weights,
......
......@@ -13,6 +13,9 @@ custom_NANMIN = MPI.Op.Create(lambda x, y, datatype:
custom_NANMAX = MPI.Op.Create(lambda x, y, datatype:
np.nanmax(np.vstack(x, y), axis=0))
custom_UNIQUE = MPI.Op.Create(lambda x, y, datatype:
np.unique(np.concatenate([x, y])))
op_translate_dict = {}
# the value tuple contains the operator and a boolean which specifies
......@@ -25,3 +28,4 @@ op_translate_dict[np.all] = (MPI.LAND, True)
op_translate_dict[np.any] = (MPI.LOR, True)
op_translate_dict[np.nanmin] = (custom_NANMIN, False)
op_translate_dict[np.nanmax] = (custom_NANMAX, False)
op_translate_dict[np.unique] = (custom_UNIQUE, False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment