Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
8cf39e8f
Commit
8cf39e8f
authored
Apr 22, 2016
by
Theo Steininger
Browse files
Merge branch 'tests-for-axis' into 'add_axis_keyword_to_d2o'
Tests for axis See merge request
!9
parents
d72b7996
facb6103
Pipeline
#1901
skipped
Changes
4
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
d2o/distributed_data_object.py
View file @
8cf39e8f
# -*- coding: utf-8 -*-
import
numbers
as
numbers
import
numpy
as
np
...
...
@@ -1175,9 +1176,13 @@ class distributed_data_object(object):
def
std
(
self
,
axis
=
None
):
""" Returns the standard deviation of the d2o's elements. """
return
np
.
sqrt
(
self
.
var
(
axis
=
axis
))
var
=
self
.
var
(
axis
=
axis
)
if
isinstance
(
var
,
numbers
.
Number
):
return
np
.
sqrt
(
var
)
else
:
return
var
.
apply_scalar_function
(
np
.
sqrt
)
def
argmin
(
self
):
def
argmin
(
self
,
axis
=
None
):
""" Returns the (flat) index of the d2o's smallest value.
See Also:
...
...
@@ -1187,6 +1192,9 @@ class distributed_data_object(object):
if
0
in
self
.
shape
:
raise
ValueError
(
"ERROR: attempt to get argmin of an empty object"
)
if
axis
is
not
None
:
raise
NotImplementedError
(
"ERROR: argmin doesn't support axis "
"keyword"
)
if
0
in
self
.
local_shape
:
local_argmin
=
np
.
nan
local_argmin_value
=
np
.
nan
...
...
@@ -1208,7 +1216,7 @@ class distributed_data_object(object):
order
=
[
'value'
,
'index'
])
return
np
.
int
(
local_argmin_list
[
0
][
1
])
def
argmax
(
self
):
def
argmax
(
self
,
axis
=
None
):
""" Returns the (flat) index of the d2o's biggest value.
See Also:
...
...
@@ -1218,6 +1226,9 @@ class distributed_data_object(object):
if
0
in
self
.
shape
:
raise
ValueError
(
"ERROR: attempt to get argmax of an empty object"
)
if
axis
is
not
None
:
raise
NotImplementedError
(
"ERROR: argmax doesn't support axis "
"keyword"
)
if
0
in
self
.
local_shape
:
local_argmax
=
np
.
nan
local_argmax_value
=
np
.
nan
...
...
@@ -1238,23 +1249,22 @@ class distributed_data_object(object):
order
=
[
'value'
,
'index'
])
return
np
.
int
(
local_argmax_list
[
0
][
1
])
def
argmin_nonflat
(
self
):
def
argmin_nonflat
(
self
,
axis
=
None
):
""" Returns the unraveld index of the d2o's smallest value.
See Also:
argmin, argmax, argmax_nonflat
"""
return
np
.
unravel_index
(
self
.
argmin
(),
self
.
shape
)
return
np
.
unravel_index
(
self
.
argmin
(
axis
=
axis
),
self
.
shape
)
def
argmax_nonflat
(
self
):
def
argmax_nonflat
(
self
,
axis
=
None
):
""" Returns the unraveld index of the d2o's biggest value.
See Also:
argmin, argmax, argmin_nonflat
"""
return
np
.
unravel_index
(
self
.
argmax
(),
self
.
shape
)
return
np
.
unravel_index
(
self
.
argmax
(
axis
=
axis
),
self
.
shape
)
def
conjugate
(
self
):
""" Returns the element-wise complex conjugate. """
...
...
@@ -1284,7 +1294,15 @@ class distributed_data_object(object):
about
.
warnings
.
cprint
(
"WARNING: The current implementation of median is very expensive!"
)
median
=
np
.
median
(
self
.
get_full_data
(),
axis
=
axis
,
**
kwargs
)
return
median
if
isinstance
(
median
,
numbers
.
Number
):
return
median
else
:
x
=
self
.
copy_empty
(
global_shape
=
median
.
shape
,
dtype
=
median
.
dtype
,
distribution_strategy
=
'not'
)
x
.
set_local_data
(
median
)
return
x
def
_is_helper
(
self
,
function
):
""" _is_helper is used for functions like isreal, isinf, isfinite,...
...
...
nifty_core.py
View file @
8cf39e8f
...
...
@@ -892,7 +892,7 @@ class point_space(space):
def
apply_scalar_function
(
self
,
x
,
function
,
inplace
=
False
):
return
x
.
apply_scalar_function
(
function
,
inplace
=
inplace
)
def
unary_operation
(
self
,
x
,
op
=
'None'
,
**
kwargs
):
def
unary_operation
(
self
,
x
,
op
=
'None'
,
axis
=
None
,
**
kwargs
):
"""
x must be a numpy array which is compatible with the space!
Valid operations are
...
...
@@ -903,21 +903,21 @@ class point_space(space):
'abs'
:
lambda
y
:
getattr
(
y
,
'__abs__'
)(),
'real'
:
lambda
y
:
getattr
(
y
,
'real'
),
'imag'
:
lambda
y
:
getattr
(
y
,
'imag'
),
'nanmin'
:
lambda
y
:
getattr
(
y
,
'nanmin'
)(),
'amin'
:
lambda
y
:
getattr
(
y
,
'amin'
)(),
'nanmax'
:
lambda
y
:
getattr
(
y
,
'nanmax'
)(),
'amax'
:
lambda
y
:
getattr
(
y
,
'amax'
)(),
'median'
:
lambda
y
:
getattr
(
y
,
'median'
)(),
'mean'
:
lambda
y
:
getattr
(
y
,
'mean'
)(),
'std'
:
lambda
y
:
getattr
(
y
,
'std'
)(),
'var'
:
lambda
y
:
getattr
(
y
,
'var'
)(),
'argmin'
:
lambda
y
:
getattr
(
y
,
'argmin_nonflat'
)(),
'argmin
_flat
'
:
lambda
y
:
getattr
(
y
,
'argmin'
)(),
'argmax'
:
lambda
y
:
getattr
(
y
,
'argmax_nonflat'
)(),
'argmax
_flat
'
:
lambda
y
:
getattr
(
y
,
'argmax'
)(),
'nanmin'
:
lambda
y
:
getattr
(
y
,
'nanmin'
)(
axis
=
axis
),
'amin'
:
lambda
y
:
getattr
(
y
,
'amin'
)(
axis
=
axis
),
'nanmax'
:
lambda
y
:
getattr
(
y
,
'nanmax'
)(
axis
=
axis
),
'amax'
:
lambda
y
:
getattr
(
y
,
'amax'
)(
axis
=
axis
),
'median'
:
lambda
y
:
getattr
(
y
,
'median'
)(
axis
=
axis
),
'mean'
:
lambda
y
:
getattr
(
y
,
'mean'
)(
axis
=
axis
),
'std'
:
lambda
y
:
getattr
(
y
,
'std'
)(
axis
=
axis
),
'var'
:
lambda
y
:
getattr
(
y
,
'var'
)(
axis
=
axis
),
'argmin
_nonflat
'
:
lambda
y
:
getattr
(
y
,
'argmin_nonflat'
)(
axis
=
axis
),
'argmin'
:
lambda
y
:
getattr
(
y
,
'argmin'
)(
axis
=
axis
),
'argmax
_nonflat
'
:
lambda
y
:
getattr
(
y
,
'argmax_nonflat'
)(
axis
=
axis
),
'argmax'
:
lambda
y
:
getattr
(
y
,
'argmax'
)(
axis
=
axis
),
'conjugate'
:
lambda
y
:
getattr
(
y
,
'conjugate'
)(),
'sum'
:
lambda
y
:
getattr
(
y
,
'sum'
)(),
'prod'
:
lambda
y
:
getattr
(
y
,
'prod'
)(),
'sum'
:
lambda
y
:
getattr
(
y
,
'sum'
)(
axis
=
axis
),
'prod'
:
lambda
y
:
getattr
(
y
,
'prod'
)(
axis
=
axis
),
'unique'
:
lambda
y
:
getattr
(
y
,
'unique'
)(),
'copy'
:
lambda
y
:
getattr
(
y
,
'copy'
)(),
'copy_empty'
:
lambda
y
:
getattr
(
y
,
'copy_empty'
)(),
...
...
@@ -925,8 +925,8 @@ class point_space(space):
'isinf'
:
lambda
y
:
getattr
(
y
,
'isinf'
)(),
'isfinite'
:
lambda
y
:
getattr
(
y
,
'isfinite'
)(),
'nan_to_num'
:
lambda
y
:
getattr
(
y
,
'nan_to_num'
)(),
'all'
:
lambda
y
:
getattr
(
y
,
'all'
)(),
'any'
:
lambda
y
:
getattr
(
y
,
'any'
)(),
'all'
:
lambda
y
:
getattr
(
y
,
'all'
)(
axis
=
axis
),
'any'
:
lambda
y
:
getattr
(
y
,
'any'
)(
axis
=
axis
),
'None'
:
lambda
y
:
y
}
return
translation
[
op
](
x
,
**
kwargs
)
...
...
@@ -2747,10 +2747,10 @@ class field(object):
"""
if
split
:
return
self
.
_unary_helper
(
self
.
get_val
(),
op
=
'argmin'
,
return
self
.
_unary_helper
(
self
.
get_val
(),
op
=
'argmin
_nonflat
'
,
**
kwargs
)
else
:
return
self
.
_unary_helper
(
self
.
get_val
(),
op
=
'argmin
_flat
'
,
return
self
.
_unary_helper
(
self
.
get_val
(),
op
=
'argmin'
,
**
kwargs
)
def
argmax
(
self
,
split
=
True
,
**
kwargs
):
...
...
@@ -2776,10 +2776,10 @@ class field(object):
"""
if
split
:
return
self
.
_unary_helper
(
self
.
get_val
(),
op
=
'argmax'
,
return
self
.
_unary_helper
(
self
.
get_val
(),
op
=
'argmax
_nonflat
'
,
**
kwargs
)
else
:
return
self
.
_unary_helper
(
self
.
get_val
(),
op
=
'argmax
_flat
'
,
return
self
.
_unary_helper
(
self
.
get_val
(),
op
=
'argmax'
,
**
kwargs
)
# TODO: Implement the full range of unary and binary operotions
...
...
test/test_nifty_mpi_data.py
View file @
8cf39e8f
...
...
@@ -1738,3 +1738,64 @@ if FOUND['h5py'] == True:
# Todo: Assert that data is copied, when copy flag is set
# Todo: Assert that set, get and injection work, if there is different data
# on the nodes
class
Test_axis
(
unittest
.
TestCase
):
@
parameterized
.
expand
(
itertools
.
product
([
'sum'
,
'prod'
,
'mean'
,
'var'
,
'std'
,
'median'
,
'all'
,
'any'
,
'min'
,
'amin'
,
'nanmin'
,
'argmin'
,
'argmin_nonflat'
,
'max'
,
'amax'
,
'nanmax'
,
'argmax'
,
'argmax_nonflat'
],
all_datatypes
[
1
:],
[(
0
,),
(
1
,),
(
6
,
6
),
(
5
,
5
,
5
)],
all_distribution_strategies
,
[
None
,
0
,
(
1
,
),
(
0
,
1
)]),
testcase_func_name
=
custom_name_func
)
def
test_axis_with_functions
(
self
,
function
,
dtype
,
global_shape
,
distribution_strategy
,
axis
):
(
a
,
obj
)
=
generate_data
(
global_shape
,
dtype
,
distribution_strategy
,
strictly_positive
=
True
)
if
function
in
[
'argmin'
,
'argmin_nonflat'
,
'argmax'
,
'argmax_nonflat'
]:
assert_raises
(
NotImplementedError
)
else
:
if
global_shape
!=
(
0
,)
and
global_shape
!=
(
1
,):
assert_almost_equal
(
getattr
(
obj
,
function
)(
axis
=
axis
),
getattr
(
np
,
function
)(
a
,
axis
=
axis
),
decimal
=
4
)
else
:
if
function
in
[
'min'
,
'amin'
,
'nanmin'
,
'max'
,
'amax'
,
'nanmax'
]:
assert_raises
(
ValueError
)
else
:
if
axis
is
None
or
axis
==
0
or
axis
==
(
0
,):
assert_almost_equal
(
getattr
(
obj
,
function
)(
axis
=
axis
),
getattr
(
np
,
function
)(
a
,
axis
=
axis
),
decimal
=
4
)
@
parameterized
.
expand
(
itertools
.
product
([
'sum'
,
'prod'
,
'mean'
,
'var'
,
'std'
,
'median'
,
'all'
,
'any'
,
'min'
,
'amin'
,
'nanmin'
,
'argmin'
,
'argmin_nonflat'
,
'max'
,
'amax'
,
'nanmax'
,
'argmax'
,
'argmax_nonflat'
],
all_datatypes
[
1
:],
[(
5
,
5
,
5
),
(
4
,
0
,
3
)],
all_distribution_strategies
,
[(
0
,
1
),
(
1
,
2
)]),
testcase_func_name
=
custom_name_func
)
def
test_axis_with_functions_for_many_dimentions
(
self
,
function
,
dtype
,
global_shape
,
distribution_strategy
,
axis
):
(
a
,
obj
)
=
generate_data
(
global_shape
,
dtype
,
distribution_strategy
,
strictly_positive
=
True
)
if
function
in
[
'argmin'
,
'argmin_nonflat'
,
'argmax'
,
'argmax_nonflat'
]:
assert_raises
(
NotImplementedError
)
else
:
if
function
in
[
'min'
,
'amin'
,
'nanmin'
,
'max'
,
'amax'
,
'nanmax'
]
\
and
np
.
prod
(
global_shape
)
==
0
:
assert_raises
(
ValueError
)
else
:
assert_almost_equal
(
getattr
(
obj
,
function
)
(
axis
=
axis
).
get_full_data
(),
getattr
(
np
,
function
)(
a
,
axis
=
axis
),
decimal
=
4
)
test/test_nifty_spaces.py
View file @
8cf39e8f
...
...
@@ -27,6 +27,8 @@ from nifty.nifty_power_indices import power_indices
from
nifty.nifty_utilities
import
_hermitianize_inverter
as
\
hermitianize_inverter
from
nifty.operators.nifty_operators
import
power_operator
available
=
[]
try
:
from
nifty
import
lm_space
...
...
@@ -130,9 +132,9 @@ if HP_DISTRIBUTION_STRATEGIES != []:
unary_operations
=
[
'pos'
,
'neg'
,
'abs'
,
'real'
,
'imag'
,
'nanmin'
,
'amin'
,
'nanmax'
,
'amax'
,
'median'
,
'mean'
,
'std'
,
'var'
,
'argmin'
,
'argmin_flat'
,
'argmax'
,
'argmax_flat'
,
'conjugate'
,
'sum'
,
'prod'
,
'unique'
,
'copy'
,
'copy_empty'
,
'isnan'
,
'isinf'
,
'isfinite'
,
'nan_to_num'
,
'all'
,
'any'
,
'None'
]
'argmin_
non
flat'
,
'argmax'
,
'argmax_
non
flat'
,
'conjugate'
,
'sum'
,
'prod'
,
'unique'
,
'copy'
,
'copy_empty'
,
'isnan'
,
'isinf'
,
'isfinite'
,
'nan_to_num'
,
'all'
,
'any'
,
'None'
]
binary_operations
=
[
'add'
,
'radd'
,
'iadd'
,
'sub'
,
'rsub'
,
'isub'
,
'mul'
,
'rmul'
,
'imul'
,
'div'
,
'rdiv'
,
'idiv'
,
'pow'
,
'rpow'
,
...
...
@@ -178,6 +180,22 @@ def generate_space(name):
return
space_dict
[
name
]
def
generate_space_with_size
(
name
,
num
,
datamodel
=
'fftw'
):
space_dict
=
{
'space'
:
space
(),
'point_space'
:
point_space
(
num
,
datamodel
=
datamodel
),
'rg_space'
:
rg_space
((
num
,
num
),
datamodel
=
datamodel
),
}
if
'lm_space'
in
available
:
space_dict
[
'lm_space'
]
=
lm_space
(
mmax
=
num
,
lmax
=
num
,
datamodel
=
datamodel
)
if
'hp_space'
in
available
:
space_dict
[
'hp_space'
]
=
hp_space
(
num
,
datamodel
=
datamodel
)
if
'gl_space'
in
available
:
space_dict
[
'gl_space'
]
=
gl_space
(
nlat
=
num
,
nlon
=
num
,
datamodel
=
datamodel
)
return
space_dict
[
name
]
def
generate_data
(
space
):
a
=
np
.
arange
(
space
.
get_dim
()).
reshape
(
space
.
get_shape
())
data
=
space
.
cast
(
a
)
...
...
@@ -1334,4 +1352,27 @@ class Test_Lm_Space(unittest.TestCase):
print
all_spaces
print
generate_space
(
'rg_space'
)
\ No newline at end of file
print
generate_space
(
'rg_space'
)
class
Test_axis
(
unittest
.
TestCase
):
@
parameterized
.
expand
(
itertools
.
product
(
point_like_spaces
,
[
8
,
16
],
[
'sum'
,
'prod'
,
'mean'
,
'var'
,
'std'
,
'median'
,
'all'
,
'any'
,
'amin'
,
'nanmin'
,
'argmin'
,
'argmin_nonflat'
,
'amax'
,
'nanmax'
,
'argmax'
,
'argmax_nonflat'
],
[
None
,
(
0
,)],
DATAMODELS
[
'point_space'
]),
testcase_func_name
=
custom_name_func
)
def
test_binary_operations
(
self
,
name
,
num
,
op
,
axis
,
datamodel
):
s
=
generate_space_with_size
(
name
,
np
.
prod
(
num
),
datamodel
=
datamodel
)
d
=
generate_data
(
s
)
a
=
d
.
get_full_data
()
if
op
in
[
'argmin'
,
'argmin_nonflat'
,
'argmax'
,
'argmax_nonflat'
]:
assert_raises
(
NotImplementedError
)
else
:
assert_almost_equal
(
s
.
unary_operation
(
d
,
op
,
axis
=
axis
),
getattr
(
np
,
op
)(
a
,
axis
=
axis
),
decimal
=
4
)
if
name
in
[
'rg_space'
]:
assert_almost_equal
(
s
.
unary_operation
(
d
,
op
,
axis
=
(
0
,
1
)),
getattr
(
np
,
op
)(
a
,
axis
=
(
0
,
1
)),
decimal
=
4
)
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment