Skip to content
Snippets Groups Projects
mpi_interface.py 963 B
import dill

try:
    from mpi4py import MPI
except ImportError:
    mpi_size = 1
    my_rank = 0
else:
    comm = MPI.COMM_WORLD
    mpi_size = comm.Get_size()
    my_rank = comm.Get_rank()


def get_mpi_start_end_from_list(len_list, start_pt=0):
    els_per_rank = len_list // mpi_size
    remainder = len_list % mpi_size

    start_el = start_pt + els_per_rank * my_rank + min(my_rank, remainder)
    end_el = start_pt + els_per_rank * (my_rank + 1) + min(my_rank + 1, remainder)

    return start_el, end_el


def allgather_object(obj, all2all=False):
    if mpi_size > 1:
        serialized_obj = dill.dumps(obj)
        if all2all:
            all_serialized_obj = comm.allgather(serialized_obj)
        else:
            all_serialized_obj = comm.gather(serialized_obj, root=0)
            all_serialized_obj = comm.bcast(all_serialized_obj, root=0)

        all_obj = [dill.loads(oo) for oo in all_serialized_obj]
        return all_obj

    return [obj]