Commit bdd2c500 authored by Theo Steininger's avatar Theo Steininger
Browse files

Fixed tests for argmin, argmax.

parent 47c069d1
Pipeline #8813 passed with stage
in 4 minutes and 1 second
...@@ -1282,6 +1282,7 @@ class distributed_data_object(Versionable, object): ...@@ -1282,6 +1282,7 @@ class distributed_data_object(Versionable, object):
('index', np.dtype('float'))]) ('index', np.dtype('float'))])
local_argmax_list = np.sort(local_argmax_list, local_argmax_list = np.sort(local_argmax_list,
order=['value', 'index']) order=['value', 'index'])
# take the last entry here and correct the minus sign of the index # take the last entry here and correct the minus sign of the index
return -np.int(local_argmax_list[-1][1]) return -np.int(local_argmax_list[-1][1])
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
from numpy.testing import assert_equal,\ from numpy.testing import assert_equal,\
assert_almost_equal,\ assert_almost_equal,\
assert_raises assert_raises,\
assert_allclose
from nose_parameterized import parameterized from nose_parameterized import parameterized
import unittest import unittest
...@@ -1530,8 +1531,8 @@ class Test_contractions(unittest.TestCase): ...@@ -1530,8 +1531,8 @@ class Test_contractions(unittest.TestCase):
(a, obj) = generate_data(global_shape, dtype, (a, obj) = generate_data(global_shape, dtype,
distribution_strategy, distribution_strategy,
strictly_positive=True) strictly_positive=True)
assert_almost_equal(getattr(obj, function)(), getattr(np, function)(a), assert_allclose(getattr(obj, function)(), getattr(np, function)(a),
decimal=4) rtol=1e-4)
############################################################################### ###############################################################################
...@@ -1547,8 +1548,8 @@ class Test_contractions(unittest.TestCase): ...@@ -1547,8 +1548,8 @@ class Test_contractions(unittest.TestCase):
(a, obj) = generate_data(global_shape, dtype, (a, obj) = generate_data(global_shape, dtype,
distribution_strategy, distribution_strategy,
strictly_positive=True) strictly_positive=True)
assert_almost_equal(getattr(obj, function)(), getattr(np, function)(a), assert_allclose(getattr(obj, function)(), getattr(np, function)(a),
decimal=4) rtol=1e-4)
############################################################################### ###############################################################################
...@@ -1557,9 +1558,13 @@ class Test_contractions(unittest.TestCase): ...@@ -1557,9 +1558,13 @@ class Test_contractions(unittest.TestCase):
all_distribution_strategies all_distribution_strategies
)) ))
def test_argmin_argmax(self, dtype, distribution_strategy): def test_argmin_argmax(self, dtype, distribution_strategy):
print (dtype, distribution_strategy)
global_shape = (8, 8) global_shape = (8, 8)
(a, obj) = generate_data(global_shape, dtype, (a, obj) = generate_data(global_shape, dtype,
distribution_strategy) distribution_strategy,
strictly_positive=True)
o_full = obj.get_full_data()
print (a, o_full)
assert_equal(obj.argmax(), np.argmax(a)) assert_equal(obj.argmax(), np.argmax(a))
assert_equal(obj.argmin(), np.argmin(a)) assert_equal(obj.argmin(), np.argmin(a))
assert_equal(obj.argmin_nonflat(), assert_equal(obj.argmin_nonflat(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment