Commit 74a9b76c authored by theos's avatar theos
Browse files

Fixed the exception handling in _selective_allreduce.

parent b511fde7
Pipeline #2152 skipped
......@@ -503,6 +503,8 @@ class _slicing_distributor(distributor):
rank = self.comm.rank
if size == 1:
if data is None:
raise ValueError("ERROR: No process with non-None data.")
result_data = data
else:
......@@ -511,26 +513,31 @@ class _slicing_distributor(distributor):
if data is None:
got_array = np.array([0])
elif not isinstance(data, np.ndarray):
got_array = np.array([2])
elif reduce(lambda x, y: x*y, data.shape) == 0:
got_array = np.array([1])
elif np.issubdtype(data.dtype, np.complexfloating):
# MPI.MAX and MPI.MIN do not support complex data types
got_array = np.array([2])
else:
got_array = np.array([3])
else:
got_array = np.array([4])
got_array_list = np.empty(size, dtype=np.int)
self.comm.Allgather([got_array, MPI.INT],
[got_array_list, MPI.INT])
if reduce(lambda x, y: x & y, got_array_list == 1):
return data
# get first node with non-None data
try:
start = next(i for i in xrange(size) if got_array_list[i] > 0)
start = next(i for i in xrange(size) if got_array_list[i] > 1)
except(StopIteration):
raise ValueError("ERROR: No process with non-None data.")
# check if the Uppercase function can be used or not
# -> check if op supports buffers and if we got real array-data
if bufferQ and got_array[start] == 3:
if bufferQ and got_array[start] == 4:
# Send the dtype and shape from the start process to the others
(new_dtype,
new_shape) = self.comm.bcast((data.dtype,
......@@ -544,7 +551,7 @@ class _slicing_distributor(distributor):
self.comm.Bcast([result_data, mpi_dtype], root=start)
for i in xrange(start+1, size):
if got_array_list[i]:
if got_array_list[i] > 1:
if rank == i:
temp_data = data
else:
......@@ -555,7 +562,7 @@ class _slicing_distributor(distributor):
else:
result_data = self.comm.bcast(data, root=start)
for i in xrange(start+1, size):
if got_array_list[i]:
if got_array_list[i] > 1:
temp_data = self.comm.bcast(data, root=i)
result_data = op(result_data, temp_data)
return result_data
......@@ -575,21 +582,10 @@ class _slicing_distributor(distributor):
local_data = parent.data
# if all local data is empty and empty_contractions are forbidden
# call function on the local_data in order to raise the right exception
if self.global_dim == 0 and not allow_empty_contractions:
# this shall raise an exception
function(local_data, axis=axis, **kwargs)
# do the contraction on the node's local data
if self.local_dim == 0 and not allow_empty_contractions:
# this case will only be reached if some nodes have data and some
# not
contracted_local_data = None
else:
# if local_dim == 0 but empty contractions will be allowed
# this will be a `contraction neutral` array.
try:
contracted_local_data = function(local_data, axis=axis, **kwargs)
except(ValueError):
contracted_local_data = None
# check if additional contraction along the first axis must be done
if axis is None or 0 in axis:
......@@ -600,6 +596,9 @@ class _slicing_distributor(distributor):
bufferQ)
new_dist_strategy = 'not'
else:
if contracted_local_data is None:
# raise the exception implicitly
function(local_data, axis=axis, **kwargs)
contracted_global_data = contracted_local_data
new_dist_strategy = parent.distribution_strategy
......
......@@ -1752,7 +1752,7 @@ class Test_axis(unittest.TestCase):
all_distribution_strategies,
[None, (0, ), (1, ), (0, 1)]),
testcase_func_name=custom_name_func)
def test_axis_with_operations_0_dimention(self, function, dtype,
def test_axis_with_operations_0_dimension(self, function, dtype,
global_shape,
distribution_strategy, axis):
(a, obj) = generate_data(global_shape, dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment