Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
D2O
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
22
Issues
22
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ift
D2O
Commits
fb7e53de
Commit
fb7e53de
authored
Jan 24, 2017
by
Theo Steininger
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added functionality for empty-shape d2o's.
parent
b8f6f2d9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
58 additions
and
27 deletions
+58
-27
d2o/distributed_data_object.py
d2o/distributed_data_object.py
+10
-1
d2o/distributor_factory.py
d2o/distributor_factory.py
+18
-9
test/test_distributed_data_object.py
test/test_distributed_data_object.py
+30
-17
No files found.
d2o/distributed_data_object.py
View file @
fb7e53de
...
...
@@ -1226,6 +1226,9 @@ class distributed_data_object(Versionable, object):
if
axis
is
not
None
:
raise
NotImplementedError
(
"ERROR: argmin doesn't support axis "
"keyword"
)
if
self
.
shape
==
():
return
0
if
0
in
self
.
local_shape
:
local_argmin
=
np
.
nan
local_argmin_value
=
np
.
nan
...
...
@@ -1260,6 +1263,9 @@ class distributed_data_object(Versionable, object):
if
axis
is
not
None
:
raise
NotImplementedError
(
"ERROR: argmax doesn't support axis "
"keyword"
)
if
self
.
shape
==
():
return
0
if
0
in
self
.
local_shape
:
local_argmax
=
np
.
nan
local_argmax_value
=
-
np
.
inf
...
...
@@ -1291,7 +1297,8 @@ class distributed_data_object(Versionable, object):
See Also:
argmin, argmax, argmax_nonflat
"""
if
self
.
shape
==
():
return
(
0
,)
return
np
.
unravel_index
(
self
.
argmin
(
axis
=
axis
),
self
.
shape
)
def
argmax_nonflat
(
self
,
axis
=
None
):
...
...
@@ -1300,6 +1307,8 @@ class distributed_data_object(Versionable, object):
See Also:
argmin, argmax, argmin_nonflat
"""
if
self
.
shape
==
():
return
(
0
,)
return
np
.
unravel_index
(
self
.
argmax
(
axis
=
axis
),
self
.
shape
)
def
conjugate
(
self
):
...
...
d2o/distributor_factory.py
View file @
fb7e53de
...
...
@@ -81,7 +81,7 @@ class _distributor_factory(object):
if
expensive_checks
:
# Check that all nodes got the same distribution_strategy
strat_list
=
comm
.
allgather
(
distribution_strategy
)
if
all
(
x
==
strat_list
[
0
]
for
x
in
strat_list
)
==
False
:
if
not
all
(
x
==
strat_list
[
0
]
for
x
in
strat_list
)
:
raise
ValueError
(
about_cstring
(
"ERROR: The distribution-strategy must be the same on "
+
"all nodes!"
))
...
...
@@ -135,7 +135,7 @@ class _distributor_factory(object):
dtype
=
np
.
dtype
(
dtype
)
if
expensive_checks
:
dtype_list
=
comm
.
allgather
(
dtype
)
if
all
(
x
==
dtype_list
[
0
]
for
x
in
dtype_list
)
==
False
:
if
not
all
(
x
==
dtype_list
[
0
]
for
x
in
dtype_list
)
:
raise
ValueError
(
about_cstring
(
"ERROR: The given dtype must be the same on all nodes!"
))
return_dict
[
'dtype'
]
=
dtype
...
...
@@ -145,17 +145,19 @@ class _distributor_factory(object):
if
distribution_strategy
in
STRATEGIES
[
'global'
]:
if
dset
is
not
None
:
global_shape
=
dset
.
shape
elif
global_data
is
not
None
and
n
p
.
isscalar
(
global_data
)
==
False
:
elif
global_data
is
not
None
and
n
ot
np
.
isscalar
(
global_data
)
:
global_shape
=
global_data
.
shape
elif
global_shape
is
not
None
:
global_shape
=
tuple
(
global_shape
)
elif
global_data
is
not
None
:
global_shape
=
()
else
:
raise
ValueError
(
about_cstring
(
"ERROR: Neither
non-0-dimensional
global_data nor "
+
"ERROR: Neither global_data nor "
+
"global_shape nor hdf5 file supplied!"
))
if
global_shape
==
():
raise
ValueError
(
about_cstring
(
"ERROR: global_shape == () is not a valid shape!"
))
#
if global_shape == ():
#
raise ValueError(about_cstring(
#
"ERROR: global_shape == () is not a valid shape!"))
if
expensive_checks
:
global_shape_list
=
comm
.
allgather
(
global_shape
)
...
...
@@ -170,7 +172,7 @@ class _distributor_factory(object):
elif
distribution_strategy
in
[
'freeform'
]:
if
isinstance
(
global_data
,
distributed_data_object
):
local_shape
=
global_data
.
local_shape
elif
local_data
is
not
None
and
n
p
.
isscalar
(
local_data
)
==
False
:
elif
local_data
is
not
None
and
n
ot
np
.
isscalar
(
local_data
)
:
local_shape
=
local_data
.
shape
elif
local_shape
is
not
None
:
local_shape
=
tuple
(
local_shape
)
...
...
@@ -240,6 +242,11 @@ class _distributor_factory(object):
comm
=
comm
,
**
kwargs
)
if
parsed_kwargs
.
get
(
'global_shape'
)
==
():
distribution_strategy
=
'not'
about_infos_cprint
(
"WARNING: Distribution strategy was set to "
"'not' because of global_shape == ()"
)
hashed_kwargs
=
self
.
hash_arguments
(
distribution_strategy
,
**
parsed_kwargs
)
# check if the distributors has already been produced in the past
...
...
@@ -441,6 +448,8 @@ class distributor(object):
i
+=
1
def
bincount
(
self
,
obj
,
length
,
weights
=
None
,
axis
=
None
):
if
obj
.
shape
==
():
raise
ValueError
(
"object of too small depth for desired array"
)
data
=
obj
.
get_local_data
(
copy
=
False
)
# this implementation fits all distribution strategies where the
# axes of the global array correspond to the axes of the local data
...
...
@@ -2240,7 +2249,7 @@ class _not_distributor(distributor):
if
isinstance
(
data_object
,
distributed_data_object
):
result_data
=
data_object
.
get_full_data
()
else
:
result_data
=
np
.
array
(
data_object
)
[:]
result_data
=
np
.
array
(
data_object
)
try
:
result_data
=
result_data
.
reshape
(
self
.
global_shape
)
except
ValueError
:
...
...
test/test_distributed_data_object.py
View file @
fb7e53de
...
...
@@ -115,7 +115,11 @@ def custom_name_func(testcase_func, param_num, param):
def
generate_data
(
global_shape
,
dtype
,
distribution_strategy
,
strictly_positive
=
False
):
if
distribution_strategy
in
global_distribution_strategies
:
if
global_shape
==
():
obj
=
distributed_data_object
(
global_shape
=
(),
global_data
=
42.
,
distribution_strategy
=
'not'
)
global_a
=
np
.
array
(
42
)
elif
distribution_strategy
in
global_distribution_strategies
:
a
=
np
.
arange
(
np
.
prod
(
global_shape
))
a
-=
np
.
prod
(
global_shape
)
//
2
...
...
@@ -250,6 +254,10 @@ class Test_Globaltype_Initialization(unittest.TestCase):
(
2
,
2
),
np
.
dtype
(
'int'
)],
[
None
,
(
10
,
10
),
None
,
(
10
,
10
),
np
.
dtype
(
'float64'
)],
[
1.
,
None
,
None
,
(),
np
.
dtype
(
'float64'
)],
[
None
,
(),
None
,
(),
np
.
dtype
(
'float64'
)],
],
global_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
def
test_special_init_cases
(
self
,
...
...
@@ -269,7 +277,7 @@ class Test_Globaltype_Initialization(unittest.TestCase):
###############################################################################
if
FOUND
[
'h5py'
]
==
True
:
if
FOUND
[
'h5py'
]:
@
parameterized
.
expand
(
itertools
.
product
(
hdf5_test_paths
,
hdf5_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
...
...
@@ -289,8 +297,6 @@ class Test_Globaltype_Initialization(unittest.TestCase):
itertools
.
product
(
[(
None
,
None
,
None
,
None
,
None
),
(
None
,
None
,
np
.
int_
,
None
,
None
),
(
None
,
(),
np
.
dtype
(
'int'
),
None
,
None
),
(
1
,
None
,
None
,
None
,
None
),
(
None
,
None
,
None
,
np
.
array
([
1
,
2
,
3
]),
(
3
,)),
(
None
,
None
,
np
.
int_
,
None
,
(
3
,))],
global_distribution_strategies
),
...
...
@@ -1507,7 +1513,7 @@ class Test_contractions(unittest.TestCase):
@
parameterized
.
expand
(
itertools
.
product
([
np
.
dtype
(
'int'
),
np
.
dtype
(
'float'
),
np
.
dtype
(
'complex'
)],
[(
0
,),
(
4
,
4
)],
[(
),
(
0
,),
(
4
,
4
)],
all_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
def
test_vdot
(
self
,
dtype
,
global_shape
,
distribution_strategy
):
...
...
@@ -1669,7 +1675,7 @@ class Test_special_methods(unittest.TestCase):
###############################################################################
@
parameterized
.
expand
(
itertools
.
product
([(
4
,),
(
8
,
8
),
(
0
,
4
),
(
4
,
0
,
8
)],
itertools
.
product
([(
),
(
4
,),
(
8
,
8
),
(
0
,
4
),
(
4
,
0
,
8
)],
all_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
def
test_flatten
(
self
,
global_shape
,
distribution_strategy
):
...
...
@@ -1678,7 +1684,7 @@ class Test_special_methods(unittest.TestCase):
distribution_strategy
)
assert_equal
(
obj
.
flatten
().
get_full_data
(),
a
.
flatten
())
p
=
obj
.
flatten
(
inplace
=
True
)
if
np
.
prod
(
global_shape
)
!=
0
:
if
np
.
prod
(
global_shape
)
!=
0
and
global_shape
!=
()
:
p
[
0
]
=
2222
assert_equal
(
obj
[(
0
,)
*
len
(
global_shape
)],
2222
)
...
...
@@ -1729,7 +1735,7 @@ class Test_special_methods(unittest.TestCase):
###############################################################################
if
FOUND
[
'h5py'
]
==
True
:
if
FOUND
[
'h5py'
]:
class
Test_load_save
(
unittest
.
TestCase
):
@
parameterized
.
expand
(
...
...
@@ -1747,7 +1753,7 @@ if FOUND['h5py'] == True:
path
=
os
.
path
.
join
(
tempfile
.
gettempdir
(),
'temp_hdf5_file.hdf5'
)
if
size
>
1
and
FOUND
[
'h5py_parallel'
]
==
False
:
if
size
>
1
and
not
FOUND
[
'h5py_parallel'
]
:
assert_raises
(
RuntimeError
,
lambda
:
obj
.
save
(
alias
=
alias
,
path
=
path
))
else
:
...
...
@@ -1812,7 +1818,7 @@ class Test_axis(unittest.TestCase):
'min'
,
'amin'
,
'nanmin'
,
'argmin'
,
'max'
,
'amax'
,
'nanmax'
,
'argmax'
],
all_datatypes
[
1
:],
[(
1
,
),
(
2
,
3
)],
[(
),
(
1
,
),
(
2
,
3
)],
all_distribution_strategies
,
[
None
,
0
,
(
1
,
),
(
0
,
1
)]),
testcase_func_name
=
custom_name_func
)
...
...
@@ -1826,15 +1832,20 @@ class Test_axis(unittest.TestCase):
assert_raises
(
NotImplementedError
,
lambda
:
getattr
(
obj
,
function
)
(
axis
=
axis
))
else
:
if
global_shape
!=
(
1
,
):
if
global_shape
==
(
2
,
3
):
assert_almost_equal
(
getattr
(
obj
,
function
)(
axis
=
axis
),
getattr
(
np
,
function
)(
a
,
axis
=
axis
),
decimal
=
4
)
el
se
:
el
if
global_shape
==
(
1
,)
:
if
axis
in
[
None
,
0
,
(
0
,)]:
assert_almost_equal
(
getattr
(
obj
,
function
)(
axis
=
axis
),
getattr
(
np
,
function
)(
a
,
axis
=
axis
),
decimal
=
4
)
else
:
if
axis
in
[
None
]:
assert_almost_equal
(
getattr
(
obj
,
function
)(
axis
=
axis
),
getattr
(
np
,
function
)(
a
,
axis
=
axis
),
decimal
=
4
)
@
parameterized
.
expand
(
itertools
.
product
([
'sum'
,
'prod'
,
'mean'
,
'var'
,
'std'
,
'all'
,
'any'
,
...
...
@@ -1891,7 +1902,7 @@ class Test_axis(unittest.TestCase):
itertools
.
product
([(
'argmin_nonflat'
,
'argmin'
),
(
'argmax_nonflat'
,
'argmax'
)],
all_datatypes
[
1
:],
[(
0
,),
(
1
,),
(
4
,
4
,
3
),
(
4
,
0
,
3
)],
[(
),
(
0
,),
(
1
,),
(
4
,
4
,
3
),
(
4
,
0
,
3
)],
all_distribution_strategies
,
[
None
,
(
1
,
),
(
1
,
2
)]),
testcase_func_name
=
custom_name_func
)
...
...
@@ -1900,25 +1911,24 @@ class Test_axis(unittest.TestCase):
(
a
,
obj
)
=
generate_data
(
global_shape
,
dtype
,
distribution_strategy
,
strictly_positive
=
True
)
print
(
a
,
obj
)
if
0
in
global_shape
:
assert_raises
(
ValueError
,
lambda
:
getattr
(
obj
,
function_pair
[
0
])(
axis
=
axis
))
else
:
if
axis
is
not
None
:
if
axis
is
not
None
and
global_shape
!=
()
:
assert_raises
(
NotImplementedError
,
lambda
:
getattr
(
obj
,
function_pair
[
0
])(
axis
=
axis
))
else
:
if
global_shape
!=
(
0
,)
and
global_shape
!=
(
1
,)
:
if
len
(
global_shape
)
>
1
:
assert_almost_equal
(
getattr
(
obj
,
function_pair
[
0
])(
axis
=
axis
),
np
.
unravel_index
(
getattr
(
np
,
function_pair
[
1
])
(
a
,
axis
=
axis
),
dims
=
global_shape
),
decimal
=
4
)
el
se
:
el
if
len
(
global_shape
)
==
1
:
assert_almost_equal
(
getattr
(
obj
,
function_pair
[
0
])
(
axis
=
axis
),
np
.
unravel_index
(
...
...
@@ -1926,6 +1936,9 @@ class Test_axis(unittest.TestCase):
(
a
,
axis
=
axis
),
dims
=
global_shape
),
decimal
=
4
)
else
:
assert_almost_equal
(
getattr
(
obj
,
function_pair
[
0
])
(
axis
=
axis
),
(
0
,))
class
Test_arange
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment