Commit ad862d52 authored by Theo Steininger's avatar Theo Steininger

Fixed distribute_data for global_shape == ()

parent ee663df9
Pipeline #13536 failed with stage
in 3 minutes and 22 seconds
...@@ -189,7 +189,7 @@ class distributed_data_object(Loggable, Versionable, object): ...@@ -189,7 +189,7 @@ class distributed_data_object(Loggable, Versionable, object):
dtype=dtype, dtype=dtype,
**kwargs) **kwargs)
self.distribution_strategy = distribution_strategy self.distribution_strategy = self.distributor.distribution_strategy
self.dtype = self.distributor.dtype self.dtype = self.distributor.dtype
self.shape = self.distributor.global_shape self.shape = self.distributor.global_shape
self.local_shape = self.distributor.local_shape self.local_shape = self.distributor.local_shape
......
...@@ -320,7 +320,7 @@ class distributor(object): ...@@ -320,7 +320,7 @@ class distributor(object):
result_data = np.empty(self.local_shape, dtype=self.dtype) result_data = np.empty(self.local_shape, dtype=self.dtype)
elif np.isscalar(data): elif np.isscalar(data):
result_data = np.empty(self.local_shape, dtype=self.dtype) result_data = np.empty(self.local_shape, dtype=self.dtype)
result_data[:] = data result_data[()] = data
elif isinstance(data, np.ndarray) or \ elif isinstance(data, np.ndarray) or \
isinstance(data, distributed_data_object) or \ isinstance(data, distributed_data_object) or \
h5py_dataset_Q: h5py_dataset_Q:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment