Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
D2O
Commits
2126ef50
Commit
2126ef50
authored
May 03, 2017
by
Theo Steininger
Browse files
Added Python3 compatibility
parent
083f6433
Pipeline
#11937
failed with stage
in 5 minutes and 6 seconds
Changes
10
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
d2o/__init__.py
View file @
2126ef50
...
...
@@ -17,13 +17,14 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from
__future__
import
division
from
__future__
import
absolute_import
from
version
import
__version__
from
.
version
import
__version__
from
config
import
configuration
from
distributed_data_object
import
distributed_data_object
from
d2o_librarian
import
d2o_librarian
from
.
config
import
configuration
from
.
distributed_data_object
import
distributed_data_object
from
.
d2o_librarian
import
d2o_librarian
from
strategies
import
STRATEGIES
from
.
strategies
import
STRATEGIES
from
factory_methods
import
*
from
.
factory_methods
import
*
d2o/config/__init__.py
View file @
2126ef50
from
__future__
import
absolute_import
# D2O
# Copyright (C) 2016 Theo Steininger
#
...
...
@@ -16,5 +17,5 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from
d2o_config
import
dependency_injector
,
\
from
.
d2o_config
import
dependency_injector
,
\
configuration
d2o/d2o_iter.py
View file @
2126ef50
...
...
@@ -16,6 +16,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from
builtins
import
object
import
numpy
as
np
...
...
@@ -29,7 +30,7 @@ class d2o_iter(object):
def
__iter__
(
self
):
return
self
def
next
(
self
):
def
__
next
__
(
self
):
if
self
.
n
==
0
:
raise
StopIteration
()
...
...
d2o/d2o_librarian.py
View file @
2126ef50
...
...
@@ -16,6 +16,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from
builtins
import
object
from
weakref
import
WeakValueDictionary
as
weakdict
...
...
d2o/distributed_data_object.py
View file @
2126ef50
from
__future__
import
division
from
__future__
import
absolute_import
# D2O
# Copyright (C) 2016 Theo Steininger
#
...
...
@@ -23,10 +25,10 @@ from keepers import Versionable,\
from
d2o.config
import
configuration
as
gc
,
\
dependency_injector
as
gdi
from
d2o_librarian
import
d2o_librarian
from
cast_axis_to_tuple
import
cast_axis_to_tuple
from
.
d2o_librarian
import
d2o_librarian
from
.
cast_axis_to_tuple
import
cast_axis_to_tuple
from
strategies
import
STRATEGIES
from
.
strategies
import
STRATEGIES
MPI
=
gdi
[
gc
[
'mpi_module'
]]
...
...
@@ -174,7 +176,7 @@ class distributed_data_object(Loggable, Versionable, object):
if
distribution_strategy
is
None
:
distribution_strategy
=
gc
[
'default_distribution_strategy'
]
from
distributor_factory
import
distributor_factory
from
.
distributor_factory
import
distributor_factory
self
.
distributor
=
distributor_factory
.
get_distributor
(
distribution_strategy
=
distribution_strategy
,
comm
=
comm
,
...
...
@@ -263,7 +265,7 @@ class distributed_data_object(Loggable, Versionable, object):
# repair its class
new_copy
.
__class__
=
self
.
__class__
# now copy everthing in the __dict__ except for the data array
for
key
,
value
in
self
.
__dict__
.
items
():
for
key
,
value
in
list
(
self
.
__dict__
.
items
()
)
:
if
key
!=
'data'
:
new_copy
.
__dict__
[
key
]
=
value
else
:
...
...
@@ -837,7 +839,7 @@ class distributed_data_object(Loggable, Versionable, object):
_builtin_helper
"""
return
self
.
_
_
div__
(
other
)
return
self
.
_
builtin_helper
(
'__true
div__
'
,
other
)
def
__rdiv__
(
self
,
other
):
""" x.__rdiv__(y) <==> y/x
...
...
@@ -857,7 +859,7 @@ class distributed_data_object(Loggable, Versionable, object):
_builtin_helper
"""
return
self
.
_
_r
div__
(
other
)
return
self
.
_
builtin_helper
(
'__rtrue
div__
'
,
other
)
def
__idiv__
(
self
,
other
):
""" x.__idiv__(y) <==> x/=y
...
...
@@ -879,7 +881,9 @@ class distributed_data_object(Loggable, Versionable, object):
_builtin_helper
"""
return
self
.
__idiv__
(
other
)
return
self
.
_builtin_helper
(
'__itruediv__'
,
other
,
inplace
=
True
)
def
__floordiv__
(
self
,
other
):
""" x.__floordiv__(y) <==> x//y
...
...
@@ -1472,7 +1476,10 @@ class distributed_data_object(Loggable, Versionable, object):
if
axis
is
():
return
self
.
copy
()
length
=
max
(
self
.
amax
()
+
1
,
minlength
)
if
minlength
is
not
None
:
length
=
max
(
self
.
amax
()
+
1
,
minlength
)
else
:
length
=
self
.
amax
()
+
1
return
self
.
distributor
.
bincount
(
obj
=
self
,
length
=
length
,
...
...
d2o/distributor_factory.py
View file @
2126ef50
from
__future__
import
absolute_import
# D2O
# Copyright (C) 2016 Theo Steininger
#
...
...
@@ -17,6 +18,11 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from
builtins
import
next
from
builtins
import
str
from
builtins
import
map
from
builtins
import
range
from
builtins
import
object
import
numpy
as
np
from
keepers
import
Loggable
...
...
@@ -24,17 +30,18 @@ from keepers import Loggable
from
d2o.config
import
configuration
as
gc
,
\
dependency_injector
as
gdi
from
distributed_data_object
import
distributed_data_object
from
.
distributed_data_object
import
distributed_data_object
from
d2o_iter
import
d2o_slicing_iter
,
\
from
.
d2o_iter
import
d2o_slicing_iter
,
\
d2o_not_iter
from
d2o_librarian
import
d2o_librarian
from
dtype_converter
import
dtype_converter
from
cast_axis_to_tuple
import
cast_axis_to_tuple
from
translate_to_mpi_operator
import
op_translate_dict
from
slicing_generator
import
slicing_generator
from
.
d2o_librarian
import
d2o_librarian
from
.
dtype_converter
import
dtype_converter
from
.
cast_axis_to_tuple
import
cast_axis_to_tuple
from
.
translate_to_mpi_operator
import
op_translate_dict
from
.
slicing_generator
import
slicing_generator
from
strategies
import
STRATEGIES
from
.strategies
import
STRATEGIES
from
functools
import
reduce
MPI
=
gdi
[
gc
[
'mpi_module'
]]
h5py
=
gdi
.
get
(
'h5py'
)
...
...
@@ -208,10 +215,10 @@ class _distributor_factory(Loggable, object):
kwargs
[
'dtype'
]
=
self
.
dictionize_np
(
kwargs
[
'dtype'
])
kwargs
[
'distribution_strategy'
]
=
distribution_strategy
return
frozenset
(
kwargs
.
items
())
return
frozenset
(
list
(
kwargs
.
items
())
)
def
dictionize_np
(
self
,
x
):
dic
=
x
.
type
.
__dict__
.
items
()
dic
=
list
(
x
.
type
.
__dict__
.
items
()
)
if
x
is
np
.
float
:
dic
[
24
]
=
0
dic
[
29
]
=
0
...
...
@@ -275,8 +282,8 @@ def _infer_key_type(key):
elif
isinstance
(
key
,
tuple
)
or
isinstance
(
key
,
list
):
# Check if there is something different in the array than
# scalars and slices
scalarQ
=
np
.
array
(
map
(
np
.
isscalar
,
key
))
sliceQ
=
np
.
array
(
map
(
lambda
z
:
isinstance
(
z
,
slice
)
,
key
)
)
scalarQ
=
np
.
array
(
list
(
map
(
np
.
isscalar
,
key
))
)
sliceQ
=
np
.
array
(
[
isinstance
(
z
,
slice
)
for
z
in
key
]
)
if
np
.
all
(
scalarQ
+
sliceQ
):
found
=
'slicetuple'
else
:
...
...
@@ -377,26 +384,24 @@ class distributor(object):
# from the librarian
else
:
to_index_list
=
comm
.
allgather
(
to_key
.
index
)
to_key_list
=
map
(
lambda
z
:
d2o_librarian
[
z
]
,
to_index_list
)
to_key_list
=
[
d2o_librarian
[
z
]
for
z
in
to_index_list
]
# gather the local from_keys. It is the same procedure as above
if
from_found
!=
'd2o'
:
from_key_list
=
comm
.
allgather
(
from_key
)
else
:
from_index_list
=
comm
.
allgather
(
from_key
.
index
)
from_key_list
=
map
(
lambda
z
:
d2o_librarian
[
z
],
from_index_list
)
from_key_list
=
[
d2o_librarian
[
z
]
for
z
in
from_index_list
]
local_data_update_is_scalar
=
np
.
isscalar
(
data_update
)
local_scalar_list
=
comm
.
allgather
(
local_data_update_is_scalar
)
for
i
in
x
range
(
len
(
to_key_list
)):
for
i
in
range
(
len
(
to_key_list
)):
if
np
.
all
(
np
.
array
(
local_scalar_list
)
==
True
):
scalar_list
=
comm
.
allgather
(
data_update
)
temp_data_update
=
scalar_list
[
i
]
elif
isinstance
(
data_update
,
distributed_data_object
):
data_update_index_list
=
comm
.
allgather
(
data_update
.
index
)
data_update_list
=
map
(
lambda
z
:
d2o_librarian
[
z
],
data_update_index_list
)
data_update_list
=
[
d2o_librarian
[
z
]
for
z
in
data_update_index_list
]
temp_data_update
=
data_update_list
[
i
]
else
:
# build a temporary freeform d2o which only contains data
...
...
@@ -458,14 +463,14 @@ class distributor(object):
# do the reordering
ndim
=
len
(
self
.
global_shape
)
axis
=
sorted
(
cast_axis_to_tuple
(
axis
,
length
=
ndim
))
reordering
=
[
x
for
x
in
x
range
(
ndim
)
if
x
not
in
axis
]
reordering
=
[
x
for
x
in
range
(
ndim
)
if
x
not
in
axis
]
reordering
+=
axis
data
=
np
.
transpose
(
data
,
reordering
)
if
local_weights
is
not
None
:
local_weights
=
np
.
transpose
(
local_weights
,
reordering
)
reord_axis
=
range
(
ndim
-
len
(
axis
),
ndim
)
reord_axis
=
list
(
range
(
ndim
-
len
(
axis
),
ndim
)
)
# semi-flatten the dimensions in `axis`, i.e. after reordering
# the last ones.
...
...
@@ -507,9 +512,9 @@ class distributor(object):
# axis has been sorted above
insert_position
=
axis
[
0
]
new_ndim
=
len
(
local_counts
.
shape
)
return_order
=
(
range
(
0
,
insert_position
)
+
return_order
=
(
list
(
range
(
0
,
insert_position
)
)
+
[
new_ndim
-
1
,
]
+
range
(
insert_position
,
new_ndim
-
1
))
list
(
range
(
insert_position
,
new_ndim
-
1
))
)
local_counts
=
np
.
ascontiguousarray
(
local_counts
.
transpose
(
return_order
))
return
self
.
_combine_local_bincount_counts
(
obj
,
local_counts
,
axis
)
...
...
@@ -656,7 +661,7 @@ class _slicing_distributor(distributor):
# get first node with non-None data
try
:
start
=
next
(
i
for
i
in
x
range
(
size
)
if
got_array_list
[
i
]
>
1
)
start
=
next
(
i
for
i
in
range
(
size
)
if
got_array_list
[
i
]
>
1
)
except
(
StopIteration
):
raise
ValueError
(
"ERROR: No process with non-None data."
)
...
...
@@ -675,7 +680,7 @@ class _slicing_distributor(distributor):
self
.
comm
.
Bcast
([
result_data
,
mpi_dtype
],
root
=
start
)
for
i
in
x
range
(
start
+
1
,
size
):
for
i
in
range
(
start
+
1
,
size
):
if
got_array_list
[
i
]
>
1
:
if
rank
==
i
:
temp_data
=
data
...
...
@@ -686,7 +691,7 @@ class _slicing_distributor(distributor):
else
:
result_data
=
self
.
comm
.
bcast
(
data
,
root
=
start
)
for
i
in
x
range
(
start
+
1
,
size
):
for
i
in
range
(
start
+
1
,
size
):
if
got_array_list
[
i
]
>
1
:
temp_data
=
self
.
comm
.
bcast
(
data
,
root
=
i
)
result_data
=
op
(
result_data
,
temp_data
)
...
...
@@ -702,7 +707,7 @@ class _slicing_distributor(distributor):
if
axis
is
None
:
new_shape
=
()
else
:
new_shape
=
tuple
([
old_shape
[
i
]
for
i
in
x
range
(
len
(
old_shape
))
new_shape
=
tuple
([
old_shape
[
i
]
for
i
in
range
(
len
(
old_shape
))
if
i
not
in
axis
])
local_data
=
parent
.
data
...
...
@@ -1015,7 +1020,7 @@ class _slicing_distributor(distributor):
# from the librarian
else
:
index_list
=
comm
.
allgather
(
key
.
index
)
key_list
=
map
(
lambda
z
:
d2o_librarian
[
z
]
,
index_list
)
key_list
=
[
d2o_librarian
[
z
]
for
z
in
index_list
]
i
=
0
for
temp_key
in
key_list
:
# build the locally fed d2o
...
...
@@ -1104,7 +1109,7 @@ class _slicing_distributor(distributor):
# if the index lies within the local nodes' data-range
# take the shifted index, combined with rest of from_list_key
result
=
[
local_zeroth_key
]
for
ii
in
x
range
(
1
,
len
(
from_list_key
)):
for
ii
in
range
(
1
,
len
(
from_list_key
)):
current
=
from_list_key
[
ii
]
if
isinstance
(
current
,
distributed_data_object
):
result
.
append
(
current
.
get_full_data
())
...
...
@@ -1125,21 +1130,20 @@ class _slicing_distributor(distributor):
raise
ValueError
(
"Index out of bounds!"
)
# shift the indices according to shift
shift_list
=
self
.
comm
.
allgather
(
shift
)
local_zeroth_key_list
=
map
(
lambda
z
:
zeroth_key
-
z
,
shift_list
)
local_zeroth_key_list
=
[
zeroth_key
-
z
for
z
in
shift_list
]
# discard all entries where the indices are negative or larger
# than local_length
greater_than_lower_list
=
map
(
lambda
z
:
z
>=
0
,
local_zeroth_key_list
)
greater_than_lower_list
=
[
z
>=
0
for
z
in
local_zeroth_key_list
]
# -> build up a list with the local selection d2o's
local_length_list
=
self
.
comm
.
allgather
(
local_length
)
less_than_upper_list
=
map
(
lambda
z
,
zz
:
z
<
zz
,
local_zeroth_key_list
,
local_length_list
)
local_selection_list
=
map
(
lambda
z
,
zz
:
z
*
zz
,
less_than_upper_list
,
greater_than_lower_list
)
for
j
in
x
range
(
len
(
local_zeroth_key_list
)):
less_than_upper_list
=
list
(
map
(
lambda
z
,
zz
:
z
<
zz
,
local_zeroth_key_list
,
local_length_list
)
)
local_selection_list
=
list
(
map
(
lambda
z
,
zz
:
z
*
zz
,
less_than_upper_list
,
greater_than_lower_list
)
)
for
j
in
range
(
len
(
local_zeroth_key_list
)):
temp_result
=
local_zeroth_key_list
[
j
].
\
get_data
(
local_selection_list
[
j
]).
\
get_full_data
(
target_rank
=
j
)
...
...
@@ -1152,7 +1156,7 @@ class _slicing_distributor(distributor):
# "ERROR: The first dimemnsion of list_key must be sorted!"))
result
=
[
result
]
for
ii
in
x
range
(
1
,
len
(
from_list_key
)):
for
ii
in
range
(
1
,
len
(
from_list_key
)):
current
=
from_list_key
[
ii
]
if
np
.
isscalar
(
current
):
result
.
append
(
current
)
...
...
@@ -1161,7 +1165,7 @@ class _slicing_distributor(distributor):
local_selection_list
[
rank
],
local_keys
=
True
).
get_local_data
(
copy
=
False
))
else
:
for
j
in
x
range
(
len
(
local_selection_list
)):
for
j
in
range
(
len
(
local_selection_list
)):
temp_select
=
local_selection_list
[
j
].
\
get_full_data
(
target_rank
=
j
)
if
j
==
rank
:
...
...
@@ -1192,7 +1196,7 @@ class _slicing_distributor(distributor):
# raise ValueError(about_cstring(
# "ERROR: The first dimemnsion of list_key must be sorted!"))
for
ii
in
x
range
(
1
,
len
(
from_list_key
)):
for
ii
in
range
(
1
,
len
(
from_list_key
)):
current
=
from_list_key
[
ii
]
if
np
.
isscalar
(
current
):
result
.
append
(
current
)
...
...
@@ -1270,21 +1274,23 @@ class _slicing_distributor(distributor):
global_length
):
# Reformulate negative indices
if
slice_object
.
start
<
0
and
slice_object
.
start
is
not
None
:
temp_start
=
slice_object
.
start
+
global_length
if
temp_start
<
0
:
temp_start
=
0
if
slice_object
.
start
is
not
None
:
if
slice_object
.
start
<
0
:
temp_start
=
slice_object
.
start
+
global_length
if
temp_start
<
0
:
temp_start
=
0
slice_object
=
slice
(
temp_start
,
slice_object
.
stop
,
slice_object
.
step
)
slice_object
=
slice
(
temp_start
,
slice_object
.
stop
,
slice_object
.
step
)
if
slice_object
.
stop
<
0
and
slice_object
.
stop
is
not
None
:
temp_stop
=
slice_object
.
stop
+
global_length
if
temp_stop
<
0
:
temp_stop
=
None
if
slice_object
.
stop
is
not
None
:
if
slice_object
.
stop
<
0
:
temp_stop
=
slice_object
.
stop
+
global_length
if
temp_stop
<
0
:
temp_stop
=
None
slice_object
=
slice
(
slice_object
.
start
,
temp_stop
,
slice_object
.
step
)
slice_object
=
slice
(
slice_object
.
start
,
temp_stop
,
slice_object
.
step
)
# initialize the step
if
slice_object
.
step
is
None
:
...
...
@@ -1530,10 +1536,10 @@ class _slicing_distributor(distributor):
local_where
[
0
]
=
local_where
[
0
]
+
self
.
local_start
local_where
=
tuple
(
local_where
)
global_where
=
map
(
lambda
z
:
distributed_data_object
(
local_data
=
z
,
distribution_strategy
=
'freeform'
)
,
local_where
)
global_where
=
\
[
distributed_data_object
(
local_data
=
z
,
distribution_strategy
=
'freeform'
)
for
z
in
local_where
]
return
global_where
def
unique
(
self
,
data
):
...
...
@@ -1552,7 +1558,7 @@ class _slicing_distributor(distributor):
[
local_data_length_list
,
MPI
.
INT
])
global_unique_data
=
np
.
array
([],
dtype
=
self
.
dtype
)
for
i
in
x
range
(
size
):
for
i
in
range
(
size
):
# broadcast data to the other nodes
# prepare the recv array
if
rank
!=
i
:
...
...
@@ -2151,10 +2157,10 @@ class _not_distributor(distributor):
def
where
(
self
,
data
):
# compute the result from np.where
local_where
=
np
.
where
(
data
)
global_where
=
map
(
lambda
z
:
distributed_data_object
(
global_data
=
z
,
distribution_strategy
=
'not'
)
,
local_where
)
global_where
=
\
[
distributed_data_object
(
global_data
=
z
,
distribution_strategy
=
'not'
)
for
z
in
local_where
]
return
global_where
def
unique
(
self
,
data
):
...
...
d2o/dtype_converter.py
View file @
2126ef50
...
...
@@ -16,6 +16,8 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from
builtins
import
map
from
builtins
import
object
import
numpy
as
np
from
d2o.config
import
configuration
as
gc
,
\
...
...
@@ -54,15 +56,17 @@ class _dtype_converter(object):
[
np
.
dtype
(
'complex128'
),
MPI
.
DOUBLE_COMPLEX
]]
to_mpi_pre_dict
=
np
.
array
(
pre_dict
)
to_mpi_pre_dict
[:,
0
]
=
map
(
self
.
dictionize_np
,
to_mpi_pre_dict
[:,
0
])
to_mpi_pre_dict
[:,
0
]
=
list
(
map
(
self
.
dictionize_np
,
to_mpi_pre_dict
[:,
0
]))
self
.
_to_mpi_dict
=
dict
(
to_mpi_pre_dict
)
to_np_pre_dict
=
np
.
array
(
pre_dict
)[:,
::
-
1
]
to_np_pre_dict
[:,
0
]
=
map
(
self
.
dictionize_mpi
,
to_np_pre_dict
[:,
0
])
to_np_pre_dict
[:,
0
]
=
list
(
map
(
self
.
dictionize_mpi
,
to_np_pre_dict
[:,
0
]))
self
.
_to_np_dict
=
dict
(
to_np_pre_dict
)
def
dictionize_np
(
self
,
x
):
dic
=
x
.
type
.
__dict__
.
items
()
dic
=
list
(
x
.
type
.
__dict__
.
items
()
)
if
x
.
type
is
np
.
float
:
dic
[
24
]
=
0
dic
[
29
]
=
0
...
...
d2o/factory_methods.py
View file @
2126ef50
# -*- coding: utf-8 -*-
from
__future__
import
absolute_import
import
numpy
as
np
from
d2o.config
import
configuration
as
gc
from
distributed_data_object
import
distributed_data_object
from
.
distributed_data_object
import
distributed_data_object
from
strategies
import
STRATEGIES
from
.
strategies
import
STRATEGIES
__all__
=
[
'arange'
]
...
...
d2o/slicing_generator.py
View file @
2126ef50
...
...
@@ -16,6 +16,8 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from
builtins
import
next
from
builtins
import
range
import
itertools
...
...
@@ -52,7 +54,7 @@ def slicing_generator(shape, axes):
raise
ValueError
(
"ERROR: axes(axis) does not match shape."
)
axes_select
=
[
0
if
x
in
axes
else
1
for
x
,
y
in
enumerate
(
shape
)]
axes_iterables
=
\
[
range
(
y
)
for
x
,
y
in
enumerate
(
shape
)
if
x
not
in
axes
]
[
list
(
range
(
y
)
)
for
x
,
y
in
enumerate
(
shape
)
if
x
not
in
axes
]
for
current_index
in
itertools
.
product
(
*
axes_iterables
):
it_iter
=
iter
(
current_index
)
slice_list
=
[
next
(
it_iter
)
if
use_axis
else
...
...
test/test_distributed_data_object.py
View file @
2126ef50
from
__future__
import
division
from
__future__
import
print_function
# D2O
# Copyright (C) 2016 Theo Steininger
#
...
...
@@ -16,6 +18,11 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from
builtins
import
str
from
builtins
import
range
from
future.standard_library
import
PY2
,
PY3
from
numpy.testing
import
assert_equal
,
\
assert_almost_equal
,
\
assert_raises
,
\
...
...
@@ -86,12 +93,19 @@ hdf5_distribution_strategies = STRATEGIES['hdf5']
###############################################################################
binary_non_inplace_operators
=
[
'__add__'
,
'__radd__'
,
'__sub__'
,
'__rsub__'
,
'__div__'
,
'__truediv__'
,
'__rdiv__'
,
'__rtruediv__'
,
'__floordiv__'
,
'__truediv__'
,
'__rtruediv__'
,
'__floordiv__'
,
'__rfloordiv__'
,
'__mul__'
,
'__rmul__'
,
'__pow__'
,
'__rpow__'
]
binary_inplace_operators
=
[
'__iadd__'
,
'__isub__'
,
'__idiv__'
,
'__itruediv__'
,
if
PY2
:
binary_non_inplace_operators
+=
[
'__div__'
,
'__rdiv__'
]
binary_inplace_operators
=
[
'__iadd__'
,
'__isub__'
,
'__itruediv__'
,
'__ifloordiv__'
,
'__imul__'
,
'__ipow__'
]
if
PY2
:
binary_inplace_operators
+=
[
'__idiv__'
]
comparison_operators
=
[
'__ne__'
,
'__lt__'
,
'__le__'
,
'__eq__'
,
'__ge__'
,
'__gt__'
,
]
...
...
@@ -262,13 +276,12 @@ class Test_Globaltype_Initialization(unittest.TestCase):
(),
np
.
dtype
(
'float64'
)],
],
global_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
def
test_special_init_cases
(
self
,
(
global_data
,
global_shape
,
dtype
,
expected_shape
,
expected_dtype
),
distribution_strategy
):
def
test_special_init_cases
(
self
,
para
,
distribution_strategy
):
(
global_data
,
global_shape
,
dtype
,
expected_shape
,
expected_dtype
)
=
para
obj
=
distributed_data_object
(
global_data
=
global_data
,
global_shape
=
global_shape
,
...
...
@@ -283,7 +296,8 @@ class Test_Globaltype_Initialization(unittest.TestCase):
@
parameterized
.
expand
(
itertools
.
product
(
hdf5_test_paths
,
hdf5_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
def
test_hdf5_init
(
self
,
(
alias
,
path
),
distribution_strategy
):
def
test_hdf5_init
(
self
,
para
,
distribution_strategy
):
(
alias
,
path
)
=
para
obj
=
distributed_data_object
(
global_data
=
1.
,
global_shape
=
(
12
,
6
),
...
...
@@ -303,12 +317,12 @@ class Test_Globaltype_Initialization(unittest.TestCase):
(
None
,
None
,
np
.
int_
,
None
,
(
3
,))],
global_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
def
test_failed_init_on_unsufficient_parameters
(
self
,
(
global_data
,
global_shape
,
dtype
,
local_data
,
local_shape
),
def
test_failed_init_on_unsufficient_parameters
(
self
,
para
,
distribution_strategy
):
(
global_data
,
global_shape
,
dtype
,
local_data
,
local_shape
)
=
para
assert_raises
(
ValueError
,
lambda
:
distributed_data_object
(
global_data
=
global_data
,
...
...
@@ -329,13 +343,13 @@ class Test_Globaltype_Initialization(unittest.TestCase):
None
,
None
),
],
global_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
def
test_failed_init_unsufficient_params_mpi
(
self
,
(
global_data
,
global_shape
,
dtype
,
local_data
,
local_shape
),
def
test_failed_init_unsufficient_params_mpi
(
self
,
para
,
distribution_strategy
):
(
global_data
,
global_shape
,
dtype
,
local_data
,
local_shape
)
=
para
assert_raises
(
ValueError
,
lambda
:
distributed_data_object
(
global_data
=
global_data
,
...
...
@@ -481,13 +495,13 @@ class Test_Localtype_Initialization(unittest.TestCase):