Commit 5599041c authored by csongor's avatar csongor
Browse files

fix tests for median, check for numpy version

parent 23657529
...@@ -17,6 +17,8 @@ import nifty ...@@ -17,6 +17,8 @@ import nifty
from nifty.d2o import distributed_data_object,\ from nifty.d2o import distributed_data_object,\
STRATEGIES STRATEGIES
from distutils.version import LooseVersion as lv
FOUND = {} FOUND = {}
try: try:
import h5py import h5py
...@@ -1742,10 +1744,9 @@ if FOUND['h5py'] == True: ...@@ -1742,10 +1744,9 @@ if FOUND['h5py'] == True:
class Test_axis(unittest.TestCase): class Test_axis(unittest.TestCase):
@parameterized.expand( @parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', #'median', itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'all', 'any', 'min', 'amin', 'nanmin', 'argmin', 'min', 'amin', 'nanmin', 'argmin', 'argmin_nonflat',
'argmin_nonflat', 'max', 'amax', 'nanmax', 'max', 'amax', 'nanmax', 'argmax', 'argmax_nonflat'],
'argmax', 'argmax_nonflat'],
all_datatypes[1:], all_datatypes[1:],
[(0,), (1,), (6, 6), (4, 4, 3)], [(0,), (1,), (6, 6), (4, 4, 3)],
all_distribution_strategies, all_distribution_strategies,
...@@ -1774,10 +1775,9 @@ class Test_axis(unittest.TestCase): ...@@ -1774,10 +1775,9 @@ class Test_axis(unittest.TestCase):
decimal=4) decimal=4)
@parameterized.expand( @parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', #'median', itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'all', 'any', 'min', 'amin', 'nanmin', 'argmin', 'min', 'amin', 'nanmin', 'argmin', 'argmin_nonflat',
'argmin_nonflat', 'max', 'amax', 'nanmax', 'max', 'amax', 'nanmax', 'argmax', 'argmax_nonflat'],
'argmax', 'argmax_nonflat'],
all_datatypes[1:], all_datatypes[1:],
[(4, 4, 3), (4, 0, 3)], [(4, 4, 3), (4, 0, 3)],
all_distribution_strategies, [(0, 1), (1, 2)]), all_distribution_strategies, [(0, 1), (1, 2)]),
...@@ -1798,4 +1798,27 @@ class Test_axis(unittest.TestCase): ...@@ -1798,4 +1798,27 @@ class Test_axis(unittest.TestCase):
else: else:
assert_almost_equal(getattr(obj, function) assert_almost_equal(getattr(obj, function)
(axis=axis).get_full_data(), (axis=axis).get_full_data(),
getattr(np, function)(a, axis=axis), decimal=4) getattr(np, function)(a, axis=axis),
decimal=4)
@parameterized.expand(
itertools.product(all_datatypes[1:],
[(0,), (1,), (4, 4, 3), (4, 0, 3)],
all_distribution_strategies,
[None, 0, (1, ), (0, 1), (1, 2)]),
testcase_func_name=custom_name_func)
def test_axis_for_median(self, dtype, global_shape,
distribution_strategy, axis):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
if global_shape != (0,) and global_shape != (1,) and lv("1.9.0Z") < \
lv(np.__version__):
assert_almost_equal(getattr(obj, 'median')(axis=axis),
getattr(np, 'median')(a, axis=axis),
decimal=4)
else:
if axis is None or axis == 0 or axis == (0,):
assert_almost_equal(getattr(obj, 'median')(axis=axis),
getattr(np, 'median')(a, axis=axis),
decimal=4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment