Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
D2O
Commits
fb7e53de
Commit
fb7e53de
authored
Jan 24, 2017
by
Theo Steininger
Browse files
Added functionality for empty-shape d2o's.
parent
b8f6f2d9
Changes
3
Hide whitespace changes
Inline
Side-by-side
d2o/distributed_data_object.py
View file @
fb7e53de
...
...
@@ -1226,6 +1226,9 @@ class distributed_data_object(Versionable, object):
if
axis
is
not
None
:
raise
NotImplementedError
(
"ERROR: argmin doesn't support axis "
"keyword"
)
if
self
.
shape
==
():
return
0
if
0
in
self
.
local_shape
:
local_argmin
=
np
.
nan
local_argmin_value
=
np
.
nan
...
...
@@ -1260,6 +1263,9 @@ class distributed_data_object(Versionable, object):
if
axis
is
not
None
:
raise
NotImplementedError
(
"ERROR: argmax doesn't support axis "
"keyword"
)
if
self
.
shape
==
():
return
0
if
0
in
self
.
local_shape
:
local_argmax
=
np
.
nan
local_argmax_value
=
-
np
.
inf
...
...
@@ -1291,7 +1297,8 @@ class distributed_data_object(Versionable, object):
See Also:
argmin, argmax, argmax_nonflat
"""
if
self
.
shape
==
():
return
(
0
,)
return
np
.
unravel_index
(
self
.
argmin
(
axis
=
axis
),
self
.
shape
)
def
argmax_nonflat
(
self
,
axis
=
None
):
...
...
@@ -1300,6 +1307,8 @@ class distributed_data_object(Versionable, object):
See Also:
argmin, argmax, argmin_nonflat
"""
if
self
.
shape
==
():
return
(
0
,)
return
np
.
unravel_index
(
self
.
argmax
(
axis
=
axis
),
self
.
shape
)
def
conjugate
(
self
):
...
...
d2o/distributor_factory.py
View file @
fb7e53de
...
...
@@ -81,7 +81,7 @@ class _distributor_factory(object):
if
expensive_checks
:
# Check that all nodes got the same distribution_strategy
strat_list
=
comm
.
allgather
(
distribution_strategy
)
if
all
(
x
==
strat_list
[
0
]
for
x
in
strat_list
)
==
False
:
if
not
all
(
x
==
strat_list
[
0
]
for
x
in
strat_list
):
raise
ValueError
(
about_cstring
(
"ERROR: The distribution-strategy must be the same on "
+
"all nodes!"
))
...
...
@@ -135,7 +135,7 @@ class _distributor_factory(object):
dtype
=
np
.
dtype
(
dtype
)
if
expensive_checks
:
dtype_list
=
comm
.
allgather
(
dtype
)
if
all
(
x
==
dtype_list
[
0
]
for
x
in
dtype_list
)
==
False
:
if
not
all
(
x
==
dtype_list
[
0
]
for
x
in
dtype_list
):
raise
ValueError
(
about_cstring
(
"ERROR: The given dtype must be the same on all nodes!"
))
return_dict
[
'dtype'
]
=
dtype
...
...
@@ -145,17 +145,19 @@ class _distributor_factory(object):
if
distribution_strategy
in
STRATEGIES
[
'global'
]:
if
dset
is
not
None
:
global_shape
=
dset
.
shape
elif
global_data
is
not
None
and
np
.
isscalar
(
global_data
)
==
False
:
elif
global_data
is
not
None
and
not
np
.
isscalar
(
global_data
):
global_shape
=
global_data
.
shape
elif
global_shape
is
not
None
:
global_shape
=
tuple
(
global_shape
)
elif
global_data
is
not
None
:
global_shape
=
()
else
:
raise
ValueError
(
about_cstring
(
"ERROR: Neither
non-0-dimensional
global_data nor "
+
"ERROR: Neither global_data nor "
+
"global_shape nor hdf5 file supplied!"
))
if
global_shape
==
():
raise
ValueError
(
about_cstring
(
"ERROR: global_shape == () is not a valid shape!"
))
#
if global_shape == ():
#
raise ValueError(about_cstring(
#
"ERROR: global_shape == () is not a valid shape!"))
if
expensive_checks
:
global_shape_list
=
comm
.
allgather
(
global_shape
)
...
...
@@ -170,7 +172,7 @@ class _distributor_factory(object):
elif
distribution_strategy
in
[
'freeform'
]:
if
isinstance
(
global_data
,
distributed_data_object
):
local_shape
=
global_data
.
local_shape
elif
local_data
is
not
None
and
np
.
isscalar
(
local_data
)
==
False
:
elif
local_data
is
not
None
and
not
np
.
isscalar
(
local_data
):
local_shape
=
local_data
.
shape
elif
local_shape
is
not
None
:
local_shape
=
tuple
(
local_shape
)
...
...
@@ -240,6 +242,11 @@ class _distributor_factory(object):
comm
=
comm
,
**
kwargs
)
if
parsed_kwargs
.
get
(
'global_shape'
)
==
():
distribution_strategy
=
'not'
about_infos_cprint
(
"WARNING: Distribution strategy was set to "
"'not' because of global_shape == ()"
)
hashed_kwargs
=
self
.
hash_arguments
(
distribution_strategy
,
**
parsed_kwargs
)
# check if the distributors has already been produced in the past
...
...
@@ -441,6 +448,8 @@ class distributor(object):
i
+=
1
def
bincount
(
self
,
obj
,
length
,
weights
=
None
,
axis
=
None
):
if
obj
.
shape
==
():
raise
ValueError
(
"object of too small depth for desired array"
)
data
=
obj
.
get_local_data
(
copy
=
False
)
# this implementation fits all distribution strategies where the
# axes of the global array correspond to the axes of the local data
...
...
@@ -2240,7 +2249,7 @@ class _not_distributor(distributor):
if
isinstance
(
data_object
,
distributed_data_object
):
result_data
=
data_object
.
get_full_data
()
else
:
result_data
=
np
.
array
(
data_object
)
[:]
result_data
=
np
.
array
(
data_object
)
try
:
result_data
=
result_data
.
reshape
(
self
.
global_shape
)
except
ValueError
:
...
...
test/test_distributed_data_object.py
View file @
fb7e53de
...
...
@@ -115,7 +115,11 @@ def custom_name_func(testcase_func, param_num, param):
def
generate_data
(
global_shape
,
dtype
,
distribution_strategy
,
strictly_positive
=
False
):
if
distribution_strategy
in
global_distribution_strategies
:
if
global_shape
==
():
obj
=
distributed_data_object
(
global_shape
=
(),
global_data
=
42.
,
distribution_strategy
=
'not'
)
global_a
=
np
.
array
(
42
)
elif
distribution_strategy
in
global_distribution_strategies
:
a
=
np
.
arange
(
np
.
prod
(
global_shape
))
a
-=
np
.
prod
(
global_shape
)
//
2
...
...
@@ -250,6 +254,10 @@ class Test_Globaltype_Initialization(unittest.TestCase):
(
2
,
2
),
np
.
dtype
(
'int'
)],
[
None
,
(
10
,
10
),
None
,
(
10
,
10
),
np
.
dtype
(
'float64'
)],
[
1.
,
None
,
None
,
(),
np
.
dtype
(
'float64'
)],
[
None
,
(),
None
,
(),
np
.
dtype
(
'float64'
)],
],
global_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
def
test_special_init_cases
(
self
,
...
...
@@ -269,7 +277,7 @@ class Test_Globaltype_Initialization(unittest.TestCase):
###############################################################################
if
FOUND
[
'h5py'
]
==
True
:
if
FOUND
[
'h5py'
]:
@
parameterized
.
expand
(
itertools
.
product
(
hdf5_test_paths
,
hdf5_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
...
...
@@ -289,8 +297,6 @@ class Test_Globaltype_Initialization(unittest.TestCase):
itertools
.
product
(
[(
None
,
None
,
None
,
None
,
None
),
(
None
,
None
,
np
.
int_
,
None
,
None
),
(
None
,
(),
np
.
dtype
(
'int'
),
None
,
None
),
(
1
,
None
,
None
,
None
,
None
),
(
None
,
None
,
None
,
np
.
array
([
1
,
2
,
3
]),
(
3
,)),
(
None
,
None
,
np
.
int_
,
None
,
(
3
,))],
global_distribution_strategies
),
...
...
@@ -1507,7 +1513,7 @@ class Test_contractions(unittest.TestCase):
@
parameterized
.
expand
(
itertools
.
product
([
np
.
dtype
(
'int'
),
np
.
dtype
(
'float'
),
np
.
dtype
(
'complex'
)],
[(
0
,),
(
4
,
4
)],
[
(),
(
0
,),
(
4
,
4
)],
all_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
def
test_vdot
(
self
,
dtype
,
global_shape
,
distribution_strategy
):
...
...
@@ -1669,7 +1675,7 @@ class Test_special_methods(unittest.TestCase):
###############################################################################
@
parameterized
.
expand
(
itertools
.
product
([(
4
,),
(
8
,
8
),
(
0
,
4
),
(
4
,
0
,
8
)],
itertools
.
product
([
(),
(
4
,),
(
8
,
8
),
(
0
,
4
),
(
4
,
0
,
8
)],
all_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
def
test_flatten
(
self
,
global_shape
,
distribution_strategy
):
...
...
@@ -1678,7 +1684,7 @@ class Test_special_methods(unittest.TestCase):
distribution_strategy
)
assert_equal
(
obj
.
flatten
().
get_full_data
(),
a
.
flatten
())
p
=
obj
.
flatten
(
inplace
=
True
)
if
np
.
prod
(
global_shape
)
!=
0
:
if
np
.
prod
(
global_shape
)
!=
0
and
global_shape
!=
()
:
p
[
0
]
=
2222
assert_equal
(
obj
[(
0
,)
*
len
(
global_shape
)],
2222
)
...
...
@@ -1729,7 +1735,7 @@ class Test_special_methods(unittest.TestCase):
###############################################################################
if
FOUND
[
'h5py'
]
==
True
:
if
FOUND
[
'h5py'
]:
class
Test_load_save
(
unittest
.
TestCase
):
@
parameterized
.
expand
(
...
...
@@ -1747,7 +1753,7 @@ if FOUND['h5py'] == True:
path
=
os
.
path
.
join
(
tempfile
.
gettempdir
(),
'temp_hdf5_file.hdf5'
)
if
size
>
1
and
FOUND
[
'h5py_parallel'
]
==
False
:
if
size
>
1
and
not
FOUND
[
'h5py_parallel'
]:
assert_raises
(
RuntimeError
,
lambda
:
obj
.
save
(
alias
=
alias
,
path
=
path
))
else
:
...
...
@@ -1812,7 +1818,7 @@ class Test_axis(unittest.TestCase):
'min'
,
'amin'
,
'nanmin'
,
'argmin'
,
'max'
,
'amax'
,
'nanmax'
,
'argmax'
],
all_datatypes
[
1
:],
[(
1
,
),
(
2
,
3
)],
[
(),
(
1
,),
(
2
,
3
)],
all_distribution_strategies
,
[
None
,
0
,
(
1
,
),
(
0
,
1
)]),
testcase_func_name
=
custom_name_func
)
...
...
@@ -1826,15 +1832,20 @@ class Test_axis(unittest.TestCase):
assert_raises
(
NotImplementedError
,
lambda
:
getattr
(
obj
,
function
)
(
axis
=
axis
))
else
:
if
global_shape
!
=
(
1
,
):
if
global_shape
=
=
(
2
,
3
):
assert_almost_equal
(
getattr
(
obj
,
function
)(
axis
=
axis
),
getattr
(
np
,
function
)(
a
,
axis
=
axis
),
decimal
=
4
)
el
se
:
el
if
global_shape
==
(
1
,)
:
if
axis
in
[
None
,
0
,
(
0
,)]:
assert_almost_equal
(
getattr
(
obj
,
function
)(
axis
=
axis
),
getattr
(
np
,
function
)(
a
,
axis
=
axis
),
decimal
=
4
)
else
:
if
axis
in
[
None
]:
assert_almost_equal
(
getattr
(
obj
,
function
)(
axis
=
axis
),
getattr
(
np
,
function
)(
a
,
axis
=
axis
),
decimal
=
4
)
@
parameterized
.
expand
(
itertools
.
product
([
'sum'
,
'prod'
,
'mean'
,
'var'
,
'std'
,
'all'
,
'any'
,
...
...
@@ -1891,7 +1902,7 @@ class Test_axis(unittest.TestCase):
itertools
.
product
([(
'argmin_nonflat'
,
'argmin'
),
(
'argmax_nonflat'
,
'argmax'
)],
all_datatypes
[
1
:],
[(
0
,),
(
1
,),
(
4
,
4
,
3
),
(
4
,
0
,
3
)],
[
(),
(
0
,),
(
1
,),
(
4
,
4
,
3
),
(
4
,
0
,
3
)],
all_distribution_strategies
,
[
None
,
(
1
,
),
(
1
,
2
)]),
testcase_func_name
=
custom_name_func
)
...
...
@@ -1900,25 +1911,24 @@ class Test_axis(unittest.TestCase):
(
a
,
obj
)
=
generate_data
(
global_shape
,
dtype
,
distribution_strategy
,
strictly_positive
=
True
)
print
(
a
,
obj
)
if
0
in
global_shape
:
assert_raises
(
ValueError
,
lambda
:
getattr
(
obj
,
function_pair
[
0
])(
axis
=
axis
))
else
:
if
axis
is
not
None
:
if
axis
is
not
None
and
global_shape
!=
()
:
assert_raises
(
NotImplementedError
,
lambda
:
getattr
(
obj
,
function_pair
[
0
])(
axis
=
axis
))
else
:
if
global_shape
!=
(
0
,)
and
global_shape
!=
(
1
,)
:
if
len
(
global_shape
)
>
1
:
assert_almost_equal
(
getattr
(
obj
,
function_pair
[
0
])(
axis
=
axis
),
np
.
unravel_index
(
getattr
(
np
,
function_pair
[
1
])
(
a
,
axis
=
axis
),
dims
=
global_shape
),
decimal
=
4
)
el
se
:
el
if
len
(
global_shape
)
==
1
:
assert_almost_equal
(
getattr
(
obj
,
function_pair
[
0
])
(
axis
=
axis
),
np
.
unravel_index
(
...
...
@@ -1926,6 +1936,9 @@ class Test_axis(unittest.TestCase):
(
a
,
axis
=
axis
),
dims
=
global_shape
),
decimal
=
4
)
else
:
assert_almost_equal
(
getattr
(
obj
,
function_pair
[
0
])
(
axis
=
axis
),
(
0
,))
class
Test_arange
(
unittest
.
TestCase
):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment