Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
D2O
Commits
bdd2c500
Commit
bdd2c500
authored
Dec 03, 2016
by
Theo Steininger
Browse files
Fixed tests for argmin, argmax.
parent
47c069d1
Pipeline
#8813
passed with stage
in 4 minutes and 1 second
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
d2o/distributed_data_object.py
View file @
bdd2c500
...
...
@@ -1282,6 +1282,7 @@ class distributed_data_object(Versionable, object):
(
'index'
,
np
.
dtype
(
'float'
))])
local_argmax_list
=
np
.
sort
(
local_argmax_list
,
order
=
[
'value'
,
'index'
])
# take the last entry here and correct the minus sign of the index
return
-
np
.
int
(
local_argmax_list
[
-
1
][
1
])
...
...
test/test_distributed_data_object.py
View file @
bdd2c500
...
...
@@ -18,7 +18,8 @@
from
numpy.testing
import
assert_equal
,
\
assert_almost_equal
,
\
assert_raises
assert_raises
,
\
assert_allclose
from
nose_parameterized
import
parameterized
import
unittest
...
...
@@ -1530,8 +1531,8 @@ class Test_contractions(unittest.TestCase):
(
a
,
obj
)
=
generate_data
(
global_shape
,
dtype
,
distribution_strategy
,
strictly_positive
=
True
)
assert_al
most_equal
(
getattr
(
obj
,
function
)(),
getattr
(
np
,
function
)(
a
),
decimal
=
4
)
assert_al
lclose
(
getattr
(
obj
,
function
)(),
getattr
(
np
,
function
)(
a
),
rtol
=
1e-
4
)
###############################################################################
...
...
@@ -1547,8 +1548,8 @@ class Test_contractions(unittest.TestCase):
(
a
,
obj
)
=
generate_data
(
global_shape
,
dtype
,
distribution_strategy
,
strictly_positive
=
True
)
assert_al
most_equal
(
getattr
(
obj
,
function
)(),
getattr
(
np
,
function
)(
a
),
decimal
=
4
)
assert_al
lclose
(
getattr
(
obj
,
function
)(),
getattr
(
np
,
function
)(
a
),
rtol
=
1e-
4
)
###############################################################################
...
...
@@ -1557,9 +1558,13 @@ class Test_contractions(unittest.TestCase):
all_distribution_strategies
))
def
test_argmin_argmax
(
self
,
dtype
,
distribution_strategy
):
print
(
dtype
,
distribution_strategy
)
global_shape
=
(
8
,
8
)
(
a
,
obj
)
=
generate_data
(
global_shape
,
dtype
,
distribution_strategy
)
distribution_strategy
,
strictly_positive
=
True
)
o_full
=
obj
.
get_full_data
()
print
(
a
,
o_full
)
assert_equal
(
obj
.
argmax
(),
np
.
argmax
(
a
))
assert_equal
(
obj
.
argmin
(),
np
.
argmin
(
a
))
assert_equal
(
obj
.
argmin_nonflat
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment