Commit 74a9b76c authored by theos's avatar theos
Browse files

Fixed the exception handling in _selective_allreduce.

parent b511fde7
Pipeline #2152 skipped
...@@ -503,6 +503,8 @@ class _slicing_distributor(distributor): ...@@ -503,6 +503,8 @@ class _slicing_distributor(distributor):
rank = self.comm.rank rank = self.comm.rank
if size == 1: if size == 1:
if data is None:
raise ValueError("ERROR: No process with non-None data.")
result_data = data result_data = data
else: else:
...@@ -511,26 +513,31 @@ class _slicing_distributor(distributor): ...@@ -511,26 +513,31 @@ class _slicing_distributor(distributor):
if data is None: if data is None:
got_array = np.array([0]) got_array = np.array([0])
elif not isinstance(data, np.ndarray): elif not isinstance(data, np.ndarray):
got_array = np.array([2])
elif reduce(lambda x, y: x*y, data.shape) == 0:
got_array = np.array([1]) got_array = np.array([1])
elif np.issubdtype(data.dtype, np.complexfloating): elif np.issubdtype(data.dtype, np.complexfloating):
# MPI.MAX and MPI.MIN do not support complex data types # MPI.MAX and MPI.MIN do not support complex data types
got_array = np.array([2])
else:
got_array = np.array([3]) got_array = np.array([3])
else:
got_array = np.array([4])
got_array_list = np.empty(size, dtype=np.int) got_array_list = np.empty(size, dtype=np.int)
self.comm.Allgather([got_array, MPI.INT], self.comm.Allgather([got_array, MPI.INT],
[got_array_list, MPI.INT]) [got_array_list, MPI.INT])
if reduce(lambda x, y: x & y, got_array_list == 1):
return data
# get first node with non-None data # get first node with non-None data
try: try:
start = next(i for i in xrange(size) if got_array_list[i] > 0) start = next(i for i in xrange(size) if got_array_list[i] > 1)
except(StopIteration): except(StopIteration):
raise ValueError("ERROR: No process with non-None data.") raise ValueError("ERROR: No process with non-None data.")
# check if the Uppercase function can be used or not # check if the Uppercase function can be used or not
# -> check if op supports buffers and if we got real array-data # -> check if op supports buffers and if we got real array-data
if bufferQ and got_array[start] == 3: if bufferQ and got_array[start] == 4:
# Send the dtype and shape from the start process to the others # Send the dtype and shape from the start process to the others
(new_dtype, (new_dtype,
new_shape) = self.comm.bcast((data.dtype, new_shape) = self.comm.bcast((data.dtype,
...@@ -544,7 +551,7 @@ class _slicing_distributor(distributor): ...@@ -544,7 +551,7 @@ class _slicing_distributor(distributor):
self.comm.Bcast([result_data, mpi_dtype], root=start) self.comm.Bcast([result_data, mpi_dtype], root=start)
for i in xrange(start+1, size): for i in xrange(start+1, size):
if got_array_list[i]: if got_array_list[i] > 1:
if rank == i: if rank == i:
temp_data = data temp_data = data
else: else:
...@@ -555,7 +562,7 @@ class _slicing_distributor(distributor): ...@@ -555,7 +562,7 @@ class _slicing_distributor(distributor):
else: else:
result_data = self.comm.bcast(data, root=start) result_data = self.comm.bcast(data, root=start)
for i in xrange(start+1, size): for i in xrange(start+1, size):
if got_array_list[i]: if got_array_list[i] > 1:
temp_data = self.comm.bcast(data, root=i) temp_data = self.comm.bcast(data, root=i)
result_data = op(result_data, temp_data) result_data = op(result_data, temp_data)
return result_data return result_data
...@@ -575,21 +582,10 @@ class _slicing_distributor(distributor): ...@@ -575,21 +582,10 @@ class _slicing_distributor(distributor):
local_data = parent.data local_data = parent.data
# if all local data is empty and empty_contractions are forbidden try:
# call function on the local_data in order to raise the right exception
if self.global_dim == 0 and not allow_empty_contractions:
# this shall raise an exception
function(local_data, axis=axis, **kwargs)
# do the contraction on the node's local data
if self.local_dim == 0 and not allow_empty_contractions:
# this case will only be reached if some nodes have data and some
# not
contracted_local_data = None
else:
# if local_dim == 0 but empty contractions will be allowed
# this will be a `contraction neutral` array.
contracted_local_data = function(local_data, axis=axis, **kwargs) contracted_local_data = function(local_data, axis=axis, **kwargs)
except(ValueError):
contracted_local_data = None
# check if additional contraction along the first axis must be done # check if additional contraction along the first axis must be done
if axis is None or 0 in axis: if axis is None or 0 in axis:
...@@ -600,6 +596,9 @@ class _slicing_distributor(distributor): ...@@ -600,6 +596,9 @@ class _slicing_distributor(distributor):
bufferQ) bufferQ)
new_dist_strategy = 'not' new_dist_strategy = 'not'
else: else:
if contracted_local_data is None:
# raise the exception implicitly
function(local_data, axis=axis, **kwargs)
contracted_global_data = contracted_local_data contracted_global_data = contracted_local_data
new_dist_strategy = parent.distribution_strategy new_dist_strategy = parent.distribution_strategy
......
...@@ -1752,7 +1752,7 @@ class Test_axis(unittest.TestCase): ...@@ -1752,7 +1752,7 @@ class Test_axis(unittest.TestCase):
all_distribution_strategies, all_distribution_strategies,
[None, (0, ), (1, ), (0, 1)]), [None, (0, ), (1, ), (0, 1)]),
testcase_func_name=custom_name_func) testcase_func_name=custom_name_func)
def test_axis_with_operations_0_dimention(self, function, dtype, def test_axis_with_operations_0_dimension(self, function, dtype,
global_shape, global_shape,
distribution_strategy, axis): distribution_strategy, axis):
(a, obj) = generate_data(global_shape, dtype, (a, obj) = generate_data(global_shape, dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment