Commit 86674115 authored by theos's avatar theos
Browse files

Improved distributed_data_object's performance by avoiding unnecessary copies.

Added `copy` kwargs where missing.
parent 79d0153f
Pipeline #1275 skipped
......@@ -390,9 +390,10 @@ class distributor(object):
else:
temp_data = data_update
temp_data_update = distributed_data_object(
local_data=temp_data,
distribution_strategy='freeform',
comm=self.comm)
local_data=temp_data,
distribution_strategy='freeform',
copy=False,
comm=self.comm)
# disperse the data one after another
self._disperse_data_primitive(
data=data,
......@@ -409,7 +410,6 @@ class distributor(object):
class _slicing_distributor(distributor):
def __init__(self, slicer, name, dtype, comm, **remaining_parsed_kwargs):
self.comm = comm
......@@ -542,9 +542,10 @@ class _slicing_distributor(distributor):
if isinstance(data, distributed_data_object):
temp_d2o = data.get_data((slice(self.local_start,
self.local_end),),
local_keys=True)
return temp_d2o.get_local_data().astype(self.dtype,
copy=copy)
local_keys=True,
copy=copy)
return temp_d2o.get_local_data(copy=False).astype(self.dtype,
copy=False)
else:
return data[self.local_start:self.local_end].astype(
self.dtype,
......@@ -673,15 +674,18 @@ class _slicing_distributor(distributor):
l = data_length
if isinstance(data_update, distributed_data_object):
data[local_to_key] = data_update.get_data(
local_data_update = data_update.get_data(
slice(o[rank], o[rank] + l),
local_keys=True
).get_local_data().astype(self.dtype)
).get_local_data(copy=False)
data[local_to_key] = local_data_update.astype(self.dtype,
copy=False)
elif np.isscalar(data_update):
data[local_to_key] = data_update
else:
data[local_to_key] = np.array(data_update[o[rank]:o[rank] + l],
copy=copy).astype(self.dtype)
copy=copy).astype(self.dtype,
copy=False)
return data
def disperse_data_to_slices(self, data, to_slices,
......@@ -763,10 +767,12 @@ class _slicing_distributor(distributor):
update_slice = localized_from_slice + from_slices[1:]
if isinstance(data_update, distributed_data_object):
local_data_update = data_update.get_data(
selected_update = data_update.get_data(
key=update_slice,
local_keys=True
).get_local_data(copy=copy).astype(self.dtype)
local_keys=True)
local_data_update = selected_update.get_local_data(copy=False)
local_data_update = local_data_update.astype(self.dtype,
copy=False)
if np.prod(np.shape(local_data_update)) != 0:
data[local_to_slice] = local_data_update
# elif np.isscalar(data_update):
......@@ -776,9 +782,10 @@ class _slicing_distributor(distributor):
if np.prod(np.shape(local_data_update)) != 0:
data[local_to_slice] = np.array(
local_data_update,
copy=copy).astype(self.dtype)
copy=copy).astype(self.dtype,
copy=False)
def collect_data(self, data, key, local_keys=False, **kwargs):
def collect_data(self, data, key, local_keys=False, copy=True, **kwargs):
# collect_data supports three types of keys
# Case 1: key is a slicing/index tuple
# Case 2: key is a boolean-array of the same shape as self
......@@ -793,7 +800,8 @@ class _slicing_distributor(distributor):
comm = self.comm
if local_keys is False:
return self._collect_data_primitive(data, key, found,
found_boolean, **kwargs)
found_boolean, copy=copy,
**kwargs)
else:
# assert that all keys are from same type
found_list = comm.allgather(found)
......@@ -817,7 +825,7 @@ class _slicing_distributor(distributor):
# build the locally fed d2o
temp_d2o = self._collect_data_primitive(data, temp_key, found,
found_boolean,
**kwargs)
copy=copy, **kwargs)
# collect the data stored in the d2o to the individual target
# rank
temp_data = temp_d2o.get_full_data(target_rank=i)
......@@ -825,19 +833,21 @@ class _slicing_distributor(distributor):
individual_data = temp_data
i += 1
return_d2o = distributed_data_object(
local_data=individual_data,
distribution_strategy='freeform',
comm=self.comm)
local_data=individual_data,
distribution_strategy='freeform',
copy=False,
comm=self.comm)
return return_d2o
def _collect_data_primitive(self, data, key, found, found_boolean,
**kwargs):
copy=True, **kwargs):
# Case 1: key is a slice-tuple. Hence, the basic indexing/slicing
# machinery will be used
if found == 'slicetuple':
return self.collect_data_from_slices(data=data,
slice_objects=key,
copy=copy,
**kwargs)
# Case 2: key is an array
elif (found == 'ndarray' or found == 'd2o'):
......@@ -845,6 +855,7 @@ class _slicing_distributor(distributor):
if found_boolean:
return self.collect_data_from_bool(data=data,
boolean_key=key,
copy=copy,
**kwargs)
# Case 2.2: The array is not boolean. Only 1-dimensional
# advanced slicing is supported.
......@@ -854,16 +865,18 @@ class _slicing_distributor(distributor):
"WARNING: Only one-dimensional advanced indexing " +
"is supported"))
# Make a recursive call in order to trigger the 'list'-section
return self.collect_data(data=data, key=[key], **kwargs)
return self.collect_data(data=data, key=[key], copy=copy,
**kwargs)
# Case 3 : key is a list. This list is interpreted as one-dimensional
# advanced indexing list.
elif found == 'indexinglist':
return self.collect_data_from_list(data=data,
list_key=key,
copy=copy,
**kwargs)
def collect_data_from_list(self, data, list_key, **kwargs):
def collect_data_from_list(self, data, list_key, copy=True, **kwargs):
if list_key == []:
raise ValueError(about._errors.cstring(
"ERROR: key == [] is an unsupported key!"))
......@@ -872,6 +885,7 @@ class _slicing_distributor(distributor):
global_result = distributed_data_object(
local_data=local_result,
distribution_strategy='freeform',
copy=copy,
comm=self.comm)
return global_result
......@@ -1000,12 +1014,13 @@ class _slicing_distributor(distributor):
return result
def collect_data_from_bool(self, data, boolean_key, **kwargs):
def collect_data_from_bool(self, data, boolean_key, copy=True, **kwargs):
local_boolean_key = self.extract_local_data(boolean_key)
local_result = data[local_boolean_key]
global_result = distributed_data_object(
local_data=local_result,
distribution_strategy='freeform',
copy=copy,
comm=self.comm)
return global_result
......@@ -1023,7 +1038,7 @@ class _slicing_distributor(distributor):
comm.barrier()
return new_data
def collect_data_from_slices(self, data, slice_objects,
def collect_data_from_slices(self, data, slice_objects, copy=True,
target_rank='all'):
(slice_objects, sliceified) = self._sliceify(slice_objects)
......@@ -1046,6 +1061,7 @@ class _slicing_distributor(distributor):
global_result = distributed_data_object(
local_data=local_result,
distribution_strategy='freeform',
copy=copy,
comm=self.comm)
return self._defold(global_result, sliceified)
......@@ -1357,7 +1373,7 @@ class _slicing_distributor(distributor):
# low level mess!!
if isinstance(in_data, distributed_data_object):
local_data = in_data.data
local_data = in_data.get_local_data(copy=False)
elif isinstance(in_data, np.ndarray) == False:
local_data = np.array(in_data, copy=False)
in_data = local_data
......@@ -1397,17 +1413,17 @@ class _slicing_distributor(distributor):
# and broadcast the shape to the others
if sliceified[0]:
# Case 1: The in_data d2o has more than one dimension
if len(in_data.shape) > 1:
local_has_data = (np.prod(
np.shape(
in_data.get_local_data())) != 0)
if len(in_data.shape) > 1 and \
in_data.distribution_strategy in STRATEGIES['slicing']:
local_in_data = in_data.get_local_data(copy=False)
local_has_data = (np.prod(local_in_data.shape) != 0)
local_has_data_list = np.array(
self.comm.allgather(local_has_data))
self.comm.allgather(local_has_data))
nodes_with_data = np.where(local_has_data_list)[0]
if np.shape(nodes_with_data)[0] > 1:
raise ValueError(
"ERROR: scalar index on first dimension, but " +
" more than one node has data!")
"ERROR: scalar index on first dimension, " +
"but more than one node has data!")
elif np.shape(nodes_with_data)[0] == 1:
node_with_data = nodes_with_data[0]
else:
......@@ -1417,12 +1433,13 @@ class _slicing_distributor(distributor):
broadcasted_shape = (0,) * len(temp_local_shape)
else:
broadcasted_shape = self.comm.bcast(
temp_local_shape,
root=node_with_data)
temp_local_shape,
root=node_with_data)
if self.comm.rank != node_with_data:
temp_local_shape = np.array(broadcasted_shape)
temp_local_shape[0] = 0
temp_local_shape = tuple(temp_local_shape)
# Case 2: The in_data d2o is only onedimensional
else:
# The data contained in the d2o must be stored on one
......@@ -1439,6 +1456,7 @@ class _slicing_distributor(distributor):
new_data = distributed_data_object(
local_data=reshaped_data,
distribution_strategy=in_data.distribution_strategy,
copy=False,
comm=self.comm)
return new_data
else:
......@@ -1499,6 +1517,7 @@ class _slicing_distributor(distributor):
new_data = distributed_data_object(
local_data=reshaped_data,
distribution_strategy='freeform',
copy=False,
comm=self.comm)
return new_data
else:
......@@ -1714,18 +1733,32 @@ class _not_distributor(distributor):
if np.isscalar(data_update):
update = data_update
elif isinstance(data_update, distributed_data_object):
update = data_update[from_key].get_full_data().\
astype(self.dtype)
update = data_update[from_key].get_full_data()
update = update.astype(self.dtype,
copy=False)
else:
if isinstance(from_key, distributed_data_object):
from_key = from_key.get_full_data()
elif isinstance(from_key, list):
try:
from_key = [item.get_full_data() for item in from_key]
except(AttributeError):
pass
update = np.array(data_update,
copy=copy)[from_key].astype(self.dtype)
copy=copy)[from_key]
update = update.astype(self.dtype,
copy=False)
data[to_key] = update
def collect_data(self, data, key, local_keys=False, **kwargs):
if isinstance(key, distributed_data_object):
key = key.get_full_data()
elif isinstance(key, list):
try:
key = [item.get_full_data() for item in key]
except(AttributeError):
pass
new_data = data[key]
if isinstance(new_data, np.ndarray):
if local_keys:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment