Commit fd33af41 authored by theos's avatar theos
Browse files

Improved set_full_data.

parent 2621861d
...@@ -300,6 +300,41 @@ def _infer_key_type(key): ...@@ -300,6 +300,41 @@ def _infer_key_type(key):
class distributor(object): class distributor(object):
def distribute_data(self, data=None, alias=None,
path=None, copy=True, **kwargs):
'''
distribute data checks
- whether the data is located on all nodes or only on node 0
- that the shape of 'data' matches the global_shape
'''
if 'h5py' in gdi and alias is not None:
data = self.load_data(alias=alias, path=path)
if data is None:
return np.empty(self.local_shape, dtype=self.dtype)
elif np.isscalar(data):
return np.ones(self.local_shape, dtype=self.dtype)*data
elif isinstance(data, np.ndarray) or \
isinstance(data, distributed_data_object):
data = self.extract_local_data(data)
if data.shape is not self.local_shape:
copy = True
if copy:
result_data = np.empty(self.local_shape, dtype=self.dtype)
result_data[:] = data
else:
result_data = data
return result_data
else:
new_data = np.array(data)
return new_data.astype(self.dtype,
copy=copy).reshape(self.local_shape)
def disperse_data(self, data, to_key, data_update, from_key=None, def disperse_data(self, data, to_key, data_update, from_key=None,
local_keys=False, copy=True, **kwargs): local_keys=False, copy=True, **kwargs):
# Check which keys we got: # Check which keys we got:
...@@ -456,18 +491,24 @@ class distributor(object): ...@@ -456,18 +491,24 @@ class distributor(object):
# bincounts # bincounts
for slice_list in slicing_generator(flat_shape, for slice_list in slicing_generator(flat_shape,
axes=(len(flat_shape)-1, )): axes=(len(flat_shape)-1, )):
local_counts[slice_list] = np.bincount(data[slice_list], if local_weights is not None:
weights=local_weights, current_weights = local_weights[slice_list]
minlength=length) else:
current_weights = None
local_counts[slice_list] = np.bincount(
data[slice_list],
weights=current_weights,
minlength=length)
# restore the original ordering # restore the original ordering
# place the bincount stuff at the location of the first `axis` entry # place the bincount stuff at the location of the first `axis` entry
if axis is not None: if axis is not None:
# axis has been sorted above # axis has been sorted above
insert_position = axis[0] insert_position = axis[0]
new_ndim = len(local_counts.shape)
return_order = (range(0, insert_position) + return_order = (range(0, insert_position) +
[ndim-1, ] + [new_ndim-1, ] +
range(insert_position, ndim-1)) range(insert_position, new_ndim-1))
local_counts = np.ascontiguousarray( local_counts = np.ascontiguousarray(
local_counts.transpose(return_order)) local_counts.transpose(return_order))
return self._combine_local_bincount_counts(obj, local_counts, axis) return self._combine_local_bincount_counts(obj, local_counts, axis)
...@@ -714,43 +755,69 @@ class _slicing_distributor(distributor): ...@@ -714,43 +755,69 @@ class _slicing_distributor(distributor):
return result return result
def distribute_data(self, data=None, alias=None, # def distribute_data(self, data=None, alias=None,
path=None, copy=True, **kwargs): # path=None, copy=True, **kwargs):
''' # '''
distribute data checks # distribute data checks
- whether the data is located on all nodes or only on node 0 # - whether the data is located on all nodes or only on node 0
- that the shape of 'data' matches the global_shape # - that the shape of 'data' matches the global_shape
''' # '''
#
comm = self.comm ## comm = self.comm
#
if 'h5py' in gdi and alias is not None: # if 'h5py' in gdi and alias is not None:
data = self.load_data(alias=alias, path=path) # data = self.load_data(alias=alias, path=path)
#
local_data_available_Q = (data is not None) # if data is None:
data_available_Q = np.array(comm.allgather(local_data_available_Q)) # return np.empty(self.global_shape, dtype=self.dtype)
# elif np.isscalar(data):
if np.all(data_available_Q == False): # return np.ones(self.global_shape, dtype=self.dtype)*data
return np.empty(self.local_shape, dtype=self.dtype, order='C') # copy = False
# if all nodes got data, we assume that it is the right data and # elif isinstance(data, np.ndarray) or \
# store it individually. # isinstance(data, distributed_data_object):
elif np.all(data_available_Q == True): # data = self.extract_local_data(data)
if isinstance(data, distributed_data_object): #
temp_d2o = data.get_data((slice(self.local_start, # if data.shape is not self.local_shape:
self.local_end),), # copy = True
local_keys=True, #
copy=copy) # if copy:
return temp_d2o.get_local_data(copy=False).astype(self.dtype, # result_data = np.empty(self.local_shape, dtype=self.dtype)
copy=False) # result_data[:] = data
elif np.isscalar(data): # else:
return np.ones(self.local_shape, dtype=self.dtype)*data # result_data = data
else: #
return data[self.local_start:self.local_end].astype( # return result_data
self.dtype, #
copy=copy) # else:
else: # new_data = np.array(data)
raise ValueError( # return new_data.astype(self.dtype,
"ERROR: distribute_data must get data on all nodes!") # copy=copy).reshape(self.global_shape)
#
#
## local_data_available_Q = (data is not None)
## data_available_Q = np.array(comm.allgather(local_data_available_Q))
##
## if np.all(data_available_Q == False):
## return np.empty(self.local_shape, dtype=self.dtype, order='C')
## # if all nodes got data, we assume that it is the right data and
## # store it individually.
## elif np.all(data_available_Q == True):
## if isinstance(data, distributed_data_object):
## temp_d2o = data.get_data((slice(self.local_start,
## self.local_end),),
## local_keys=True,
## copy=copy)
## return temp_d2o.get_local_data(copy=False).astype(self.dtype,
## copy=False)
## elif np.isscalar(data):
## return np.ones(self.local_shape, dtype=self.dtype)*data
## else:
## return data[self.local_start:self.local_end].astype(
## self.dtype,
## copy=copy)
## else:
## raise ValueError(
## "ERROR: distribute_data must get data on all nodes!")
def _disperse_data_primitive(self, data, to_key, data_update, from_key, def _disperse_data_primitive(self, data, to_key, data_update, from_key,
copy, to_found, to_found_boolean, from_found, copy, to_found, to_found_boolean, from_found,
...@@ -1403,37 +1470,38 @@ class _slicing_distributor(distributor): ...@@ -1403,37 +1470,38 @@ class _slicing_distributor(distributor):
# if shape-casting was successfull, extract the data # if shape-casting was successfull, extract the data
else: else:
# If the first dimension matches only via broadcasting... if isinstance(data_object, distributed_data_object):
# Case 1: ...do broadcasting. This procedure does not depend on the # If the first dimension matches only via broadcasting...
# array type (ndarray or d2o) # Case 1: ...do broadcasting.
if matching_dimensions[0] == False: if matching_dimensions[0] == False:
extracted_data = data_object[0:1] extracted_data = data_object.get_full_data()
extracted_data = extracted_data[0]
# Case 2: First dimension fits directly and data_object is a d2o else:
elif isinstance(data_object, distributed_data_object): # Case 2: First dimension fits directly and data_object is
# Check if both d2os have the same slicing # a d2o
# If the distributor is exactly the same, extract the data # Check if both d2os have the same slicing
if self is data_object.distributor: # If the distributor is exactly the same, extract the data
# Simply take the local data if self is data_object.distributor:
extracted_data = data_object.data # Simply take the local data
# If the distributor is not exactly the same, check if the
# geometry matches if it is a slicing distributor
# -> comm and local shapes
elif (isinstance(data_object.distributor,
_slicing_distributor) and
(self.comm is data_object.distributor.comm) and
(np.all(self.all_local_slices ==
data_object.distributor.all_local_slices))):
extracted_data = data_object.data extracted_data = data_object.data
# If the distributor is not exactly the same, check if the
# geometry matches if it is a slicing distributor
# -> comm and local shapes
elif (isinstance(data_object.distributor,
_slicing_distributor) and
(self.comm is data_object.distributor.comm) and
(np.all(self.all_local_slices ==
data_object.distributor.all_local_slices))):
extracted_data = data_object.data
else: else:
# Case 2: no. All nodes extract their local slice from the # Case 2: no. All nodes extract their local slice from the
# data_object # data_object
extracted_data =\ extracted_data =\
data_object.get_data(slice(self.local_start, data_object.get_data(slice(self.local_start,
self.local_end), self.local_end),
local_keys=True) local_keys=True)
extracted_data = extracted_data.get_local_data() extracted_data = extracted_data.get_local_data()
# # Check if the distributor and the comm match # # Check if the distributor and the comm match
...@@ -1454,7 +1522,11 @@ class _slicing_distributor(distributor): ...@@ -1454,7 +1522,11 @@ class _slicing_distributor(distributor):
# extracted_data = extracted_data.get_local_data() # extracted_data = extracted_data.get_local_data()
# #
# #
# Case 2: np-array
# If the first dimension matches only via broadcasting
# ...do broadcasting.
elif matching_dimensions[0] == False:
extracted_data = data_object[0:1]
# Case 3: First dimension fits directly and data_object is an # Case 3: First dimension fits directly and data_object is an
# generic array # generic array
else: else:
...@@ -1464,6 +1536,7 @@ class _slicing_distributor(distributor): ...@@ -1464,6 +1536,7 @@ class _slicing_distributor(distributor):
return extracted_data return extracted_data
def _reshape_foreign_data(self, foreign): def _reshape_foreign_data(self, foreign):
# Case 1: # Case 1:
# check if the shapes match directly # check if the shapes match directly
if self.global_shape == foreign.shape: if self.global_shape == foreign.shape:
...@@ -2043,24 +2116,42 @@ class _not_distributor(distributor): ...@@ -2043,24 +2116,42 @@ class _not_distributor(distributor):
return result_object return result_object
def distribute_data(self, data, alias=None, path=None, copy=True, # def distribute_data(self, data, alias=None, path=None, copy=True,
**kwargs): # **kwargs):
if 'h5py' in gdi and alias is not None: # if 'h5py' in gdi and alias is not None:
data = self.load_data(alias=alias, path=path) # data = self.load_data(alias=alias, path=path)
#
if data is None: # if data is None:
return np.empty(self.global_shape, dtype=self.dtype) # return np.empty(self.global_shape, dtype=self.dtype)
elif isinstance(data, distributed_data_object): # elif np.isscalar(data):
new_data = data.get_full_data() # return np.ones(self.global_shape, dtype=self.dtype)*data
elif isinstance(data, np.ndarray): # copy = False
new_data = data # elif isinstance(data, np.ndarray) or \
elif np.isscalar(data): # isinstance(data, distributed_data_object):
new_data = np.ones(self.global_shape, dtype=self.dtype)*data # data = self.extract_local_data(data)
copy = False # result_data = np.empty(self.local_shape, dtype=self.dtype)
else: # result_data[:] = data
new_data = np.array(data) # return result_data
return new_data.astype(self.dtype, #
copy=copy).reshape(self.global_shape) # else:
# new_data = np.array(data)
# return new_data.astype(self.dtype,
# copy=copy).reshape(self.global_shape)
#
#
## if data is None:
## return np.empty(self.global_shape, dtype=self.dtype)
## elif isinstance(data, distributed_data_object):
## new_data = data.get_full_data()
## elif isinstance(data, np.ndarray):
## new_data = data
## elif np.isscalar(data):
## new_data = np.ones(self.global_shape, dtype=self.dtype)*data
## copy = False
## else:
## new_data = np.array(data)
## return new_data.astype(self.dtype,
## copy=copy).reshape(self.global_shape)
def _disperse_data_primitive(self, data, to_key, data_update, from_key, def _disperse_data_primitive(self, data, to_key, data_update, from_key,
copy, to_found, to_found_boolean, from_found, copy, to_found, to_found_boolean, from_found,
...@@ -2118,9 +2209,15 @@ class _not_distributor(distributor): ...@@ -2118,9 +2209,15 @@ class _not_distributor(distributor):
def extract_local_data(self, data_object): def extract_local_data(self, data_object):
if isinstance(data_object, distributed_data_object): if isinstance(data_object, distributed_data_object):
return data_object.get_full_data().reshape(self.global_shape) result_data = data_object.get_full_data()
else: else:
return np.array(data_object)[:].reshape(self.global_shape) result_data = np.array(data_object)[:]
try:
result_data = result_data.reshape(self.global_shape)
except ValueError:
pass
return result_data
def flatten(self, data, inplace=False): def flatten(self, data, inplace=False):
if inplace: if inplace:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment