diff --git a/d2o/distributor_factory.py b/d2o/distributor_factory.py index 61067e6138f9fa6783917ba667a595ea0c3cddc4..1cfeab31cdc877fddc6a156e0699ca2bf7b570e6 100644 --- a/d2o/distributor_factory.py +++ b/d2o/distributor_factory.py @@ -318,12 +318,17 @@ class distributor(object): if 'h5py' in gdi and alias is not None: data = self.load_data(alias=alias, path=path) + h5py_dataset_Q = False + if 'h5py' in gdi: + h5py_dataset_Q = isinstance(data, h5py.Dataset) + if data is None: return np.empty(self.local_shape, dtype=self.dtype) elif np.isscalar(data): return np.ones(self.local_shape, dtype=self.dtype)*data elif isinstance(data, np.ndarray) or \ - isinstance(data, distributed_data_object): + isinstance(data, distributed_data_object) or \ + h5py_dataset_Q: data = self.extract_local_data(data) if data.shape is not self.local_shape: @@ -1455,9 +1460,13 @@ class _slicing_distributor(distributor): return (local_start, local_stop) def extract_local_data(self, data_object): + h5py_dataset_Q = False + if 'h5py' in gdi: + h5py_dataset_Q = isinstance(data_object, h5py.Dataset) # if data_object is not a ndarray or a d2o, cast it to a ndarray if not (isinstance(data_object, np.ndarray) or - isinstance(data_object, distributed_data_object)): + isinstance(data_object, distributed_data_object) or + h5py_dataset_Q): data_object = np.array(data_object) # check if the shapes are remotely compatible, reshape if possible # and determine which dimensions match only via broadcasting @@ -1483,7 +1492,7 @@ class _slicing_distributor(distributor): if isinstance(data_object, distributed_data_object): # If the first dimension matches only via broadcasting... # Case 1: ...do broadcasting. - if matching_dimensions[0] == False: + if not matching_dimensions[0]: extracted_data = data_object.get_full_data() extracted_data = extracted_data[0] else: @@ -1505,7 +1514,8 @@ class _slicing_distributor(distributor): extracted_data = data_object.data else: - # Case 2: no. All nodes extract their local slice from the + # Case 2: no. + # All nodes extract their local slice from the # data_object extracted_data =\ data_object.get_data(slice(self.local_start, @@ -1535,10 +1545,10 @@ class _slicing_distributor(distributor): # Case 2: np-array # If the first dimension matches only via broadcasting # ...do broadcasting. - elif matching_dimensions[0] == False: + elif not matching_dimensions[0]: extracted_data = data_object[0:1] # Case 3: First dimension fits directly and data_object is an - # generic array + # generic array or h5py dataset else: extracted_data =\ data_object[self.local_start:self.local_end] @@ -1794,7 +1804,7 @@ class _slicing_distributor(distributor): if isinstance(in_data, distributed_data_object): local_data = in_data.get_local_data(copy=False) - elif isinstance(in_data, np.ndarray) == False: + elif not isinstance(in_data, np.ndarray): local_data = np.array(in_data, copy=False) in_data = local_data else: @@ -1887,7 +1897,7 @@ class _slicing_distributor(distributor): # low level mess!! if isinstance(in_data, distributed_data_object): local_data = in_data.data - elif isinstance(in_data, np.ndarray) == False: + elif not isinstance(in_data, np.ndarray): local_data = np.array(in_data, copy=False) in_data = local_data else: @@ -2096,7 +2106,7 @@ if 'pyfftw' in gdi: # pyfftw.local_size crashes if any of the entries of global_shape working_shape = np.array(global_shape) mask = (working_shape == 0) - if mask[0] == True: + if mask[0]: start = 0 end = 0 return (start, end, global_shape) diff --git a/d2o/version.py b/d2o/version.py index 16c79f90f9b421f5e381d6ce3310d73c89b9932a..b03d135ca49f323621e5a3137c23340c9168a4c4 100644 --- a/d2o/version.py +++ b/d2o/version.py @@ -20,4 +20,4 @@ # 1) we don't load dependencies by storing it in __init__.py # 2) we can import it in setup.py for the same reason # 3) we can import it into your module module -__version__ = '1.0.7' +__version__ = '1.0.8'