Commit 72493960 authored by theos's avatar theos

Added get_axes_local_distribution_strategy method.

parent 65098d3f
...@@ -1760,6 +1760,10 @@ class distributed_data_object(object): ...@@ -1760,6 +1760,10 @@ class distributed_data_object(object):
return self.distributor.consolidate_data(self.data, return self.distributor.consolidate_data(self.data,
target_rank=target_rank) target_rank=target_rank)
def get_axes_local_distribution_strategy(self, axes):
axes = cast_axis_to_tuple(axes, len(self.shape))
return self.distributor.get_axes_local_distribution_strategy(axes)
def flatten(self, inplace=False): def flatten(self, inplace=False):
""" Returns a flat copy of the d2o collapsed into one dimension. """ Returns a flat copy of the d2o collapsed into one dimension.
......
...@@ -1830,6 +1830,12 @@ class _slicing_distributor(distributor): ...@@ -1830,6 +1830,12 @@ class _slicing_distributor(distributor):
def get_iter(self, d2o): def get_iter(self, d2o):
return d2o_slicing_iter(d2o) return d2o_slicing_iter(d2o)
def get_axes_local_distribution_strategy(self, axes):
if 0 in axes:
return self.distribution_strategy
else:
return 'not'
def _equal_slicer(comm, global_shape): def _equal_slicer(comm, global_shape):
rank = comm.rank rank = comm.rank
...@@ -2140,3 +2146,6 @@ class _not_distributor(distributor): ...@@ -2140,3 +2146,6 @@ class _not_distributor(distributor):
def get_iter(self, d2o): def get_iter(self, d2o):
return d2o_not_iter(d2o) return d2o_not_iter(d2o)
def get_axes_local_distribution_strategy(self, axes):
return 'not'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment