Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
86674115
Commit
86674115
authored
Apr 06, 2016
by
theos
Browse files
Improved distributed_data_object's performance by avoiding unnecessary copies.
Added `copy` kwargs where missing.
parent
79d0153f
Pipeline
#1275
skipped
Changes
1
Pipelines
1
Show whitespace changes
Inline
Side-by-side
d2o/distributor_factory.py
View file @
86674115
...
...
@@ -392,6 +392,7 @@ class distributor(object):
temp_data_update
=
distributed_data_object
(
local_data
=
temp_data
,
distribution_strategy
=
'freeform'
,
copy
=
False
,
comm
=
self
.
comm
)
# disperse the data one after another
self
.
_disperse_data_primitive
(
...
...
@@ -409,7 +410,6 @@ class distributor(object):
class
_slicing_distributor
(
distributor
):
def
__init__
(
self
,
slicer
,
name
,
dtype
,
comm
,
**
remaining_parsed_kwargs
):
self
.
comm
=
comm
...
...
@@ -542,9 +542,10 @@ class _slicing_distributor(distributor):
if
isinstance
(
data
,
distributed_data_object
):
temp_d2o
=
data
.
get_data
((
slice
(
self
.
local_start
,
self
.
local_end
),),
local_keys
=
True
)
return
temp_d2o
.
get_local_data
().
astype
(
self
.
dtype
,
local_keys
=
True
,
copy
=
copy
)
return
temp_d2o
.
get_local_data
(
copy
=
False
).
astype
(
self
.
dtype
,
copy
=
False
)
else
:
return
data
[
self
.
local_start
:
self
.
local_end
].
astype
(
self
.
dtype
,
...
...
@@ -673,15 +674,18 @@ class _slicing_distributor(distributor):
l
=
data_length
if
isinstance
(
data_update
,
distributed_data_object
):
data
[
local_
to_key
]
=
data_update
.
get_data
(
local_
data_update
=
data_update
.
get_data
(
slice
(
o
[
rank
],
o
[
rank
]
+
l
),
local_keys
=
True
).
get_local_data
().
astype
(
self
.
dtype
)
).
get_local_data
(
copy
=
False
)
data
[
local_to_key
]
=
local_data_update
.
astype
(
self
.
dtype
,
copy
=
False
)
elif
np
.
isscalar
(
data_update
):
data
[
local_to_key
]
=
data_update
else
:
data
[
local_to_key
]
=
np
.
array
(
data_update
[
o
[
rank
]:
o
[
rank
]
+
l
],
copy
=
copy
).
astype
(
self
.
dtype
)
copy
=
copy
).
astype
(
self
.
dtype
,
copy
=
False
)
return
data
def
disperse_data_to_slices
(
self
,
data
,
to_slices
,
...
...
@@ -763,10 +767,12 @@ class _slicing_distributor(distributor):
update_slice
=
localized_from_slice
+
from_slices
[
1
:]
if
isinstance
(
data_update
,
distributed_data_object
):
local_data
_update
=
data_update
.
get_data
(
selected
_update
=
data_update
.
get_data
(
key
=
update_slice
,
local_keys
=
True
).
get_local_data
(
copy
=
copy
).
astype
(
self
.
dtype
)
local_keys
=
True
)
local_data_update
=
selected_update
.
get_local_data
(
copy
=
False
)
local_data_update
=
local_data_update
.
astype
(
self
.
dtype
,
copy
=
False
)
if
np
.
prod
(
np
.
shape
(
local_data_update
))
!=
0
:
data
[
local_to_slice
]
=
local_data_update
# elif np.isscalar(data_update):
...
...
@@ -776,9 +782,10 @@ class _slicing_distributor(distributor):
if
np
.
prod
(
np
.
shape
(
local_data_update
))
!=
0
:
data
[
local_to_slice
]
=
np
.
array
(
local_data_update
,
copy
=
copy
).
astype
(
self
.
dtype
)
copy
=
copy
).
astype
(
self
.
dtype
,
copy
=
False
)
def
collect_data
(
self
,
data
,
key
,
local_keys
=
False
,
**
kwargs
):
def
collect_data
(
self
,
data
,
key
,
local_keys
=
False
,
copy
=
True
,
**
kwargs
):
# collect_data supports three types of keys
# Case 1: key is a slicing/index tuple
# Case 2: key is a boolean-array of the same shape as self
...
...
@@ -793,7 +800,8 @@ class _slicing_distributor(distributor):
comm
=
self
.
comm
if
local_keys
is
False
:
return
self
.
_collect_data_primitive
(
data
,
key
,
found
,
found_boolean
,
**
kwargs
)
found_boolean
,
copy
=
copy
,
**
kwargs
)
else
:
# assert that all keys are from same type
found_list
=
comm
.
allgather
(
found
)
...
...
@@ -817,7 +825,7 @@ class _slicing_distributor(distributor):
# build the locally fed d2o
temp_d2o
=
self
.
_collect_data_primitive
(
data
,
temp_key
,
found
,
found_boolean
,
**
kwargs
)
copy
=
copy
,
**
kwargs
)
# collect the data stored in the d2o to the individual target
# rank
temp_data
=
temp_d2o
.
get_full_data
(
target_rank
=
i
)
...
...
@@ -827,17 +835,19 @@ class _slicing_distributor(distributor):
return_d2o
=
distributed_data_object
(
local_data
=
individual_data
,
distribution_strategy
=
'freeform'
,
copy
=
False
,
comm
=
self
.
comm
)
return
return_d2o
def
_collect_data_primitive
(
self
,
data
,
key
,
found
,
found_boolean
,
**
kwargs
):
copy
=
True
,
**
kwargs
):
# Case 1: key is a slice-tuple. Hence, the basic indexing/slicing
# machinery will be used
if
found
==
'slicetuple'
:
return
self
.
collect_data_from_slices
(
data
=
data
,
slice_objects
=
key
,
copy
=
copy
,
**
kwargs
)
# Case 2: key is an array
elif
(
found
==
'ndarray'
or
found
==
'd2o'
):
...
...
@@ -845,6 +855,7 @@ class _slicing_distributor(distributor):
if
found_boolean
:
return
self
.
collect_data_from_bool
(
data
=
data
,
boolean_key
=
key
,
copy
=
copy
,
**
kwargs
)
# Case 2.2: The array is not boolean. Only 1-dimensional
# advanced slicing is supported.
...
...
@@ -854,16 +865,18 @@ class _slicing_distributor(distributor):
"WARNING: Only one-dimensional advanced indexing "
+
"is supported"
))
# Make a recursive call in order to trigger the 'list'-section
return
self
.
collect_data
(
data
=
data
,
key
=
[
key
],
**
kwargs
)
return
self
.
collect_data
(
data
=
data
,
key
=
[
key
],
copy
=
copy
,
**
kwargs
)
# Case 3 : key is a list. This list is interpreted as one-dimensional
# advanced indexing list.
elif
found
==
'indexinglist'
:
return
self
.
collect_data_from_list
(
data
=
data
,
list_key
=
key
,
copy
=
copy
,
**
kwargs
)
def
collect_data_from_list
(
self
,
data
,
list_key
,
**
kwargs
):
def
collect_data_from_list
(
self
,
data
,
list_key
,
copy
=
True
,
**
kwargs
):
if
list_key
==
[]:
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: key == [] is an unsupported key!"
))
...
...
@@ -872,6 +885,7 @@ class _slicing_distributor(distributor):
global_result
=
distributed_data_object
(
local_data
=
local_result
,
distribution_strategy
=
'freeform'
,
copy
=
copy
,
comm
=
self
.
comm
)
return
global_result
...
...
@@ -1000,12 +1014,13 @@ class _slicing_distributor(distributor):
return
result
def
collect_data_from_bool
(
self
,
data
,
boolean_key
,
**
kwargs
):
def
collect_data_from_bool
(
self
,
data
,
boolean_key
,
copy
=
True
,
**
kwargs
):
local_boolean_key
=
self
.
extract_local_data
(
boolean_key
)
local_result
=
data
[
local_boolean_key
]
global_result
=
distributed_data_object
(
local_data
=
local_result
,
distribution_strategy
=
'freeform'
,
copy
=
copy
,
comm
=
self
.
comm
)
return
global_result
...
...
@@ -1023,7 +1038,7 @@ class _slicing_distributor(distributor):
comm
.
barrier
()
return
new_data
def
collect_data_from_slices
(
self
,
data
,
slice_objects
,
def
collect_data_from_slices
(
self
,
data
,
slice_objects
,
copy
=
True
,
target_rank
=
'all'
):
(
slice_objects
,
sliceified
)
=
self
.
_sliceify
(
slice_objects
)
...
...
@@ -1046,6 +1061,7 @@ class _slicing_distributor(distributor):
global_result
=
distributed_data_object
(
local_data
=
local_result
,
distribution_strategy
=
'freeform'
,
copy
=
copy
,
comm
=
self
.
comm
)
return
self
.
_defold
(
global_result
,
sliceified
)
...
...
@@ -1357,7 +1373,7 @@ class _slicing_distributor(distributor):
# low level mess!!
if
isinstance
(
in_data
,
distributed_data_object
):
local_data
=
in_data
.
data
local_data
=
in_data
.
get_local_data
(
copy
=
False
)
elif
isinstance
(
in_data
,
np
.
ndarray
)
==
False
:
local_data
=
np
.
array
(
in_data
,
copy
=
False
)
in_data
=
local_data
...
...
@@ -1397,17 +1413,17 @@ class _slicing_distributor(distributor):
# and broadcast the shape to the others
if
sliceified
[
0
]:
# Case 1: The in_data d2o has more than one dimension
if
len
(
in_data
.
shape
)
>
1
:
local_has_data
=
(
np
.
prod
(
np
.
shape
(
in_data
.
get_
local_data
())
)
!=
0
)
if
len
(
in_data
.
shape
)
>
1
and
\
in_data
.
distribution_strategy
in
STRATEGIES
[
'slicing'
]:
local_in_data
=
in_data
.
get_local_data
(
copy
=
False
)
local_has_data
=
(
np
.
prod
(
local_
in_
data
.
shape
)
!=
0
)
local_has_data_list
=
np
.
array
(
self
.
comm
.
allgather
(
local_has_data
))
nodes_with_data
=
np
.
where
(
local_has_data_list
)[
0
]
if
np
.
shape
(
nodes_with_data
)[
0
]
>
1
:
raise
ValueError
(
"ERROR: scalar index on first dimension,
but
"
+
" more than one node has data!"
)
"ERROR: scalar index on first dimension, "
+
"
but
more than one node has data!"
)
elif
np
.
shape
(
nodes_with_data
)[
0
]
==
1
:
node_with_data
=
nodes_with_data
[
0
]
else
:
...
...
@@ -1423,6 +1439,7 @@ class _slicing_distributor(distributor):
temp_local_shape
=
np
.
array
(
broadcasted_shape
)
temp_local_shape
[
0
]
=
0
temp_local_shape
=
tuple
(
temp_local_shape
)
# Case 2: The in_data d2o is only onedimensional
else
:
# The data contained in the d2o must be stored on one
...
...
@@ -1439,6 +1456,7 @@ class _slicing_distributor(distributor):
new_data
=
distributed_data_object
(
local_data
=
reshaped_data
,
distribution_strategy
=
in_data
.
distribution_strategy
,
copy
=
False
,
comm
=
self
.
comm
)
return
new_data
else
:
...
...
@@ -1499,6 +1517,7 @@ class _slicing_distributor(distributor):
new_data
=
distributed_data_object
(
local_data
=
reshaped_data
,
distribution_strategy
=
'freeform'
,
copy
=
False
,
comm
=
self
.
comm
)
return
new_data
else
:
...
...
@@ -1714,18 +1733,32 @@ class _not_distributor(distributor):
if
np
.
isscalar
(
data_update
):
update
=
data_update
elif
isinstance
(
data_update
,
distributed_data_object
):
update
=
data_update
[
from_key
].
get_full_data
().
\
astype
(
self
.
dtype
)
update
=
data_update
[
from_key
].
get_full_data
()
update
=
update
.
astype
(
self
.
dtype
,
copy
=
False
)
else
:
if
isinstance
(
from_key
,
distributed_data_object
):
from_key
=
from_key
.
get_full_data
()
elif
isinstance
(
from_key
,
list
):
try
:
from_key
=
[
item
.
get_full_data
()
for
item
in
from_key
]
except
(
AttributeError
):
pass
update
=
np
.
array
(
data_update
,
copy
=
copy
)[
from_key
].
astype
(
self
.
dtype
)
copy
=
copy
)[
from_key
]
update
=
update
.
astype
(
self
.
dtype
,
copy
=
False
)
data
[
to_key
]
=
update
def
collect_data
(
self
,
data
,
key
,
local_keys
=
False
,
**
kwargs
):
if
isinstance
(
key
,
distributed_data_object
):
key
=
key
.
get_full_data
()
elif
isinstance
(
key
,
list
):
try
:
key
=
[
item
.
get_full_data
()
for
item
in
key
]
except
(
AttributeError
):
pass
new_data
=
data
[
key
]
if
isinstance
(
new_data
,
np
.
ndarray
):
if
local_keys
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment