Commit af793f09 authored by csongor's avatar csongor
Browse files

fix axus test for all cases and exceptions

parent 154a163f
Pipeline #2048 skipped
...@@ -1748,56 +1748,85 @@ class Test_axis(unittest.TestCase): ...@@ -1748,56 +1748,85 @@ class Test_axis(unittest.TestCase):
'min', 'amin', 'nanmin', 'argmin', 'argmin_nonflat', 'min', 'amin', 'nanmin', 'argmin', 'argmin_nonflat',
'max', 'amax', 'nanmax', 'argmax', 'argmax_nonflat'], 'max', 'amax', 'nanmax', 'argmax', 'argmax_nonflat'],
all_datatypes[1:], all_datatypes[1:],
[(0,), (1,), (6, 6), (4, 4, 3)], [(0,), (4, 0, 3)],
all_distribution_strategies,
[None, (0, ), (1, ), (0, 1)]),
testcase_func_name=custom_name_func)
def test_axis_with_operations_0_dimention(self, function, dtype,
global_shape,
distribution_strategy, axis):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
if function in ['min', 'amin', 'nanmin', 'argmin', 'argmin_nonflat',
'max', 'amax', 'nanmax', 'argmax', 'argmax_nonflat']:
if not (function in ['min', 'amin', 'nanmin', 'max', 'amax',
'nanmax'] and axis == (0, ) and global_shape == (4, 0, 3)):
assert_raises(ValueError, lambda: getattr(obj, function)
(axis=axis))
else:
if axis in [(1, ), (0, 1)] and global_shape == (0,):
assert_raises(StandardError, lambda: getattr(obj, function)
(axis=axis))
else:
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
@parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'min', 'amin', 'nanmin', 'argmin', 'max', 'amax',
'nanmax', 'argmax'],
all_datatypes[1:],
[(1,), (6, 6)],
all_distribution_strategies, all_distribution_strategies,
[None, 0, (1, ), (0, 1)]), [None, 0, (1, ), (0, 1)]),
testcase_func_name=custom_name_func) testcase_func_name=custom_name_func)
def test_axis_with_functions(self, function, dtype, global_shape, def test_axis_with_operations(self, function, dtype, global_shape,
distribution_strategy, axis): distribution_strategy, axis):
(a, obj) = generate_data(global_shape, dtype, (a, obj) = generate_data(global_shape, dtype,
distribution_strategy, distribution_strategy,
strictly_positive=True) strictly_positive=True)
if function in ['argmin', 'argmin_nonflat', 'argmax', 'argmax_nonflat']: if function in ['argmin', 'argmax'] and axis is not None:
assert_raises(NotImplementedError) assert_raises(NotImplementedError, lambda: getattr(obj, function)
(axis=axis))
else: else:
if global_shape != (0,) and global_shape != (1,): if global_shape != (1,):
assert_almost_equal(getattr(obj, function)(axis=axis), assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis), getattr(np, function)(a, axis=axis),
decimal=4) decimal=4)
else: else:
if function in ['min', 'amin', 'nanmin', 'max', if axis is None or axis == 0 or axis == (0,):
'amax', 'nanmax']: assert_almost_equal(getattr(obj, function)(axis=axis),
assert_raises(ValueError) getattr(np, function)(a, axis=axis),
else: decimal=4)
if axis is None or axis == 0 or axis == (0,):
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
@parameterized.expand( @parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any', itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'min', 'amin', 'nanmin', 'argmin', 'argmin_nonflat', 'min', 'amin', 'nanmin', 'argmin', 'max', 'amax',
'max', 'amax', 'nanmax', 'argmax', 'argmax_nonflat'], 'nanmax', 'argmax'],
all_datatypes[1:], all_datatypes[1:],
[(4, 4, 3), (4, 0, 3)], [(4, 4, 3)],
all_distribution_strategies, [(0, 1), (1, 2)]), all_distribution_strategies,
[(0, 1), (1, 2), (0, 1, 2)]),
testcase_func_name=custom_name_func) testcase_func_name=custom_name_func)
def test_axis_with_functions_for_many_dimentions(self, function, dtype, def test_axis_with_operations_many_dimentions(self, function, dtype,
global_shape, global_shape,
distribution_strategy, distribution_strategy,
axis): axis):
(a, obj) = generate_data(global_shape, dtype, (a, obj) = generate_data(global_shape, dtype,
distribution_strategy, distribution_strategy,
strictly_positive=True) strictly_positive=True)
if function in ['argmin', 'argmin_nonflat', 'argmax', 'argmax_nonflat']: if function in ['argmin', 'argmax'] and axis is not None:
assert_raises(NotImplementedError) assert_raises(NotImplementedError, lambda: getattr(obj, function)
(axis=axis))
else: else:
if function in ['min', 'amin', 'nanmin', 'max', 'amax', 'nanmax']\ if function in ['min', 'amin', 'nanmin', 'max', 'amax', 'nanmax']\
and np.prod(global_shape) == 0: and 0 in global_shape:
assert_raises(ValueError) assert_raises(ValueError, lambda: getattr(obj, function)
(axis=axis))
else: else:
assert_almost_equal(getattr(obj, function) assert_almost_equal(getattr(obj, function)(axis=axis),
(axis=axis).get_full_data(),
getattr(np, function)(a, axis=axis), getattr(np, function)(a, axis=axis),
decimal=4) decimal=4)
...@@ -1822,3 +1851,41 @@ class Test_axis(unittest.TestCase): ...@@ -1822,3 +1851,41 @@ class Test_axis(unittest.TestCase):
assert_almost_equal(getattr(obj, 'median')(axis=axis), assert_almost_equal(getattr(obj, 'median')(axis=axis),
getattr(np, 'median')(a, axis=axis), getattr(np, 'median')(a, axis=axis),
decimal=4) decimal=4)
@parameterized.expand(
itertools.product([('argmin_nonflat', 'argmin'),
('argmax_nonflat', 'argmax')],
all_datatypes[1:],
[(0,), (1,), (4, 4, 3), (4, 0, 3)],
all_distribution_strategies,
[None, (1, ), (1, 2)]),
testcase_func_name=custom_name_func)
def test_axis_for_nonflats(self, function_pair, dtype, global_shape,
distribution_strategy, axis):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
if 0 in global_shape:
assert_raises(ValueError,
lambda: getattr(obj, function_pair[0])(axis=axis))
else:
if axis is not None:
assert_raises(NotImplementedError,
lambda: getattr(obj, function_pair[0])(axis=axis))
else:
if global_shape != (0,) and global_shape != (1,):
assert_almost_equal(
getattr(obj, function_pair[0])(axis=axis),
np.unravel_index(getattr(np, function_pair[1])
(a, axis=axis),
dims=global_shape),
decimal=4)
else:
assert_almost_equal(getattr(obj, function_pair[0])
(axis=axis),
np.unravel_index(
getattr(np, function_pair[1])
(a, axis=axis),
dims=global_shape),
decimal=4)
...@@ -1358,23 +1358,31 @@ class Test_axis(unittest.TestCase): ...@@ -1358,23 +1358,31 @@ class Test_axis(unittest.TestCase):
@parameterized.expand( @parameterized.expand(
itertools.product(point_like_spaces, [8, 16], itertools.product(point_like_spaces, [8, 16],
['sum', 'prod', 'mean', 'var', 'std', 'median', 'all', ['sum', 'prod', 'mean', 'var', 'std', 'median', 'all',
'any', 'amin', 'nanmin', 'argmin', 'any', 'amin', 'nanmin', 'argmin', 'amax', 'nanmax',
'argmin_nonflat', 'amax', 'nanmax', 'argmax', 'argmax'],
'argmax_nonflat'],
[None, (0,)], [None, (0,)],
DATAMODELS['point_space']), DATAMODELS['point_space']),
testcase_func_name=custom_name_func) testcase_func_name=custom_name_func)
def test_unary_operations(self, name, num, op, axis, datamodel): def test_unary_operations(self, name, num, op, axis, datamodel):
s = generate_space_with_size(name, np.prod(num), datamodel=datamodel) s = generate_space_with_size(name, num, datamodel=datamodel)
d = generate_data(s) d = generate_data(s)
a = d.get_full_data() a = d.get_full_data()
if op in ['argmin', 'argmin_nonflat', 'argmax', 'argmax_nonflat']: if op in ['argmin', 'argmax'] and axis is not None:
assert_raises(NotImplementedError) assert_raises(NotImplementedError, lambda: s.unary_operation
(d, op, axis=axis))
else: else:
assert_almost_equal(s.unary_operation(d, op, axis=axis), assert_almost_equal(s.unary_operation(d, op, axis=axis),
getattr(np, op)(a, axis=axis), decimal=4) getattr(np, op)(a, axis=axis), decimal=4)
if name in ['rg_space']: if name in ['rg_space']:
assert_almost_equal(s.unary_operation(d, op, axis=(0, 1)), if op in ['argmin', 'argmax']:
getattr(np, op)(a, axis=(0, 1)), decimal=4) assert_raises(NotImplementedError, lambda: s.unary_operation
assert_almost_equal(s.unary_operation(d, op, axis=(1,)), (d, op, axis=(0, 1)))
getattr(np, op)(a, axis=(1,)), decimal=4) assert_raises(NotImplementedError, lambda: s.unary_operation
(d, op, axis=(1, )))
else:
assert_almost_equal(s.unary_operation(d, op, axis=(0, 1)),
getattr(np, op)(a, axis=(0, 1)),
decimal=4)
assert_almost_equal(s.unary_operation(d, op, axis=(1,)),
getattr(np, op)(a, axis=(1,)),
decimal=4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment