Commit 41c20c69 authored by Theo Steininger's avatar Theo Steininger

distributed_data_object can now process hdf5 dataset directly.

parent 60a48b67
Pipeline #10631 passed with stage
in 17 minutes and 59 seconds
...@@ -318,12 +318,17 @@ class distributor(object): ...@@ -318,12 +318,17 @@ class distributor(object):
if 'h5py' in gdi and alias is not None: if 'h5py' in gdi and alias is not None:
data = self.load_data(alias=alias, path=path) data = self.load_data(alias=alias, path=path)
h5py_dataset_Q = False
if 'h5py' in gdi:
h5py_dataset_Q = isinstance(data, h5py.Dataset)
if data is None: if data is None:
return np.empty(self.local_shape, dtype=self.dtype) return np.empty(self.local_shape, dtype=self.dtype)
elif np.isscalar(data): elif np.isscalar(data):
return np.ones(self.local_shape, dtype=self.dtype)*data return np.ones(self.local_shape, dtype=self.dtype)*data
elif isinstance(data, np.ndarray) or \ elif isinstance(data, np.ndarray) or \
isinstance(data, distributed_data_object): isinstance(data, distributed_data_object) or \
h5py_dataset_Q:
data = self.extract_local_data(data) data = self.extract_local_data(data)
if data.shape is not self.local_shape: if data.shape is not self.local_shape:
...@@ -1455,9 +1460,13 @@ class _slicing_distributor(distributor): ...@@ -1455,9 +1460,13 @@ class _slicing_distributor(distributor):
return (local_start, local_stop) return (local_start, local_stop)
def extract_local_data(self, data_object): def extract_local_data(self, data_object):
h5py_dataset_Q = False
if 'h5py' in gdi:
h5py_dataset_Q = isinstance(data_object, h5py.Dataset)
# if data_object is not a ndarray or a d2o, cast it to a ndarray # if data_object is not a ndarray or a d2o, cast it to a ndarray
if not (isinstance(data_object, np.ndarray) or if not (isinstance(data_object, np.ndarray) or
isinstance(data_object, distributed_data_object)): isinstance(data_object, distributed_data_object) or
h5py_dataset_Q):
data_object = np.array(data_object) data_object = np.array(data_object)
# check if the shapes are remotely compatible, reshape if possible # check if the shapes are remotely compatible, reshape if possible
# and determine which dimensions match only via broadcasting # and determine which dimensions match only via broadcasting
...@@ -1483,7 +1492,7 @@ class _slicing_distributor(distributor): ...@@ -1483,7 +1492,7 @@ class _slicing_distributor(distributor):
if isinstance(data_object, distributed_data_object): if isinstance(data_object, distributed_data_object):
# If the first dimension matches only via broadcasting... # If the first dimension matches only via broadcasting...
# Case 1: ...do broadcasting. # Case 1: ...do broadcasting.
if matching_dimensions[0] == False: if not matching_dimensions[0]:
extracted_data = data_object.get_full_data() extracted_data = data_object.get_full_data()
extracted_data = extracted_data[0] extracted_data = extracted_data[0]
else: else:
...@@ -1505,7 +1514,8 @@ class _slicing_distributor(distributor): ...@@ -1505,7 +1514,8 @@ class _slicing_distributor(distributor):
extracted_data = data_object.data extracted_data = data_object.data
else: else:
# Case 2: no. All nodes extract their local slice from the # Case 2: no.
# All nodes extract their local slice from the
# data_object # data_object
extracted_data =\ extracted_data =\
data_object.get_data(slice(self.local_start, data_object.get_data(slice(self.local_start,
...@@ -1535,10 +1545,10 @@ class _slicing_distributor(distributor): ...@@ -1535,10 +1545,10 @@ class _slicing_distributor(distributor):
# Case 2: np-array # Case 2: np-array
# If the first dimension matches only via broadcasting # If the first dimension matches only via broadcasting
# ...do broadcasting. # ...do broadcasting.
elif matching_dimensions[0] == False: elif not matching_dimensions[0]:
extracted_data = data_object[0:1] extracted_data = data_object[0:1]
# Case 3: First dimension fits directly and data_object is an # Case 3: First dimension fits directly and data_object is an
# generic array # generic array or h5py dataset
else: else:
extracted_data =\ extracted_data =\
data_object[self.local_start:self.local_end] data_object[self.local_start:self.local_end]
...@@ -1794,7 +1804,7 @@ class _slicing_distributor(distributor): ...@@ -1794,7 +1804,7 @@ class _slicing_distributor(distributor):
if isinstance(in_data, distributed_data_object): if isinstance(in_data, distributed_data_object):
local_data = in_data.get_local_data(copy=False) local_data = in_data.get_local_data(copy=False)
elif isinstance(in_data, np.ndarray) == False: elif not isinstance(in_data, np.ndarray):
local_data = np.array(in_data, copy=False) local_data = np.array(in_data, copy=False)
in_data = local_data in_data = local_data
else: else:
...@@ -1887,7 +1897,7 @@ class _slicing_distributor(distributor): ...@@ -1887,7 +1897,7 @@ class _slicing_distributor(distributor):
# low level mess!! # low level mess!!
if isinstance(in_data, distributed_data_object): if isinstance(in_data, distributed_data_object):
local_data = in_data.data local_data = in_data.data
elif isinstance(in_data, np.ndarray) == False: elif not isinstance(in_data, np.ndarray):
local_data = np.array(in_data, copy=False) local_data = np.array(in_data, copy=False)
in_data = local_data in_data = local_data
else: else:
...@@ -2096,7 +2106,7 @@ if 'pyfftw' in gdi: ...@@ -2096,7 +2106,7 @@ if 'pyfftw' in gdi:
# pyfftw.local_size crashes if any of the entries of global_shape # pyfftw.local_size crashes if any of the entries of global_shape
working_shape = np.array(global_shape) working_shape = np.array(global_shape)
mask = (working_shape == 0) mask = (working_shape == 0)
if mask[0] == True: if mask[0]:
start = 0 start = 0
end = 0 end = 0
return (start, end, global_shape) return (start, end, global_shape)
......
...@@ -20,4 +20,4 @@ ...@@ -20,4 +20,4 @@
# 1) we don't load dependencies by storing it in __init__.py # 1) we don't load dependencies by storing it in __init__.py
# 2) we can import it in setup.py for the same reason # 2) we can import it in setup.py for the same reason
# 3) we can import it into your module module # 3) we can import it into your module module
__version__ = '1.0.7' __version__ = '1.0.8'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment