Commit 5599041c authored by csongor's avatar csongor
Browse files

fix tests for median, check for numpy version

parent 23657529
Pipeline #2010 skipped
......@@ -17,6 +17,8 @@ import nifty
from nifty.d2o import distributed_data_object,\
STRATEGIES
from distutils.version import LooseVersion as lv
FOUND = {}
try:
import h5py
......@@ -1742,10 +1744,9 @@ if FOUND['h5py'] == True:
class Test_axis(unittest.TestCase):
@parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', #'median',
'all', 'any', 'min', 'amin', 'nanmin', 'argmin',
'argmin_nonflat', 'max', 'amax', 'nanmax',
'argmax', 'argmax_nonflat'],
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'min', 'amin', 'nanmin', 'argmin', 'argmin_nonflat',
'max', 'amax', 'nanmax', 'argmax', 'argmax_nonflat'],
all_datatypes[1:],
[(0,), (1,), (6, 6), (4, 4, 3)],
all_distribution_strategies,
......@@ -1774,10 +1775,9 @@ class Test_axis(unittest.TestCase):
decimal=4)
@parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', #'median',
'all', 'any', 'min', 'amin', 'nanmin', 'argmin',
'argmin_nonflat', 'max', 'amax', 'nanmax',
'argmax', 'argmax_nonflat'],
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'min', 'amin', 'nanmin', 'argmin', 'argmin_nonflat',
'max', 'amax', 'nanmax', 'argmax', 'argmax_nonflat'],
all_datatypes[1:],
[(4, 4, 3), (4, 0, 3)],
all_distribution_strategies, [(0, 1), (1, 2)]),
......@@ -1798,4 +1798,27 @@ class Test_axis(unittest.TestCase):
else:
assert_almost_equal(getattr(obj, function)
(axis=axis).get_full_data(),
getattr(np, function)(a, axis=axis), decimal=4)
getattr(np, function)(a, axis=axis),
decimal=4)
@parameterized.expand(
itertools.product(all_datatypes[1:],
[(0,), (1,), (4, 4, 3), (4, 0, 3)],
all_distribution_strategies,
[None, 0, (1, ), (0, 1), (1, 2)]),
testcase_func_name=custom_name_func)
def test_axis_for_median(self, dtype, global_shape,
distribution_strategy, axis):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
if global_shape != (0,) and global_shape != (1,) and lv("1.9.0Z") < \
lv(np.__version__):
assert_almost_equal(getattr(obj, 'median')(axis=axis),
getattr(np, 'median')(a, axis=axis),
decimal=4)
else:
if axis is None or axis == 0 or axis == (0,):
assert_almost_equal(getattr(obj, 'median')(axis=axis),
getattr(np, 'median')(a, axis=axis),
decimal=4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment