Commit 724a857f authored by csongor's avatar csongor
Browse files

fix selective_reduce and axis tests for max, amax, etc

parent 74a9b76c
Pipeline #2201 skipped
......@@ -498,7 +498,7 @@ class _slicing_distributor(distributor):
op=op)
return recvbuf
def _selective_allreduce(self, data, op, bufferQ=False):
def _selective_allreduce(self, data, function, bufferQ=False):
size = self.comm.size
rank = self.comm.rank
......@@ -557,14 +557,29 @@ class _slicing_distributor(distributor):
else:
temp_data = np.empty(new_shape, dtype=new_dtype)
self.comm.Bcast([temp_data, mpi_dtype], root=i)
result_data = op(result_data, temp_data)
result_data = function([result_data, temp_data],
axis=(0,))
else:
result_data = self.comm.bcast(data, root=start)
if bufferQ and got_array_list[start] == 4:
# This if is here just because multiple processes need
# the same runtime and message exchange, or they derail
(new_dtype,
new_shape) = self.comm.bcast((result_data.dtype,
result_data.shape),
root=start)
mpi_dtype = self._my_dtype_converter.to_mpi(new_dtype)
if rank == start:
result_data2 = data
else:
result_data2 = np.empty(new_shape, dtype=new_dtype)
self.comm.Bcast([result_data2, mpi_dtype], root=start)
for i in xrange(start+1, size):
if got_array_list[i] > 1:
temp_data = self.comm.bcast(data, root=i)
result_data = op(result_data, temp_data)
result_data = function([result_data, temp_data],
axis=(0,))
return result_data
def contraction_helper(self, parent, function, allow_empty_contractions,
......@@ -592,7 +607,7 @@ class _slicing_distributor(distributor):
(mpi_op, bufferQ) = op_translate_dict[function]
contracted_global_data = self._selective_allreduce(
contracted_local_data,
mpi_op,
function,
bufferQ)
new_dist_strategy = 'not'
else:
......
......@@ -1774,11 +1774,11 @@ class Test_axis(unittest.TestCase):
decimal=4)
@parameterized.expand(
itertools.product(['max', 'sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'min', 'amin', 'nanmin', 'argmin', 'max', 'amax',
'nanmax', 'argmax'],
all_datatypes[1:],
[(1,)],# (6, 6)],
[(1, ), (2, 3)],
all_distribution_strategies,
[None, 0, (1, ), (0, 1)]),
testcase_func_name=custom_name_func)
......@@ -1796,7 +1796,7 @@ class Test_axis(unittest.TestCase):
getattr(np, function)(a, axis=axis),
decimal=4)
else:
if axis is None or axis == 0 or axis == (0,):
if axis in [None, 0, (0,)]:
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
......@@ -1805,7 +1805,7 @@ class Test_axis(unittest.TestCase):
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'min', 'amin', 'nanmin', 'argmin', 'max', 'amax',
'nanmax', 'argmax'],
all_datatypes[1:],
all_datatypes[1:],
[(4, 2, 3)],
all_distribution_strategies,
[(0, 1), (1, 2), (0, 1, 2)]),
......
......@@ -1356,7 +1356,7 @@ print generate_space('rg_space')
class Test_axis(unittest.TestCase):
@parameterized.expand(
itertools.product(point_like_spaces, [8, 16],
itertools.product(point_like_spaces, [8, 4],
['sum', 'prod', 'mean', 'var', 'std', 'median', 'all',
'any', 'amin', 'nanmin', 'argmin', 'amax', 'nanmax',
'argmax'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment