Commit 86674115 authored by theos's avatar theos
Browse files

Improved distributed_data_object's performance by avoiding unnecessary copies.

Added `copy` kwargs where missing.
parent 79d0153f
Pipeline #1275 skipped
...@@ -390,9 +390,10 @@ class distributor(object): ...@@ -390,9 +390,10 @@ class distributor(object):
else: else:
temp_data = data_update temp_data = data_update
temp_data_update = distributed_data_object( temp_data_update = distributed_data_object(
local_data=temp_data, local_data=temp_data,
distribution_strategy='freeform', distribution_strategy='freeform',
comm=self.comm) copy=False,
comm=self.comm)
# disperse the data one after another # disperse the data one after another
self._disperse_data_primitive( self._disperse_data_primitive(
data=data, data=data,
...@@ -409,7 +410,6 @@ class distributor(object): ...@@ -409,7 +410,6 @@ class distributor(object):
class _slicing_distributor(distributor): class _slicing_distributor(distributor):
def __init__(self, slicer, name, dtype, comm, **remaining_parsed_kwargs): def __init__(self, slicer, name, dtype, comm, **remaining_parsed_kwargs):
self.comm = comm self.comm = comm
...@@ -542,9 +542,10 @@ class _slicing_distributor(distributor): ...@@ -542,9 +542,10 @@ class _slicing_distributor(distributor):
if isinstance(data, distributed_data_object): if isinstance(data, distributed_data_object):
temp_d2o = data.get_data((slice(self.local_start, temp_d2o = data.get_data((slice(self.local_start,
self.local_end),), self.local_end),),
local_keys=True) local_keys=True,
return temp_d2o.get_local_data().astype(self.dtype, copy=copy)
copy=copy) return temp_d2o.get_local_data(copy=False).astype(self.dtype,
copy=False)
else: else:
return data[self.local_start:self.local_end].astype( return data[self.local_start:self.local_end].astype(
self.dtype, self.dtype,
...@@ -673,15 +674,18 @@ class _slicing_distributor(distributor): ...@@ -673,15 +674,18 @@ class _slicing_distributor(distributor):
l = data_length l = data_length
if isinstance(data_update, distributed_data_object): if isinstance(data_update, distributed_data_object):
data[local_to_key] = data_update.get_data( local_data_update = data_update.get_data(
slice(o[rank], o[rank] + l), slice(o[rank], o[rank] + l),
local_keys=True local_keys=True
).get_local_data().astype(self.dtype) ).get_local_data(copy=False)
data[local_to_key] = local_data_update.astype(self.dtype,
copy=False)
elif np.isscalar(data_update): elif np.isscalar(data_update):
data[local_to_key] = data_update data[local_to_key] = data_update
else: else:
data[local_to_key] = np.array(data_update[o[rank]:o[rank] + l], data[local_to_key] = np.array(data_update[o[rank]:o[rank] + l],
copy=copy).astype(self.dtype) copy=copy).astype(self.dtype,
copy=False)
return data return data
def disperse_data_to_slices(self, data, to_slices, def disperse_data_to_slices(self, data, to_slices,
...@@ -763,10 +767,12 @@ class _slicing_distributor(distributor): ...@@ -763,10 +767,12 @@ class _slicing_distributor(distributor):
update_slice = localized_from_slice + from_slices[1:] update_slice = localized_from_slice + from_slices[1:]
if isinstance(data_update, distributed_data_object): if isinstance(data_update, distributed_data_object):
local_data_update = data_update.get_data( selected_update = data_update.get_data(
key=update_slice, key=update_slice,
local_keys=True local_keys=True)
).get_local_data(copy=copy).astype(self.dtype) local_data_update = selected_update.get_local_data(copy=False)
local_data_update = local_data_update.astype(self.dtype,
copy=False)
if np.prod(np.shape(local_data_update)) != 0: if np.prod(np.shape(local_data_update)) != 0:
data[local_to_slice] = local_data_update data[local_to_slice] = local_data_update
# elif np.isscalar(data_update): # elif np.isscalar(data_update):
...@@ -776,9 +782,10 @@ class _slicing_distributor(distributor): ...@@ -776,9 +782,10 @@ class _slicing_distributor(distributor):
if np.prod(np.shape(local_data_update)) != 0: if np.prod(np.shape(local_data_update)) != 0:
data[local_to_slice] = np.array( data[local_to_slice] = np.array(
local_data_update, local_data_update,
copy=copy).astype(self.dtype) copy=copy).astype(self.dtype,
copy=False)
def collect_data(self, data, key, local_keys=False, **kwargs): def collect_data(self, data, key, local_keys=False, copy=True, **kwargs):
# collect_data supports three types of keys # collect_data supports three types of keys
# Case 1: key is a slicing/index tuple # Case 1: key is a slicing/index tuple
# Case 2: key is a boolean-array of the same shape as self # Case 2: key is a boolean-array of the same shape as self
...@@ -793,7 +800,8 @@ class _slicing_distributor(distributor): ...@@ -793,7 +800,8 @@ class _slicing_distributor(distributor):
comm = self.comm comm = self.comm
if local_keys is False: if local_keys is False:
return self._collect_data_primitive(data, key, found, return self._collect_data_primitive(data, key, found,
found_boolean, **kwargs) found_boolean, copy=copy,
**kwargs)
else: else:
# assert that all keys are from same type # assert that all keys are from same type
found_list = comm.allgather(found) found_list = comm.allgather(found)
...@@ -817,7 +825,7 @@ class _slicing_distributor(distributor): ...@@ -817,7 +825,7 @@ class _slicing_distributor(distributor):
# build the locally fed d2o # build the locally fed d2o
temp_d2o = self._collect_data_primitive(data, temp_key, found, temp_d2o = self._collect_data_primitive(data, temp_key, found,
found_boolean, found_boolean,
**kwargs) copy=copy, **kwargs)
# collect the data stored in the d2o to the individual target # collect the data stored in the d2o to the individual target
# rank # rank
temp_data = temp_d2o.get_full_data(target_rank=i) temp_data = temp_d2o.get_full_data(target_rank=i)
...@@ -825,19 +833,21 @@ class _slicing_distributor(distributor): ...@@ -825,19 +833,21 @@ class _slicing_distributor(distributor):
individual_data = temp_data individual_data = temp_data
i += 1 i += 1
return_d2o = distributed_data_object( return_d2o = distributed_data_object(
local_data=individual_data, local_data=individual_data,
distribution_strategy='freeform', distribution_strategy='freeform',
comm=self.comm) copy=False,
comm=self.comm)
return return_d2o return return_d2o
def _collect_data_primitive(self, data, key, found, found_boolean, def _collect_data_primitive(self, data, key, found, found_boolean,
**kwargs): copy=True, **kwargs):
# Case 1: key is a slice-tuple. Hence, the basic indexing/slicing # Case 1: key is a slice-tuple. Hence, the basic indexing/slicing
# machinery will be used # machinery will be used
if found == 'slicetuple': if found == 'slicetuple':
return self.collect_data_from_slices(data=data, return self.collect_data_from_slices(data=data,
slice_objects=key, slice_objects=key,
copy=copy,
**kwargs) **kwargs)
# Case 2: key is an array # Case 2: key is an array
elif (found == 'ndarray' or found == 'd2o'): elif (found == 'ndarray' or found == 'd2o'):
...@@ -845,6 +855,7 @@ class _slicing_distributor(distributor): ...@@ -845,6 +855,7 @@ class _slicing_distributor(distributor):
if found_boolean: if found_boolean:
return self.collect_data_from_bool(data=data, return self.collect_data_from_bool(data=data,
boolean_key=key, boolean_key=key,
copy=copy,
**kwargs) **kwargs)
# Case 2.2: The array is not boolean. Only 1-dimensional # Case 2.2: The array is not boolean. Only 1-dimensional
# advanced slicing is supported. # advanced slicing is supported.
...@@ -854,16 +865,18 @@ class _slicing_distributor(distributor): ...@@ -854,16 +865,18 @@ class _slicing_distributor(distributor):
"WARNING: Only one-dimensional advanced indexing " + "WARNING: Only one-dimensional advanced indexing " +
"is supported")) "is supported"))
# Make a recursive call in order to trigger the 'list'-section # Make a recursive call in order to trigger the 'list'-section
return self.collect_data(data=data, key=[key], **kwargs) return self.collect_data(data=data, key=[key], copy=copy,
**kwargs)
# Case 3 : key is a list. This list is interpreted as one-dimensional # Case 3 : key is a list. This list is interpreted as one-dimensional
# advanced indexing list. # advanced indexing list.
elif found == 'indexinglist': elif found == 'indexinglist':
return self.collect_data_from_list(data=data, return self.collect_data_from_list(data=data,
list_key=key, list_key=key,
copy=copy,
**kwargs) **kwargs)
def collect_data_from_list(self, data, list_key, **kwargs): def collect_data_from_list(self, data, list_key, copy=True, **kwargs):
if list_key == []: if list_key == []:
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
"ERROR: key == [] is an unsupported key!")) "ERROR: key == [] is an unsupported key!"))
...@@ -872,6 +885,7 @@ class _slicing_distributor(distributor): ...@@ -872,6 +885,7 @@ class _slicing_distributor(distributor):
global_result = distributed_data_object( global_result = distributed_data_object(
local_data=local_result, local_data=local_result,
distribution_strategy='freeform', distribution_strategy='freeform',
copy=copy,
comm=self.comm) comm=self.comm)
return global_result return global_result
...@@ -1000,12 +1014,13 @@ class _slicing_distributor(distributor): ...@@ -1000,12 +1014,13 @@ class _slicing_distributor(distributor):
return result return result
def collect_data_from_bool(self, data, boolean_key, **kwargs): def collect_data_from_bool(self, data, boolean_key, copy=True, **kwargs):
local_boolean_key = self.extract_local_data(boolean_key) local_boolean_key = self.extract_local_data(boolean_key)
local_result = data[local_boolean_key] local_result = data[local_boolean_key]
global_result = distributed_data_object( global_result = distributed_data_object(
local_data=local_result, local_data=local_result,
distribution_strategy='freeform', distribution_strategy='freeform',
copy=copy,
comm=self.comm) comm=self.comm)
return global_result return global_result
...@@ -1023,7 +1038,7 @@ class _slicing_distributor(distributor): ...@@ -1023,7 +1038,7 @@ class _slicing_distributor(distributor):
comm.barrier() comm.barrier()
return new_data return new_data
def collect_data_from_slices(self, data, slice_objects, def collect_data_from_slices(self, data, slice_objects, copy=True,
target_rank='all'): target_rank='all'):
(slice_objects, sliceified) = self._sliceify(slice_objects) (slice_objects, sliceified) = self._sliceify(slice_objects)
...@@ -1046,6 +1061,7 @@ class _slicing_distributor(distributor): ...@@ -1046,6 +1061,7 @@ class _slicing_distributor(distributor):
global_result = distributed_data_object( global_result = distributed_data_object(
local_data=local_result, local_data=local_result,
distribution_strategy='freeform', distribution_strategy='freeform',
copy=copy,
comm=self.comm) comm=self.comm)
return self._defold(global_result, sliceified) return self._defold(global_result, sliceified)
...@@ -1357,7 +1373,7 @@ class _slicing_distributor(distributor): ...@@ -1357,7 +1373,7 @@ class _slicing_distributor(distributor):
# low level mess!! # low level mess!!
if isinstance(in_data, distributed_data_object): if isinstance(in_data, distributed_data_object):
local_data = in_data.data local_data = in_data.get_local_data(copy=False)
elif isinstance(in_data, np.ndarray) == False: elif isinstance(in_data, np.ndarray) == False:
local_data = np.array(in_data, copy=False) local_data = np.array(in_data, copy=False)
in_data = local_data in_data = local_data
...@@ -1397,17 +1413,17 @@ class _slicing_distributor(distributor): ...@@ -1397,17 +1413,17 @@ class _slicing_distributor(distributor):
# and broadcast the shape to the others # and broadcast the shape to the others
if sliceified[0]: if sliceified[0]:
# Case 1: The in_data d2o has more than one dimension # Case 1: The in_data d2o has more than one dimension
if len(in_data.shape) > 1: if len(in_data.shape) > 1 and \
local_has_data = (np.prod( in_data.distribution_strategy in STRATEGIES['slicing']:
np.shape( local_in_data = in_data.get_local_data(copy=False)
in_data.get_local_data())) != 0) local_has_data = (np.prod(local_in_data.shape) != 0)
local_has_data_list = np.array( local_has_data_list = np.array(
self.comm.allgather(local_has_data)) self.comm.allgather(local_has_data))
nodes_with_data = np.where(local_has_data_list)[0] nodes_with_data = np.where(local_has_data_list)[0]
if np.shape(nodes_with_data)[0] > 1: if np.shape(nodes_with_data)[0] > 1:
raise ValueError( raise ValueError(
"ERROR: scalar index on first dimension, but " + "ERROR: scalar index on first dimension, " +
" more than one node has data!") "but more than one node has data!")
elif np.shape(nodes_with_data)[0] == 1: elif np.shape(nodes_with_data)[0] == 1:
node_with_data = nodes_with_data[0] node_with_data = nodes_with_data[0]
else: else:
...@@ -1417,12 +1433,13 @@ class _slicing_distributor(distributor): ...@@ -1417,12 +1433,13 @@ class _slicing_distributor(distributor):
broadcasted_shape = (0,) * len(temp_local_shape) broadcasted_shape = (0,) * len(temp_local_shape)
else: else:
broadcasted_shape = self.comm.bcast( broadcasted_shape = self.comm.bcast(
temp_local_shape, temp_local_shape,
root=node_with_data) root=node_with_data)
if self.comm.rank != node_with_data: if self.comm.rank != node_with_data:
temp_local_shape = np.array(broadcasted_shape) temp_local_shape = np.array(broadcasted_shape)
temp_local_shape[0] = 0 temp_local_shape[0] = 0
temp_local_shape = tuple(temp_local_shape) temp_local_shape = tuple(temp_local_shape)
# Case 2: The in_data d2o is only onedimensional # Case 2: The in_data d2o is only onedimensional
else: else:
# The data contained in the d2o must be stored on one # The data contained in the d2o must be stored on one
...@@ -1439,6 +1456,7 @@ class _slicing_distributor(distributor): ...@@ -1439,6 +1456,7 @@ class _slicing_distributor(distributor):
new_data = distributed_data_object( new_data = distributed_data_object(
local_data=reshaped_data, local_data=reshaped_data,
distribution_strategy=in_data.distribution_strategy, distribution_strategy=in_data.distribution_strategy,
copy=False,
comm=self.comm) comm=self.comm)
return new_data return new_data
else: else:
...@@ -1499,6 +1517,7 @@ class _slicing_distributor(distributor): ...@@ -1499,6 +1517,7 @@ class _slicing_distributor(distributor):
new_data = distributed_data_object( new_data = distributed_data_object(
local_data=reshaped_data, local_data=reshaped_data,
distribution_strategy='freeform', distribution_strategy='freeform',
copy=False,
comm=self.comm) comm=self.comm)
return new_data return new_data
else: else:
...@@ -1714,18 +1733,32 @@ class _not_distributor(distributor): ...@@ -1714,18 +1733,32 @@ class _not_distributor(distributor):
if np.isscalar(data_update): if np.isscalar(data_update):
update = data_update update = data_update
elif isinstance(data_update, distributed_data_object): elif isinstance(data_update, distributed_data_object):
update = data_update[from_key].get_full_data().\ update = data_update[from_key].get_full_data()
astype(self.dtype) update = update.astype(self.dtype,
copy=False)
else: else:
if isinstance(from_key, distributed_data_object): if isinstance(from_key, distributed_data_object):
from_key = from_key.get_full_data() from_key = from_key.get_full_data()
elif isinstance(from_key, list):
try:
from_key = [item.get_full_data() for item in from_key]
except(AttributeError):
pass
update = np.array(data_update, update = np.array(data_update,
copy=copy)[from_key].astype(self.dtype) copy=copy)[from_key]
update = update.astype(self.dtype,
copy=False)
data[to_key] = update data[to_key] = update
def collect_data(self, data, key, local_keys=False, **kwargs): def collect_data(self, data, key, local_keys=False, **kwargs):
if isinstance(key, distributed_data_object): if isinstance(key, distributed_data_object):
key = key.get_full_data() key = key.get_full_data()
elif isinstance(key, list):
try:
key = [item.get_full_data() for item in key]
except(AttributeError):
pass
new_data = data[key] new_data = data[key]
if isinstance(new_data, np.ndarray): if isinstance(new_data, np.ndarray):
if local_keys: if local_keys:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment