Commit 0bba7a71 authored by theos's avatar theos
Browse files

Added a _selective_allreduce method to the slicing distributor in order to fix...

Added a _selective_allreduce method to the slicing distributor in order to fix contractions that involve nodes with empty data.
parent af793f09
Pipeline #2053 skipped
...@@ -1057,10 +1057,12 @@ class distributed_data_object(object): ...@@ -1057,10 +1057,12 @@ class distributed_data_object(object):
-------- --------
numpy.amin numpy.amin
""" """
return self.distributor._contraction_helper(self, return self.distributor.contraction_helper(
np.amin, self,
axis=axis, np.amin,
**kwargs) allow_empty_contractions=False,
axis=axis,
**kwargs)
def nanmin(self, axis=None, **kwargs): def nanmin(self, axis=None, **kwargs):
""" Returns the minimum of an array ignoring all NaNs. """ Returns the minimum of an array ignoring all NaNs.
...@@ -1069,10 +1071,12 @@ class distributed_data_object(object): ...@@ -1069,10 +1071,12 @@ class distributed_data_object(object):
-------- --------
numpy.nanmin numpy.nanmin
""" """
return self.distributor._contraction_helper(self, return self.distributor.contraction_helper(
np.nanmin, self,
axis=axis, np.nanmin,
**kwargs) allow_empty_contractions=False,
axis=axis,
**kwargs)
def max(self, axis=None, **kwargs): def max(self, axis=None, **kwargs):
""" x.max() <==> x.amax() """ """ x.max() <==> x.amax() """
...@@ -1085,10 +1089,12 @@ class distributed_data_object(object): ...@@ -1085,10 +1089,12 @@ class distributed_data_object(object):
-------- --------
numpy.amax numpy.amax
""" """
return self.distributor._contraction_helper(self, return self.distributor.contraction_helper(
np.amax, self,
axis=axis, np.amax,
**kwargs) allow_empty_contractions=False,
axis=axis,
**kwargs)
def nanmax(self, axis=None, **kwargs): def nanmax(self, axis=None, **kwargs):
""" Returns the maximum of an array ignoring all NaNs. """ Returns the maximum of an array ignoring all NaNs.
...@@ -1097,10 +1103,12 @@ class distributed_data_object(object): ...@@ -1097,10 +1103,12 @@ class distributed_data_object(object):
-------- --------
numpy.nanmax numpy.nanmax
""" """
return self.distributor._contraction_helper(self, return self.distributor.contraction_helper(
np.nanmax, self,
axis=axis, np.nanmax,
**kwargs) allow_empty_contractions=False,
axis=axis,
**kwargs)
def sum(self, axis=None, **kwargs): def sum(self, axis=None, **kwargs):
""" Sums the array elements. """ Sums the array elements.
...@@ -1109,10 +1117,12 @@ class distributed_data_object(object): ...@@ -1109,10 +1117,12 @@ class distributed_data_object(object):
-------- --------
numpy.sum numpy.sum
""" """
return self.distributor._contraction_helper(self, return self.distributor.contraction_helper(
np.sum, self,
axis=axis, np.sum,
**kwargs) allow_empty_contractions=True,
axis=axis,
**kwargs)
def prod(self, axis=None, **kwargs): def prod(self, axis=None, **kwargs):
""" Multiplies the array elements. """ Multiplies the array elements.
...@@ -1121,22 +1131,28 @@ class distributed_data_object(object): ...@@ -1121,22 +1131,28 @@ class distributed_data_object(object):
-------- --------
numpy.prod numpy.prod
""" """
return self.distributor._contraction_helper(self, return self.distributor.contraction_helper(
np.prod, self,
axis=axis, np.prod,
**kwargs) allow_empty_contractions=True,
axis=axis,
**kwargs)
def all(self, axis=None, **kwargs): def all(self, axis=None, **kwargs):
return self.distributor._contraction_helper(self, return self.distributor.contraction_helper(
np.all, self,
axis=axis, np.all,
**kwargs) allow_empty_contractions=True,
axis=axis,
**kwargs)
def any(self, axis=None, **kwargs): def any(self, axis=None, **kwargs):
return self.distributor._contraction_helper(self, return self.distributor.contraction_helper(
np.any, self,
axis=axis, np.any,
**kwargs) allow_empty_contractions=True,
axis=axis,
**kwargs)
def mean(self, axis=None, **kwargs): def mean(self, axis=None, **kwargs):
# infer, which axes will be collapsed # infer, which axes will be collapsed
......
...@@ -409,6 +409,7 @@ class _slicing_distributor(distributor): ...@@ -409,6 +409,7 @@ class _slicing_distributor(distributor):
self.local_start = self._local_size[0] self.local_start = self._local_size[0]
self.local_end = self._local_size[1] self.local_end = self._local_size[1]
self.global_shape = self._local_size[2] self.global_shape = self._local_size[2]
self.global_dim = reduce(lambda x, y: x*y, self.global_shape)
self.local_length = self.local_end - self.local_start self.local_length = self.local_end - self.local_start
self.local_shape = (self.local_length,) + tuple(self.global_shape[1:]) self.local_shape = (self.local_length,) + tuple(self.global_shape[1:])
...@@ -497,7 +498,70 @@ class _slicing_distributor(distributor): ...@@ -497,7 +498,70 @@ class _slicing_distributor(distributor):
op=op) op=op)
return recvbuf return recvbuf
def _contraction_helper(self, parent, function, axis=None, **kwargs): def _selective_allreduce(self, data, op, bufferQ=False):
size = self.comm.size
rank = self.comm.rank
if size == 1:
result_data = data
else:
# infer which data should be included in the allreduce and if its
# array data
if data is None:
got_array = np.array([0])
elif not isinstance(data, np.ndarray):
got_array = np.array([1])
elif np.issubdtype(data.dtype, np.complexfloating):
# MPI.MAX and MPI.MIN do not support complex data types
got_array = np.array([2])
else:
got_array = np.array([3])
got_array_list = np.empty(size, dtype=np.int)
self.comm.Allgather([got_array, MPI.INT],
[got_array_list, MPI.INT])
# get first node with non-None data
try:
start = next(i for i in xrange(size) if got_array_list[i] > 0)
except(StopIteration):
raise ValueError("ERROR: No process with non-None data.")
# check if the Uppercase function can be used or not
# -> check if op supports buffers and if we got real array-data
if bufferQ and got_array[start] == 3:
# Send the dtype and shape from the start process to the others
(new_dtype,
new_shape) = self.comm.bcast((data.dtype,
data.shape), root=start)
mpi_dtype = self._my_dtype_converter.to_mpi(new_dtype)
if rank == start:
result_data = data
else:
result_data = np.empty(new_shape, dtype=new_dtype)
self.comm.Bcast([result_data, mpi_dtype], root=start)
for i in xrange(start+1, size):
if got_array_list[i]:
if rank == i:
temp_data = data
else:
temp_data = np.empty(new_shape, dtype=new_dtype)
self.comm.Bcast([temp_data, mpi_dtype], root=i)
result_data = op(result_data, temp_data)
else:
result_data = self.comm.bcast(data, root=start)
for i in xrange(start+1, size):
if got_array_list[i]:
temp_data = self.comm.bcast(data, root=i)
result_data = op(result_data, temp_data)
return result_data
def contraction_helper(self, parent, function, allow_empty_contractions,
axis=None, **kwargs):
if axis == (): if axis == ():
return parent.copy() return parent.copy()
...@@ -509,40 +573,40 @@ class _slicing_distributor(distributor): ...@@ -509,40 +573,40 @@ class _slicing_distributor(distributor):
new_shape = tuple([old_shape[i] for i in xrange(len(old_shape)) new_shape = tuple([old_shape[i] for i in xrange(len(old_shape))
if i not in axis]) if i not in axis])
# do the contraction on the node's local data
local_data = parent.data local_data = parent.data
contracted_local_data = function(local_data, axis=axis, **kwargs)
new_dtype = contracted_local_data.dtype # if all local data is empty and empty_contractions are forbidden
# call function on the local_data in order to raise the right exception
if self.global_dim == 0 and not allow_empty_contractions:
# this shall raise an exception
function(local_data, axis=axis, **kwargs)
# do the contraction on the node's local data
if self.local_dim == 0 and not allow_empty_contractions:
# this case will only be reached if some nodes have data and some
# not
contracted_local_data = None
else:
# if local_dim == 0 but empty contractions will be allowed
# this will be a `contraction neutral` array.
contracted_local_data = function(local_data, axis=axis, **kwargs)
# check if additional contraction along the first axis must be done # check if additional contraction along the first axis must be done
if axis is None or 0 in axis: if axis is None or 0 in axis:
(mpi_op, bufferQ) = op_translate_dict[function] (mpi_op, bufferQ) = op_translate_dict[function]
# check if allreduce must be used instead of Allreduce contracted_global_data = self._selective_allreduce(
use_Uppercase = False contracted_local_data,
if bufferQ and isinstance(contracted_local_data, np.ndarray): mpi_op,
# MPI.MAX and MPI.MIN do not support complex data types bufferQ)
if not np.issubdtype(contracted_local_data.dtype,
np.complexfloating):
use_Uppercase = True
if use_Uppercase:
global_contracted_local_data = np.empty_like(
contracted_local_data)
new_mpi_dtype = self._my_dtype_converter.to_mpi(new_dtype)
self.comm.Allreduce([contracted_local_data,
new_mpi_dtype],
[global_contracted_local_data,
new_mpi_dtype],
op=mpi_op)
else:
global_contracted_local_data = self.comm.allreduce(
contracted_local_data, op=mpi_op)
new_dist_strategy = 'not' new_dist_strategy = 'not'
else: else:
contracted_global_data = contracted_local_data
new_dist_strategy = parent.distribution_strategy new_dist_strategy = parent.distribution_strategy
global_contracted_local_data = contracted_local_data
new_dtype = contracted_global_data.dtype
if new_shape == (): if new_shape == ():
result = global_contracted_local_data result = contracted_global_data
else: else:
# try to store the result in a distributed_data_object with the # try to store the result in a distributed_data_object with the
# distribution_strategy as parent # distribution_strategy as parent
...@@ -556,12 +620,12 @@ class _slicing_distributor(distributor): ...@@ -556,12 +620,12 @@ class _slicing_distributor(distributor):
# Contracting (4, 4) to (4,). # Contracting (4, 4) to (4,).
# (4, 4) was distributed (1, 4)...(1, 4) # (4, 4) was distributed (1, 4)...(1, 4)
# (4, ) is not distributed like (1,)...(1,) but like (2,)(2,)()()! # (4, ) is not distributed like (1,)...(1,) but like (2,)(2,)()()!
if result.local_shape != global_contracted_local_data.shape: if result.local_shape != contracted_global_data.shape:
result = parent.copy_empty( result = parent.copy_empty(
local_shape=global_contracted_local_data.shape, local_shape=contracted_global_data.shape,
dtype=new_dtype, dtype=new_dtype,
distribution_strategy='freeform') distribution_strategy='freeform')
result.set_local_data(global_contracted_local_data, copy=False) result.set_local_data(contracted_global_data, copy=False)
return result return result
...@@ -1814,7 +1878,8 @@ class _not_distributor(distributor): ...@@ -1814,7 +1878,8 @@ class _not_distributor(distributor):
recvbuf[:] = sendbuf recvbuf[:] = sendbuf
return recvbuf return recvbuf
def _contraction_helper(self, parent, function, axis=None, **kwargs): def contraction_helper(self, parent, function, allow_empty_contractions,
axis=None, **kwargs):
if axis == (): if axis == ():
return parent.copy() return parent.copy()
......
...@@ -1774,11 +1774,11 @@ class Test_axis(unittest.TestCase): ...@@ -1774,11 +1774,11 @@ class Test_axis(unittest.TestCase):
decimal=4) decimal=4)
@parameterized.expand( @parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any', itertools.product(['max', 'sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'min', 'amin', 'nanmin', 'argmin', 'max', 'amax', 'min', 'amin', 'nanmin', 'argmin', 'max', 'amax',
'nanmax', 'argmax'], 'nanmax', 'argmax'],
all_datatypes[1:], all_datatypes[1:],
[(1,), (6, 6)], [(1,)],# (6, 6)],
all_distribution_strategies, all_distribution_strategies,
[None, 0, (1, ), (0, 1)]), [None, 0, (1, ), (0, 1)]),
testcase_func_name=custom_name_func) testcase_func_name=custom_name_func)
...@@ -1805,8 +1805,8 @@ class Test_axis(unittest.TestCase): ...@@ -1805,8 +1805,8 @@ class Test_axis(unittest.TestCase):
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any', itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'min', 'amin', 'nanmin', 'argmin', 'max', 'amax', 'min', 'amin', 'nanmin', 'argmin', 'max', 'amax',
'nanmax', 'argmax'], 'nanmax', 'argmax'],
all_datatypes[1:], all_datatypes[1:],
[(4, 4, 3)], [(4, 2, 3)],
all_distribution_strategies, all_distribution_strategies,
[(0, 1), (1, 2), (0, 1, 2)]), [(0, 1), (1, 2), (0, 1, 2)]),
testcase_func_name=custom_name_func) testcase_func_name=custom_name_func)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment