mpi_interface.py 963 B
import dill
try:
from mpi4py import MPI
except ImportError:
mpi_size = 1
my_rank = 0
else:
comm = MPI.COMM_WORLD
mpi_size = comm.Get_size()
my_rank = comm.Get_rank()
def get_mpi_start_end_from_list(len_list, start_pt=0):
els_per_rank = len_list // mpi_size
remainder = len_list % mpi_size
start_el = start_pt + els_per_rank * my_rank + min(my_rank, remainder)
end_el = start_pt + els_per_rank * (my_rank + 1) + min(my_rank + 1, remainder)
return start_el, end_el
def allgather_object(obj, all2all=False):
if mpi_size > 1:
serialized_obj = dill.dumps(obj)
if all2all:
all_serialized_obj = comm.allgather(serialized_obj)
else:
all_serialized_obj = comm.gather(serialized_obj, root=0)
all_serialized_obj = comm.bcast(all_serialized_obj, root=0)
all_obj = [dill.loads(oo) for oo in all_serialized_obj]
return all_obj
return [obj]