Commit af793f09 authored by csongor's avatar csongor
Browse files

fix axus test for all cases and exceptions

parent 154a163f
Pipeline #2048 skipped
......@@ -1748,56 +1748,85 @@ class Test_axis(unittest.TestCase):
'min', 'amin', 'nanmin', 'argmin', 'argmin_nonflat',
'max', 'amax', 'nanmax', 'argmax', 'argmax_nonflat'],
all_datatypes[1:],
[(0,), (1,), (6, 6), (4, 4, 3)],
[(0,), (4, 0, 3)],
all_distribution_strategies,
[None, (0, ), (1, ), (0, 1)]),
testcase_func_name=custom_name_func)
def test_axis_with_operations_0_dimention(self, function, dtype,
global_shape,
distribution_strategy, axis):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
if function in ['min', 'amin', 'nanmin', 'argmin', 'argmin_nonflat',
'max', 'amax', 'nanmax', 'argmax', 'argmax_nonflat']:
if not (function in ['min', 'amin', 'nanmin', 'max', 'amax',
'nanmax'] and axis == (0, ) and global_shape == (4, 0, 3)):
assert_raises(ValueError, lambda: getattr(obj, function)
(axis=axis))
else:
if axis in [(1, ), (0, 1)] and global_shape == (0,):
assert_raises(StandardError, lambda: getattr(obj, function)
(axis=axis))
else:
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
@parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'min', 'amin', 'nanmin', 'argmin', 'max', 'amax',
'nanmax', 'argmax'],
all_datatypes[1:],
[(1,), (6, 6)],
all_distribution_strategies,
[None, 0, (1, ), (0, 1)]),
testcase_func_name=custom_name_func)
def test_axis_with_functions(self, function, dtype, global_shape,
distribution_strategy, axis):
def test_axis_with_operations(self, function, dtype, global_shape,
distribution_strategy, axis):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
if function in ['argmin', 'argmin_nonflat', 'argmax', 'argmax_nonflat']:
assert_raises(NotImplementedError)
if function in ['argmin', 'argmax'] and axis is not None:
assert_raises(NotImplementedError, lambda: getattr(obj, function)
(axis=axis))
else:
if global_shape != (0,) and global_shape != (1,):
if global_shape != (1,):
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
else:
if function in ['min', 'amin', 'nanmin', 'max',
'amax', 'nanmax']:
assert_raises(ValueError)
else:
if axis is None or axis == 0 or axis == (0,):
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
if axis is None or axis == 0 or axis == (0,):
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
@parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'min', 'amin', 'nanmin', 'argmin', 'argmin_nonflat',
'max', 'amax', 'nanmax', 'argmax', 'argmax_nonflat'],
'min', 'amin', 'nanmin', 'argmin', 'max', 'amax',
'nanmax', 'argmax'],
all_datatypes[1:],
[(4, 4, 3), (4, 0, 3)],
all_distribution_strategies, [(0, 1), (1, 2)]),
[(4, 4, 3)],
all_distribution_strategies,
[(0, 1), (1, 2), (0, 1, 2)]),
testcase_func_name=custom_name_func)
def test_axis_with_functions_for_many_dimentions(self, function, dtype,
global_shape,
distribution_strategy,
axis):
def test_axis_with_operations_many_dimentions(self, function, dtype,
global_shape,
distribution_strategy,
axis):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
if function in ['argmin', 'argmin_nonflat', 'argmax', 'argmax_nonflat']:
assert_raises(NotImplementedError)
if function in ['argmin', 'argmax'] and axis is not None:
assert_raises(NotImplementedError, lambda: getattr(obj, function)
(axis=axis))
else:
if function in ['min', 'amin', 'nanmin', 'max', 'amax', 'nanmax']\
and np.prod(global_shape) == 0:
assert_raises(ValueError)
and 0 in global_shape:
assert_raises(ValueError, lambda: getattr(obj, function)
(axis=axis))
else:
assert_almost_equal(getattr(obj, function)
(axis=axis).get_full_data(),
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
......@@ -1822,3 +1851,41 @@ class Test_axis(unittest.TestCase):
assert_almost_equal(getattr(obj, 'median')(axis=axis),
getattr(np, 'median')(a, axis=axis),
decimal=4)
@parameterized.expand(
itertools.product([('argmin_nonflat', 'argmin'),
('argmax_nonflat', 'argmax')],
all_datatypes[1:],
[(0,), (1,), (4, 4, 3), (4, 0, 3)],
all_distribution_strategies,
[None, (1, ), (1, 2)]),
testcase_func_name=custom_name_func)
def test_axis_for_nonflats(self, function_pair, dtype, global_shape,
distribution_strategy, axis):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
if 0 in global_shape:
assert_raises(ValueError,
lambda: getattr(obj, function_pair[0])(axis=axis))
else:
if axis is not None:
assert_raises(NotImplementedError,
lambda: getattr(obj, function_pair[0])(axis=axis))
else:
if global_shape != (0,) and global_shape != (1,):
assert_almost_equal(
getattr(obj, function_pair[0])(axis=axis),
np.unravel_index(getattr(np, function_pair[1])
(a, axis=axis),
dims=global_shape),
decimal=4)
else:
assert_almost_equal(getattr(obj, function_pair[0])
(axis=axis),
np.unravel_index(
getattr(np, function_pair[1])
(a, axis=axis),
dims=global_shape),
decimal=4)
......@@ -1358,23 +1358,31 @@ class Test_axis(unittest.TestCase):
@parameterized.expand(
itertools.product(point_like_spaces, [8, 16],
['sum', 'prod', 'mean', 'var', 'std', 'median', 'all',
'any', 'amin', 'nanmin', 'argmin',
'argmin_nonflat', 'amax', 'nanmax', 'argmax',
'argmax_nonflat'],
'any', 'amin', 'nanmin', 'argmin', 'amax', 'nanmax',
'argmax'],
[None, (0,)],
DATAMODELS['point_space']),
testcase_func_name=custom_name_func)
def test_unary_operations(self, name, num, op, axis, datamodel):
s = generate_space_with_size(name, np.prod(num), datamodel=datamodel)
s = generate_space_with_size(name, num, datamodel=datamodel)
d = generate_data(s)
a = d.get_full_data()
if op in ['argmin', 'argmin_nonflat', 'argmax', 'argmax_nonflat']:
assert_raises(NotImplementedError)
if op in ['argmin', 'argmax'] and axis is not None:
assert_raises(NotImplementedError, lambda: s.unary_operation
(d, op, axis=axis))
else:
assert_almost_equal(s.unary_operation(d, op, axis=axis),
getattr(np, op)(a, axis=axis), decimal=4)
if name in ['rg_space']:
assert_almost_equal(s.unary_operation(d, op, axis=(0, 1)),
getattr(np, op)(a, axis=(0, 1)), decimal=4)
assert_almost_equal(s.unary_operation(d, op, axis=(1,)),
getattr(np, op)(a, axis=(1,)), decimal=4)
if op in ['argmin', 'argmax']:
assert_raises(NotImplementedError, lambda: s.unary_operation
(d, op, axis=(0, 1)))
assert_raises(NotImplementedError, lambda: s.unary_operation
(d, op, axis=(1, )))
else:
assert_almost_equal(s.unary_operation(d, op, axis=(0, 1)),
getattr(np, op)(a, axis=(0, 1)),
decimal=4)
assert_almost_equal(s.unary_operation(d, op, axis=(1,)),
getattr(np, op)(a, axis=(1,)),
decimal=4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment