Commit 03a8b938 authored by theos's avatar theos
Browse files

Fixed two bugs related to the contraction_helper and the mpi_operator translator.

parent d4b62092
...@@ -602,11 +602,10 @@ class _slicing_distributor(distributor): ...@@ -602,11 +602,10 @@ class _slicing_distributor(distributor):
contracted_global_data = contracted_local_data contracted_global_data = contracted_local_data
new_dist_strategy = parent.distribution_strategy new_dist_strategy = parent.distribution_strategy
new_dtype = contracted_global_data.dtype
if new_shape == (): if new_shape == ():
result = contracted_global_data result = contracted_global_data
else: else:
new_dtype = contracted_global_data.dtype
# try to store the result in a distributed_data_object with the # try to store the result in a distributed_data_object with the
# distribution_strategy as parent # distribution_strategy as parent
result = parent.copy_empty(global_shape=new_shape, result = parent.copy_empty(global_shape=new_shape,
......
...@@ -10,12 +10,12 @@ MPI = gdi[gc['mpi_module']] ...@@ -10,12 +10,12 @@ MPI = gdi[gc['mpi_module']]
custom_MIN = MPI.Op.Create(lambda x, y, datatype: custom_MIN = MPI.Op.Create(lambda x, y, datatype:
np.amin(np.vstack((x, y)), axis=0) np.amin(np.vstack((x, y)), axis=0)
if isinstance(x, np.ndarray) else if isinstance(x, np.ndarray) else
lambda x, y, d: MPI.MIN(x, y)) min(x, y))
custom_MAX = MPI.Op.Create(lambda x, y, datatype: custom_MAX = MPI.Op.Create(lambda x, y, datatype:
np.amax(np.vstack((x, y)), axis=0) np.amax(np.vstack((x, y)), axis=0)
if isinstance(x, np.ndarray) else if isinstance(x, np.ndarray) else
lambda x, y, d: MPI.MAX(x, y)) max(x, y))
custom_NANMIN = MPI.Op.Create(lambda x, y, datatype: custom_NANMIN = MPI.Op.Create(lambda x, y, datatype:
np.nanmin(np.vstack((x, y)), axis=0)) np.nanmin(np.vstack((x, y)), axis=0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment