Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
af793f09
Commit
af793f09
authored
Apr 27, 2016
by
csongor
Browse files
fix axus test for all cases and exceptions
parent
154a163f
Changes
2
Hide whitespace changes
Inline
Side-by-side
test/test_nifty_mpi_data.py
View file @
af793f09
...
...
@@ -1748,56 +1748,85 @@ class Test_axis(unittest.TestCase):
'min'
,
'amin'
,
'nanmin'
,
'argmin'
,
'argmin_nonflat'
,
'max'
,
'amax'
,
'nanmax'
,
'argmax'
,
'argmax_nonflat'
],
all_datatypes
[
1
:],
[(
0
,),
(
1
,),
(
6
,
6
),
(
4
,
4
,
3
)],
[(
0
,),
(
4
,
0
,
3
)],
all_distribution_strategies
,
[
None
,
(
0
,
),
(
1
,
),
(
0
,
1
)]),
testcase_func_name
=
custom_name_func
)
def
test_axis_with_operations_0_dimention
(
self
,
function
,
dtype
,
global_shape
,
distribution_strategy
,
axis
):
(
a
,
obj
)
=
generate_data
(
global_shape
,
dtype
,
distribution_strategy
,
strictly_positive
=
True
)
if
function
in
[
'min'
,
'amin'
,
'nanmin'
,
'argmin'
,
'argmin_nonflat'
,
'max'
,
'amax'
,
'nanmax'
,
'argmax'
,
'argmax_nonflat'
]:
if
not
(
function
in
[
'min'
,
'amin'
,
'nanmin'
,
'max'
,
'amax'
,
'nanmax'
]
and
axis
==
(
0
,
)
and
global_shape
==
(
4
,
0
,
3
)):
assert_raises
(
ValueError
,
lambda
:
getattr
(
obj
,
function
)
(
axis
=
axis
))
else
:
if
axis
in
[(
1
,
),
(
0
,
1
)]
and
global_shape
==
(
0
,):
assert_raises
(
StandardError
,
lambda
:
getattr
(
obj
,
function
)
(
axis
=
axis
))
else
:
assert_almost_equal
(
getattr
(
obj
,
function
)(
axis
=
axis
),
getattr
(
np
,
function
)(
a
,
axis
=
axis
),
decimal
=
4
)
@
parameterized
.
expand
(
itertools
.
product
([
'sum'
,
'prod'
,
'mean'
,
'var'
,
'std'
,
'all'
,
'any'
,
'min'
,
'amin'
,
'nanmin'
,
'argmin'
,
'max'
,
'amax'
,
'nanmax'
,
'argmax'
],
all_datatypes
[
1
:],
[(
1
,),
(
6
,
6
)],
all_distribution_strategies
,
[
None
,
0
,
(
1
,
),
(
0
,
1
)]),
testcase_func_name
=
custom_name_func
)
def
test_axis_with_
func
tions
(
self
,
function
,
dtype
,
global_shape
,
distribution_strategy
,
axis
):
def
test_axis_with_
opera
tions
(
self
,
function
,
dtype
,
global_shape
,
distribution_strategy
,
axis
):
(
a
,
obj
)
=
generate_data
(
global_shape
,
dtype
,
distribution_strategy
,
strictly_positive
=
True
)
if
function
in
[
'argmin'
,
'argmin_nonflat'
,
'argmax'
,
'argmax_nonflat'
]:
assert_raises
(
NotImplementedError
)
if
function
in
[
'argmin'
,
'argmax'
]
and
axis
is
not
None
:
assert_raises
(
NotImplementedError
,
lambda
:
getattr
(
obj
,
function
)
(
axis
=
axis
))
else
:
if
global_shape
!=
(
0
,)
and
global_shape
!=
(
1
,):
if
global_shape
!=
(
1
,):
assert_almost_equal
(
getattr
(
obj
,
function
)(
axis
=
axis
),
getattr
(
np
,
function
)(
a
,
axis
=
axis
),
decimal
=
4
)
else
:
if
function
in
[
'min'
,
'amin'
,
'nanmin'
,
'max'
,
'amax'
,
'nanmax'
]:
assert_raises
(
ValueError
)
else
:
if
axis
is
None
or
axis
==
0
or
axis
==
(
0
,):
assert_almost_equal
(
getattr
(
obj
,
function
)(
axis
=
axis
),
getattr
(
np
,
function
)(
a
,
axis
=
axis
),
decimal
=
4
)
if
axis
is
None
or
axis
==
0
or
axis
==
(
0
,):
assert_almost_equal
(
getattr
(
obj
,
function
)(
axis
=
axis
),
getattr
(
np
,
function
)(
a
,
axis
=
axis
),
decimal
=
4
)
@
parameterized
.
expand
(
itertools
.
product
([
'sum'
,
'prod'
,
'mean'
,
'var'
,
'std'
,
'all'
,
'any'
,
'min'
,
'amin'
,
'nanmin'
,
'argmin'
,
'
argmin_nonflat
'
,
'max'
,
'amax'
,
'nanmax'
,
'argmax'
,
'argmax_nonflat'
],
'min'
,
'amin'
,
'nanmin'
,
'argmin'
,
'
max'
,
'amax
'
,
'nanmax'
,
'argmax'
],
all_datatypes
[
1
:],
[(
4
,
4
,
3
),
(
4
,
0
,
3
)],
all_distribution_strategies
,
[(
0
,
1
),
(
1
,
2
)]),
[(
4
,
4
,
3
)],
all_distribution_strategies
,
[(
0
,
1
),
(
1
,
2
),
(
0
,
1
,
2
)]),
testcase_func_name
=
custom_name_func
)
def
test_axis_with_
func
tions_
for_
many_dimentions
(
self
,
function
,
dtype
,
global_shape
,
distribution_strategy
,
axis
):
def
test_axis_with_
opera
tions_many_dimentions
(
self
,
function
,
dtype
,
global_shape
,
distribution_strategy
,
axis
):
(
a
,
obj
)
=
generate_data
(
global_shape
,
dtype
,
distribution_strategy
,
strictly_positive
=
True
)
if
function
in
[
'argmin'
,
'argmin_nonflat'
,
'argmax'
,
'argmax_nonflat'
]:
assert_raises
(
NotImplementedError
)
if
function
in
[
'argmin'
,
'argmax'
]
and
axis
is
not
None
:
assert_raises
(
NotImplementedError
,
lambda
:
getattr
(
obj
,
function
)
(
axis
=
axis
))
else
:
if
function
in
[
'min'
,
'amin'
,
'nanmin'
,
'max'
,
'amax'
,
'nanmax'
]
\
and
np
.
prod
(
global_shape
)
==
0
:
assert_raises
(
ValueError
)
and
0
in
global_shape
:
assert_raises
(
ValueError
,
lambda
:
getattr
(
obj
,
function
)
(
axis
=
axis
))
else
:
assert_almost_equal
(
getattr
(
obj
,
function
)
(
axis
=
axis
).
get_full_data
(),
assert_almost_equal
(
getattr
(
obj
,
function
)(
axis
=
axis
),
getattr
(
np
,
function
)(
a
,
axis
=
axis
),
decimal
=
4
)
...
...
@@ -1822,3 +1851,41 @@ class Test_axis(unittest.TestCase):
assert_almost_equal
(
getattr
(
obj
,
'median'
)(
axis
=
axis
),
getattr
(
np
,
'median'
)(
a
,
axis
=
axis
),
decimal
=
4
)
@
parameterized
.
expand
(
itertools
.
product
([(
'argmin_nonflat'
,
'argmin'
),
(
'argmax_nonflat'
,
'argmax'
)],
all_datatypes
[
1
:],
[(
0
,),
(
1
,),
(
4
,
4
,
3
),
(
4
,
0
,
3
)],
all_distribution_strategies
,
[
None
,
(
1
,
),
(
1
,
2
)]),
testcase_func_name
=
custom_name_func
)
def
test_axis_for_nonflats
(
self
,
function_pair
,
dtype
,
global_shape
,
distribution_strategy
,
axis
):
(
a
,
obj
)
=
generate_data
(
global_shape
,
dtype
,
distribution_strategy
,
strictly_positive
=
True
)
if
0
in
global_shape
:
assert_raises
(
ValueError
,
lambda
:
getattr
(
obj
,
function_pair
[
0
])(
axis
=
axis
))
else
:
if
axis
is
not
None
:
assert_raises
(
NotImplementedError
,
lambda
:
getattr
(
obj
,
function_pair
[
0
])(
axis
=
axis
))
else
:
if
global_shape
!=
(
0
,)
and
global_shape
!=
(
1
,):
assert_almost_equal
(
getattr
(
obj
,
function_pair
[
0
])(
axis
=
axis
),
np
.
unravel_index
(
getattr
(
np
,
function_pair
[
1
])
(
a
,
axis
=
axis
),
dims
=
global_shape
),
decimal
=
4
)
else
:
assert_almost_equal
(
getattr
(
obj
,
function_pair
[
0
])
(
axis
=
axis
),
np
.
unravel_index
(
getattr
(
np
,
function_pair
[
1
])
(
a
,
axis
=
axis
),
dims
=
global_shape
),
decimal
=
4
)
test/test_nifty_spaces.py
View file @
af793f09
...
...
@@ -1358,23 +1358,31 @@ class Test_axis(unittest.TestCase):
@
parameterized
.
expand
(
itertools
.
product
(
point_like_spaces
,
[
8
,
16
],
[
'sum'
,
'prod'
,
'mean'
,
'var'
,
'std'
,
'median'
,
'all'
,
'any'
,
'amin'
,
'nanmin'
,
'argmin'
,
'argmin_nonflat'
,
'amax'
,
'nanmax'
,
'argmax'
,
'argmax_nonflat'
],
'any'
,
'amin'
,
'nanmin'
,
'argmin'
,
'amax'
,
'nanmax'
,
'argmax'
],
[
None
,
(
0
,)],
DATAMODELS
[
'point_space'
]),
testcase_func_name
=
custom_name_func
)
def
test_unary_operations
(
self
,
name
,
num
,
op
,
axis
,
datamodel
):
s
=
generate_space_with_size
(
name
,
np
.
prod
(
num
)
,
datamodel
=
datamodel
)
s
=
generate_space_with_size
(
name
,
num
,
datamodel
=
datamodel
)
d
=
generate_data
(
s
)
a
=
d
.
get_full_data
()
if
op
in
[
'argmin'
,
'argmin_nonflat'
,
'argmax'
,
'argmax_nonflat'
]:
assert_raises
(
NotImplementedError
)
if
op
in
[
'argmin'
,
'argmax'
]
and
axis
is
not
None
:
assert_raises
(
NotImplementedError
,
lambda
:
s
.
unary_operation
(
d
,
op
,
axis
=
axis
))
else
:
assert_almost_equal
(
s
.
unary_operation
(
d
,
op
,
axis
=
axis
),
getattr
(
np
,
op
)(
a
,
axis
=
axis
),
decimal
=
4
)
if
name
in
[
'rg_space'
]:
assert_almost_equal
(
s
.
unary_operation
(
d
,
op
,
axis
=
(
0
,
1
)),
getattr
(
np
,
op
)(
a
,
axis
=
(
0
,
1
)),
decimal
=
4
)
assert_almost_equal
(
s
.
unary_operation
(
d
,
op
,
axis
=
(
1
,)),
getattr
(
np
,
op
)(
a
,
axis
=
(
1
,)),
decimal
=
4
)
if
op
in
[
'argmin'
,
'argmax'
]:
assert_raises
(
NotImplementedError
,
lambda
:
s
.
unary_operation
(
d
,
op
,
axis
=
(
0
,
1
)))
assert_raises
(
NotImplementedError
,
lambda
:
s
.
unary_operation
(
d
,
op
,
axis
=
(
1
,
)))
else
:
assert_almost_equal
(
s
.
unary_operation
(
d
,
op
,
axis
=
(
0
,
1
)),
getattr
(
np
,
op
)(
a
,
axis
=
(
0
,
1
)),
decimal
=
4
)
assert_almost_equal
(
s
.
unary_operation
(
d
,
op
,
axis
=
(
1
,)),
getattr
(
np
,
op
)(
a
,
axis
=
(
1
,)),
decimal
=
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment