Commit c448ab64 authored by theos's avatar theos
Browse files

Fixed the distribution_strategy handling of cumsum.

parent d66fd65f
Pipeline #2070 skipped
......@@ -1798,19 +1798,7 @@ class distributed_data_object(object):
Contains the results of the cummulative sum.
"""
cumsum_data = self.distributor.cumsum(self.data, axis=axis)
if axis is None:
flat_global_shape = (np.prod(self.shape),)
flat_local_shape = np.shape(cumsum_data)
result_d2o = self.copy_empty(global_shape=flat_global_shape,
local_shape=flat_local_shape)
else:
result_d2o = self.copy_empty()
result_d2o.set_local_data(cumsum_data)
return result_d2o
return self.distributor.cumsum(parent=self, axis=axis)
def save(self, alias, path=None, overwriteQ=True):
""" Saves the distributed_data_object to disk utilizing h5py.
......
......@@ -1506,7 +1506,8 @@ class _slicing_distributor(distributor):
MPI.SUM)
return global_counts
def cumsum(self, data, axis):
def cumsum(self, parent, axis):
data = parent.data
# compute the local np.cumsum
local_cumsum = np.cumsum(data, axis=axis)
if axis is None or axis == 0:
......@@ -1520,7 +1521,28 @@ class _slicing_distributor(distributor):
local_sum_of_shift = np.sum(local_shift_list[:rank],
axis=0)
local_cumsum += local_sum_of_shift
return local_cumsum
# create the return d2o
if axis is None:
# try to preserve the distribution_strategy
flat_global_shape = (self.global_dim, )
flat_local_shape = np.shape(local_cumsum)
result_d2o = parent.copy_empty(global_shape=flat_global_shape,
local_shape=flat_local_shape)
# check if the original distribution strategy yielded a suitable
# local_shape
if result_d2o.local_shape != flat_local_shape:
# if it does not fit, construct a freeform d2o
result_d2o = parent.copy_empty(
global_shape=flat_global_shape,
local_shape=flat_local_shape,
distribution_strategy='freeform')
else:
result_d2o = parent.copy_empty()
result_d2o.set_local_data(local_cumsum, copy=False)
return result_d2o
def _sliceify(self, inp):
sliceified = []
......@@ -2001,10 +2023,13 @@ class _not_distributor(distributor):
minlength=minlength)
return counts
def cumsum(self, data, axis):
def cumsum(self, parent, axis):
data = parent.data
# compute the local results from np.cumsum
cumsum = np.cumsum(data, axis=axis)
return cumsum
local_cumsum = np.cumsum(data, axis=axis)
result_d2o = parent.copy_empty(global_shape=local_cumsum.shape)
result_d2o.set_local_data(local_cumsum, copy=False)
return result_d2o
if 'h5py' in gdi:
def save_data(self, data, alias, path=None, overwriteQ=True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment