Commit fd33af41 authored by theos's avatar theos
Browse files

Improved set_full_data.

parent 2621861d
......@@ -300,6 +300,41 @@ def _infer_key_type(key):
class distributor(object):
def distribute_data(self, data=None, alias=None,
path=None, copy=True, **kwargs):
'''
distribute data checks
- whether the data is located on all nodes or only on node 0
- that the shape of 'data' matches the global_shape
'''
if 'h5py' in gdi and alias is not None:
data = self.load_data(alias=alias, path=path)
if data is None:
return np.empty(self.local_shape, dtype=self.dtype)
elif np.isscalar(data):
return np.ones(self.local_shape, dtype=self.dtype)*data
elif isinstance(data, np.ndarray) or \
isinstance(data, distributed_data_object):
data = self.extract_local_data(data)
if data.shape is not self.local_shape:
copy = True
if copy:
result_data = np.empty(self.local_shape, dtype=self.dtype)
result_data[:] = data
else:
result_data = data
return result_data
else:
new_data = np.array(data)
return new_data.astype(self.dtype,
copy=copy).reshape(self.local_shape)
def disperse_data(self, data, to_key, data_update, from_key=None,
local_keys=False, copy=True, **kwargs):
# Check which keys we got:
......@@ -456,8 +491,13 @@ class distributor(object):
# bincounts
for slice_list in slicing_generator(flat_shape,
axes=(len(flat_shape)-1, )):
local_counts[slice_list] = np.bincount(data[slice_list],
weights=local_weights,
if local_weights is not None:
current_weights = local_weights[slice_list]
else:
current_weights = None
local_counts[slice_list] = np.bincount(
data[slice_list],
weights=current_weights,
minlength=length)
# restore the original ordering
......@@ -465,9 +505,10 @@ class distributor(object):
if axis is not None:
# axis has been sorted above
insert_position = axis[0]
new_ndim = len(local_counts.shape)
return_order = (range(0, insert_position) +
[ndim-1, ] +
range(insert_position, ndim-1))
[new_ndim-1, ] +
range(insert_position, new_ndim-1))
local_counts = np.ascontiguousarray(
local_counts.transpose(return_order))
return self._combine_local_bincount_counts(obj, local_counts, axis)
......@@ -714,43 +755,69 @@ class _slicing_distributor(distributor):
return result
def distribute_data(self, data=None, alias=None,
path=None, copy=True, **kwargs):
'''
distribute data checks
- whether the data is located on all nodes or only on node 0
- that the shape of 'data' matches the global_shape
'''
comm = self.comm
if 'h5py' in gdi and alias is not None:
data = self.load_data(alias=alias, path=path)
local_data_available_Q = (data is not None)
data_available_Q = np.array(comm.allgather(local_data_available_Q))
if np.all(data_available_Q == False):
return np.empty(self.local_shape, dtype=self.dtype, order='C')
# if all nodes got data, we assume that it is the right data and
# store it individually.
elif np.all(data_available_Q == True):
if isinstance(data, distributed_data_object):
temp_d2o = data.get_data((slice(self.local_start,
self.local_end),),
local_keys=True,
copy=copy)
return temp_d2o.get_local_data(copy=False).astype(self.dtype,
copy=False)
elif np.isscalar(data):
return np.ones(self.local_shape, dtype=self.dtype)*data
else:
return data[self.local_start:self.local_end].astype(
self.dtype,
copy=copy)
else:
raise ValueError(
"ERROR: distribute_data must get data on all nodes!")
# def distribute_data(self, data=None, alias=None,
# path=None, copy=True, **kwargs):
# '''
# distribute data checks
# - whether the data is located on all nodes or only on node 0
# - that the shape of 'data' matches the global_shape
# '''
#
## comm = self.comm
#
# if 'h5py' in gdi and alias is not None:
# data = self.load_data(alias=alias, path=path)
#
# if data is None:
# return np.empty(self.global_shape, dtype=self.dtype)
# elif np.isscalar(data):
# return np.ones(self.global_shape, dtype=self.dtype)*data
# copy = False
# elif isinstance(data, np.ndarray) or \
# isinstance(data, distributed_data_object):
# data = self.extract_local_data(data)
#
# if data.shape is not self.local_shape:
# copy = True
#
# if copy:
# result_data = np.empty(self.local_shape, dtype=self.dtype)
# result_data[:] = data
# else:
# result_data = data
#
# return result_data
#
# else:
# new_data = np.array(data)
# return new_data.astype(self.dtype,
# copy=copy).reshape(self.global_shape)
#
#
## local_data_available_Q = (data is not None)
## data_available_Q = np.array(comm.allgather(local_data_available_Q))
##
## if np.all(data_available_Q == False):
## return np.empty(self.local_shape, dtype=self.dtype, order='C')
## # if all nodes got data, we assume that it is the right data and
## # store it individually.
## elif np.all(data_available_Q == True):
## if isinstance(data, distributed_data_object):
## temp_d2o = data.get_data((slice(self.local_start,
## self.local_end),),
## local_keys=True,
## copy=copy)
## return temp_d2o.get_local_data(copy=False).astype(self.dtype,
## copy=False)
## elif np.isscalar(data):
## return np.ones(self.local_shape, dtype=self.dtype)*data
## else:
## return data[self.local_start:self.local_end].astype(
## self.dtype,
## copy=copy)
## else:
## raise ValueError(
## "ERROR: distribute_data must get data on all nodes!")
def _disperse_data_primitive(self, data, to_key, data_update, from_key,
copy, to_found, to_found_boolean, from_found,
......@@ -1403,14 +1470,15 @@ class _slicing_distributor(distributor):
# if shape-casting was successfull, extract the data
else:
if isinstance(data_object, distributed_data_object):
# If the first dimension matches only via broadcasting...
# Case 1: ...do broadcasting. This procedure does not depend on the
# array type (ndarray or d2o)
# Case 1: ...do broadcasting.
if matching_dimensions[0] == False:
extracted_data = data_object[0:1]
# Case 2: First dimension fits directly and data_object is a d2o
elif isinstance(data_object, distributed_data_object):
extracted_data = data_object.get_full_data()
extracted_data = extracted_data[0]
else:
# Case 2: First dimension fits directly and data_object is
# a d2o
# Check if both d2os have the same slicing
# If the distributor is exactly the same, extract the data
if self is data_object.distributor:
......@@ -1454,7 +1522,11 @@ class _slicing_distributor(distributor):
# extracted_data = extracted_data.get_local_data()
#
#
# Case 2: np-array
# If the first dimension matches only via broadcasting
# ...do broadcasting.
elif matching_dimensions[0] == False:
extracted_data = data_object[0:1]
# Case 3: First dimension fits directly and data_object is an
# generic array
else:
......@@ -1464,6 +1536,7 @@ class _slicing_distributor(distributor):
return extracted_data
def _reshape_foreign_data(self, foreign):
# Case 1:
# check if the shapes match directly
if self.global_shape == foreign.shape:
......@@ -2043,24 +2116,42 @@ class _not_distributor(distributor):
return result_object
def distribute_data(self, data, alias=None, path=None, copy=True,
**kwargs):
if 'h5py' in gdi and alias is not None:
data = self.load_data(alias=alias, path=path)
if data is None:
return np.empty(self.global_shape, dtype=self.dtype)
elif isinstance(data, distributed_data_object):
new_data = data.get_full_data()
elif isinstance(data, np.ndarray):
new_data = data
elif np.isscalar(data):
new_data = np.ones(self.global_shape, dtype=self.dtype)*data
copy = False
else:
new_data = np.array(data)
return new_data.astype(self.dtype,
copy=copy).reshape(self.global_shape)
# def distribute_data(self, data, alias=None, path=None, copy=True,
# **kwargs):
# if 'h5py' in gdi and alias is not None:
# data = self.load_data(alias=alias, path=path)
#
# if data is None:
# return np.empty(self.global_shape, dtype=self.dtype)
# elif np.isscalar(data):
# return np.ones(self.global_shape, dtype=self.dtype)*data
# copy = False
# elif isinstance(data, np.ndarray) or \
# isinstance(data, distributed_data_object):
# data = self.extract_local_data(data)
# result_data = np.empty(self.local_shape, dtype=self.dtype)
# result_data[:] = data
# return result_data
#
# else:
# new_data = np.array(data)
# return new_data.astype(self.dtype,
# copy=copy).reshape(self.global_shape)
#
#
## if data is None:
## return np.empty(self.global_shape, dtype=self.dtype)
## elif isinstance(data, distributed_data_object):
## new_data = data.get_full_data()
## elif isinstance(data, np.ndarray):
## new_data = data
## elif np.isscalar(data):
## new_data = np.ones(self.global_shape, dtype=self.dtype)*data
## copy = False
## else:
## new_data = np.array(data)
## return new_data.astype(self.dtype,
## copy=copy).reshape(self.global_shape)
def _disperse_data_primitive(self, data, to_key, data_update, from_key,
copy, to_found, to_found_boolean, from_found,
......@@ -2118,9 +2209,15 @@ class _not_distributor(distributor):
def extract_local_data(self, data_object):
if isinstance(data_object, distributed_data_object):
return data_object.get_full_data().reshape(self.global_shape)
result_data = data_object.get_full_data()
else:
return np.array(data_object)[:].reshape(self.global_shape)
result_data = np.array(data_object)[:]
try:
result_data = result_data.reshape(self.global_shape)
except ValueError:
pass
return result_data
def flatten(self, data, inplace=False):
if inplace:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment