Commit 0bba7a71 authored by theos's avatar theos
Browse files

Added a _selective_allreduce method to the slicing distributor in order to fix...

Added a _selective_allreduce method to the slicing distributor in order to fix contractions that involve nodes with empty data.
parent af793f09
Pipeline #2053 skipped
......@@ -1057,10 +1057,12 @@ class distributed_data_object(object):
--------
numpy.amin
"""
return self.distributor._contraction_helper(self,
np.amin,
axis=axis,
**kwargs)
return self.distributor.contraction_helper(
self,
np.amin,
allow_empty_contractions=False,
axis=axis,
**kwargs)
def nanmin(self, axis=None, **kwargs):
""" Returns the minimum of an array ignoring all NaNs.
......@@ -1069,10 +1071,12 @@ class distributed_data_object(object):
--------
numpy.nanmin
"""
return self.distributor._contraction_helper(self,
np.nanmin,
axis=axis,
**kwargs)
return self.distributor.contraction_helper(
self,
np.nanmin,
allow_empty_contractions=False,
axis=axis,
**kwargs)
def max(self, axis=None, **kwargs):
""" x.max() <==> x.amax() """
......@@ -1085,10 +1089,12 @@ class distributed_data_object(object):
--------
numpy.amax
"""
return self.distributor._contraction_helper(self,
np.amax,
axis=axis,
**kwargs)
return self.distributor.contraction_helper(
self,
np.amax,
allow_empty_contractions=False,
axis=axis,
**kwargs)
def nanmax(self, axis=None, **kwargs):
""" Returns the maximum of an array ignoring all NaNs.
......@@ -1097,10 +1103,12 @@ class distributed_data_object(object):
--------
numpy.nanmax
"""
return self.distributor._contraction_helper(self,
np.nanmax,
axis=axis,
**kwargs)
return self.distributor.contraction_helper(
self,
np.nanmax,
allow_empty_contractions=False,
axis=axis,
**kwargs)
def sum(self, axis=None, **kwargs):
""" Sums the array elements.
......@@ -1109,10 +1117,12 @@ class distributed_data_object(object):
--------
numpy.sum
"""
return self.distributor._contraction_helper(self,
np.sum,
axis=axis,
**kwargs)
return self.distributor.contraction_helper(
self,
np.sum,
allow_empty_contractions=True,
axis=axis,
**kwargs)
def prod(self, axis=None, **kwargs):
""" Multiplies the array elements.
......@@ -1121,22 +1131,28 @@ class distributed_data_object(object):
--------
numpy.prod
"""
return self.distributor._contraction_helper(self,
np.prod,
axis=axis,
**kwargs)
return self.distributor.contraction_helper(
self,
np.prod,
allow_empty_contractions=True,
axis=axis,
**kwargs)
def all(self, axis=None, **kwargs):
return self.distributor._contraction_helper(self,
np.all,
axis=axis,
**kwargs)
return self.distributor.contraction_helper(
self,
np.all,
allow_empty_contractions=True,
axis=axis,
**kwargs)
def any(self, axis=None, **kwargs):
return self.distributor._contraction_helper(self,
np.any,
axis=axis,
**kwargs)
return self.distributor.contraction_helper(
self,
np.any,
allow_empty_contractions=True,
axis=axis,
**kwargs)
def mean(self, axis=None, **kwargs):
# infer, which axes will be collapsed
......
......@@ -409,6 +409,7 @@ class _slicing_distributor(distributor):
self.local_start = self._local_size[0]
self.local_end = self._local_size[1]
self.global_shape = self._local_size[2]
self.global_dim = reduce(lambda x, y: x*y, self.global_shape)
self.local_length = self.local_end - self.local_start
self.local_shape = (self.local_length,) + tuple(self.global_shape[1:])
......@@ -497,7 +498,70 @@ class _slicing_distributor(distributor):
op=op)
return recvbuf
def _contraction_helper(self, parent, function, axis=None, **kwargs):
def _selective_allreduce(self, data, op, bufferQ=False):
size = self.comm.size
rank = self.comm.rank
if size == 1:
result_data = data
else:
# infer which data should be included in the allreduce and if its
# array data
if data is None:
got_array = np.array([0])
elif not isinstance(data, np.ndarray):
got_array = np.array([1])
elif np.issubdtype(data.dtype, np.complexfloating):
# MPI.MAX and MPI.MIN do not support complex data types
got_array = np.array([2])
else:
got_array = np.array([3])
got_array_list = np.empty(size, dtype=np.int)
self.comm.Allgather([got_array, MPI.INT],
[got_array_list, MPI.INT])
# get first node with non-None data
try:
start = next(i for i in xrange(size) if got_array_list[i] > 0)
except(StopIteration):
raise ValueError("ERROR: No process with non-None data.")
# check if the Uppercase function can be used or not
# -> check if op supports buffers and if we got real array-data
if bufferQ and got_array[start] == 3:
# Send the dtype and shape from the start process to the others
(new_dtype,
new_shape) = self.comm.bcast((data.dtype,
data.shape), root=start)
mpi_dtype = self._my_dtype_converter.to_mpi(new_dtype)
if rank == start:
result_data = data
else:
result_data = np.empty(new_shape, dtype=new_dtype)
self.comm.Bcast([result_data, mpi_dtype], root=start)
for i in xrange(start+1, size):
if got_array_list[i]:
if rank == i:
temp_data = data
else:
temp_data = np.empty(new_shape, dtype=new_dtype)
self.comm.Bcast([temp_data, mpi_dtype], root=i)
result_data = op(result_data, temp_data)
else:
result_data = self.comm.bcast(data, root=start)
for i in xrange(start+1, size):
if got_array_list[i]:
temp_data = self.comm.bcast(data, root=i)
result_data = op(result_data, temp_data)
return result_data
def contraction_helper(self, parent, function, allow_empty_contractions,
axis=None, **kwargs):
if axis == ():
return parent.copy()
......@@ -509,40 +573,40 @@ class _slicing_distributor(distributor):
new_shape = tuple([old_shape[i] for i in xrange(len(old_shape))
if i not in axis])
# do the contraction on the node's local data
local_data = parent.data
contracted_local_data = function(local_data, axis=axis, **kwargs)
new_dtype = contracted_local_data.dtype
# if all local data is empty and empty_contractions are forbidden
# call function on the local_data in order to raise the right exception
if self.global_dim == 0 and not allow_empty_contractions:
# this shall raise an exception
function(local_data, axis=axis, **kwargs)
# do the contraction on the node's local data
if self.local_dim == 0 and not allow_empty_contractions:
# this case will only be reached if some nodes have data and some
# not
contracted_local_data = None
else:
# if local_dim == 0 but empty contractions will be allowed
# this will be a `contraction neutral` array.
contracted_local_data = function(local_data, axis=axis, **kwargs)
# check if additional contraction along the first axis must be done
if axis is None or 0 in axis:
(mpi_op, bufferQ) = op_translate_dict[function]
# check if allreduce must be used instead of Allreduce
use_Uppercase = False
if bufferQ and isinstance(contracted_local_data, np.ndarray):
# MPI.MAX and MPI.MIN do not support complex data types
if not np.issubdtype(contracted_local_data.dtype,
np.complexfloating):
use_Uppercase = True
if use_Uppercase:
global_contracted_local_data = np.empty_like(
contracted_local_data)
new_mpi_dtype = self._my_dtype_converter.to_mpi(new_dtype)
self.comm.Allreduce([contracted_local_data,
new_mpi_dtype],
[global_contracted_local_data,
new_mpi_dtype],
op=mpi_op)
else:
global_contracted_local_data = self.comm.allreduce(
contracted_local_data, op=mpi_op)
contracted_global_data = self._selective_allreduce(
contracted_local_data,
mpi_op,
bufferQ)
new_dist_strategy = 'not'
else:
contracted_global_data = contracted_local_data
new_dist_strategy = parent.distribution_strategy
global_contracted_local_data = contracted_local_data
new_dtype = contracted_global_data.dtype
if new_shape == ():
result = global_contracted_local_data
result = contracted_global_data
else:
# try to store the result in a distributed_data_object with the
# distribution_strategy as parent
......@@ -556,12 +620,12 @@ class _slicing_distributor(distributor):
# Contracting (4, 4) to (4,).
# (4, 4) was distributed (1, 4)...(1, 4)
# (4, ) is not distributed like (1,)...(1,) but like (2,)(2,)()()!
if result.local_shape != global_contracted_local_data.shape:
if result.local_shape != contracted_global_data.shape:
result = parent.copy_empty(
local_shape=global_contracted_local_data.shape,
local_shape=contracted_global_data.shape,
dtype=new_dtype,
distribution_strategy='freeform')
result.set_local_data(global_contracted_local_data, copy=False)
result.set_local_data(contracted_global_data, copy=False)
return result
......@@ -1814,7 +1878,8 @@ class _not_distributor(distributor):
recvbuf[:] = sendbuf
return recvbuf
def _contraction_helper(self, parent, function, axis=None, **kwargs):
def contraction_helper(self, parent, function, allow_empty_contractions,
axis=None, **kwargs):
if axis == ():
return parent.copy()
......
......@@ -1774,11 +1774,11 @@ class Test_axis(unittest.TestCase):
decimal=4)
@parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
itertools.product(['max', 'sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'min', 'amin', 'nanmin', 'argmin', 'max', 'amax',
'nanmax', 'argmax'],
all_datatypes[1:],
[(1,), (6, 6)],
[(1,)],# (6, 6)],
all_distribution_strategies,
[None, 0, (1, ), (0, 1)]),
testcase_func_name=custom_name_func)
......@@ -1805,8 +1805,8 @@ class Test_axis(unittest.TestCase):
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'min', 'amin', 'nanmin', 'argmin', 'max', 'amax',
'nanmax', 'argmax'],
all_datatypes[1:],
[(4, 4, 3)],
all_datatypes[1:],
[(4, 2, 3)],
all_distribution_strategies,
[(0, 1), (1, 2), (0, 1, 2)]),
testcase_func_name=custom_name_func)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment