Commit 6a69a231 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

first beginnings

parent fd7a243c
Pipeline #74518 passed with stages
in 8 minutes and 50 seconds
#include <cstdlib>
#include <cstring>
#include <numeric>
#include <unordered_map>
#include "mr_util/infra/communication.h"
#include "mr_util/infra/error_handling.h"
namespace mr {
namespace detail_communication {
using namespace std;
void assert_unequal (const void *a, const void *b)
{ MR_assert (a!=b, "input and output buffers must not be identical"); }
#ifdef MRUTIL_USE_MPI
class Typemap: public TypeMapper<MPI_Datatype>
{
public:
Typemap()
{
add<double>(MPI_DOUBLE);
add<float>(MPI_FLOAT);
add<int>(MPI_INT);
add<long>(MPI_LONG);
add<char>(MPI_CHAR);
add<unsigned char>(MPI_BYTE);
// etc.
}
};
Typemap typemap;
MPI_Datatype ndt2mpi (type_index type)
{ return typemap[type]; }
MPI_Op op2mop (Communicator::redOp op)
{
switch (op)
{
case Communicator::Min : return MPI_MIN;
case Communicator::Max : return MPI_MAX;
case Communicator::Sum : return MPI_SUM;
case Communicator::Prod: return MPI_PROD;
default: MR_fail ("unsupported reduction operation");
}
}
//static
void Communication::init()
{
MPI_Init(0,0);
MPI_Comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_ARE_FATAL);
}
//static
bool Communication::initialized()
{
int flag=0;
MPI_Initialized(&flag);
return flag;
}
//static
void Communication::finalize()
{ MPI_Finalize(); }
//static
void Communication::abort()
{
if (initialized())
MPI_Abort(MPI_COMM_WORLD, 1);
else
exit(1);
}
Communicator::Communicator(CommType comm)
: comm_(comm)
{
MPI_Comm_size(comm_, &num_ranks_);
MPI_Comm_rank(comm_, &rank_);
}
Communicator::Communicator()
: Communicator(MPI_COMM_WORLD) {}
Communicator::~Communicator()
{
if (comm_!=MPI_COMM_WORLD)
MPI_Comm_free(&comm_);
}
void Communicator::barrier() const
{ MPI_Barrier(comm_); }
Communicator Communicator::split(size_t color) const
{
MPI_Comm comm;
MPI_Comm_split (comm_, color, rank_, &comm);
return Communicator(comm);
}
void Communicator::sendrecvRawVoid (const void *sendbuf, size_t sendcnt,
size_t dest, void *recvbuf, size_t recvcnt, size_t src, type_index type) const
{
if ((sendcnt>0)&&(recvcnt>0)) assert_unequal(sendbuf,recvbuf);
MPI_Datatype dtype = ndt2mpi(type);
MPI_Sendrecv (const_cast<void *>(sendbuf),sendcnt,dtype,dest,0,
recvbuf,recvcnt,dtype,src,0,comm_,MPI_STATUS_IGNORE);
}
void Communicator::sendrecv_replaceRawVoid (void *data, type_index type, size_t num,
size_t dest, size_t src) const
{
MPI_Sendrecv_replace (data,num,ndt2mpi(type),dest,0,src,0,comm_,
MPI_STATUS_IGNORE);
}
void Communicator::allreduceRawVoid (const void *in, void *out, type_index type,
size_t num, redOp op) const
{
void *in2 = (in==out) ? MPI_IN_PLACE : const_cast<void *>(in);
MPI_Allreduce (in2,out,num,ndt2mpi(type),op2mop(op),comm_);
}
void Communicator::allgatherRawVoid (const void *in, void *out, type_index type,
size_t num) const
{
if (num>0) assert_unequal(in,out);
MPI_Datatype tp = ndt2mpi(type);
MPI_Allgather (const_cast<void *>(in),num,tp,out,num,tp,comm_);
}
void Communicator::allgathervRawVoid (const void *in, int numin, void *out,
const int *numout, const int *disout, type_index type) const
{
if (numin>0) assert_unequal(in,out);
MR_assert(numin==numout[rank_],"inconsistent arguments");
MPI_Datatype tp = ndt2mpi(type);
MPI_Allgatherv (const_cast<void *>(in),numin,tp,out,const_cast<int *>(numout),
const_cast<int *>(disout),tp,comm_);
}
void Communicator::all2allRawVoid (const void *in, void *out, type_index type,
size_t num) const
{
void *in2 = (in==out) ? MPI_IN_PLACE : const_cast<void *>(in);
MR_assert (num%num_ranks_==0,
"array size is not divisible by number of ranks");
MPI_Datatype tp = ndt2mpi(type);
MPI_Alltoall (in2,num/num_ranks_,tp,out,num/num_ranks_,tp,comm_);
}
void Communicator::all2allvRawVoid (const void *in, const int *numin,
const int *disin, void *out, const int *numout, const int *disout, type_index type)
const
{
long commsz=disin[num_ranks_-1]+numin[num_ranks_-1]
+disout[num_ranks_-1]+numout[num_ranks_-1];
if (commsz>0) assert_unequal(in,out);
MPI_Datatype tp = ndt2mpi(type);
MPI_Alltoallv (const_cast<void *>(in), const_cast<int *>(numin),
const_cast<int *>(disin), tp, out, const_cast<int *>(numout),
const_cast<int *>(disout), tp, comm_);
}
void Communicator::bcastRawVoid (void *data, type_index type, size_t num, int root) const
{ MPI_Bcast (data,num,ndt2mpi(type),root,comm_); }
#else
//static
void Communication::init() {}
//static
bool Communication::initialized()
{ return true; }
//static
void Communication::finalize() {}
//static
void Communication::abort()
{ exit(1); }
Communicator::Communicator()
: rank_(0), num_ranks_(1) {}
Communicator::~Communicator() {}
void Communicator::barrier() const {}
Communicator Communicator::split(size_t /*color*/) const
{ return *this; }
void Communicator::sendrecvRawVoid (const void *sendbuf, size_t sendcnt,
size_t dest, void *recvbuf, size_t recvcnt, size_t src, type_index type) const
{
MR_assert ((dest==0) && (src==0), "inconsistent call");
MR_assert (sendcnt==recvcnt, "inconsistent call");
if (sendcnt>0) assert_unequal(sendbuf,recvbuf);
memcpy (recvbuf, sendbuf, sendcnt*typesize(type));
}
void Communicator::sendrecv_replaceRawVoid (void *, type_index, size_t, size_t dest,
size_t src) const
{ MR_assert ((dest==0) && (src==0), "inconsistent call"); }
void Communicator::allreduceRawVoid (const void *in, void *out, type_index type,
size_t num, redOp /*op*/) const
{
if (in==out) return;
memcpy (out, in, num*typesize(type));
}
void Communicator::allgatherRawVoid (const void *in, void *out, type_index type,
size_t num) const
{ if (num>0) assert_unequal(in,out); memcpy (out, in, num*typesize(type)); }
void Communicator::all2allRawVoid (const void *in, void *out, type_index type,
size_t num) const
{
if (in==out) return;
memcpy (out, in, num*typesize(type));
}
void Communicator::allgathervRawVoid (const void *in, int numin, void *out,
const int *numout, const int *disout, type_index type) const
{
if (numin>0) assert_unequal(in,out);
MR_assert(numin==numout[0],"inconsistent call");
memcpy (reinterpret_cast<char *>(out)+disout[0]*typesize(type), in,
numin*typesize(type));
}
void Communicator::all2allvRawVoid (const void *in, const int *numin,
const int *disin, void *out, const int *numout, const int *disout, type_index type)
const
{
if (numin[0]>0) assert_unequal(in,out);
MR_assert (numin[0]==numout[0],"message size mismatch");
const char *in2 = static_cast<const char *>(in);
char *out2 = static_cast<char *>(out);
size_t st=typesize(type);
memcpy (out2+disout[0]*st,in2+disin[0]*st,numin[0]*st);
}
void Communicator::bcastRawVoid (void *, type_index, size_t, int) const
{}
#endif
}}
/*
* This file is part of the MR utility library.
*
* This code is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This code is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this code; if not, write to the Free Software
* Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
*/
/*
* Copyright (C) 2009-2020 Max-Planck-Society
* \author Martin Reinecke
*/
#ifndef MRUTIL_COMMUNICATION_H
#define MRUTIL_COMMUNICATION_H
#include <vector>
#ifdef MRUTIL_USE_MPI
#include <mpi.h>
#endif
#include "mr_util/infra/types.h"
namespace mr {
namespace detail_communication {
using namespace std;
class Communication
{
public:
static void init();
static bool initialized();
static void finalize();
static void abort();
};
class Communicator
{
public:
enum redOp { Sum, Min, Max, Prod };
#ifdef MRUTIL_USE_MPI
using CommType = MPI_Comm;
#else
using CommType = struct{};
#endif
private:
CommType comm_;
int rank_, num_ranks_;
Communicator(CommType comm);
void sendrecvRawVoid (const void *sendbuf, size_t sendcnt,
size_t dest, void *recvbuf, size_t recvcnt, size_t src, type_index type) const;
void sendrecv_replaceRawVoid (void *data, type_index type, size_t num,
size_t dest, size_t src) const;
void allreduceRawVoid (const void *in, void *out, type_index type, size_t num,
redOp op) const;
void allgatherRawVoid (const void *in, void *out, type_index type, size_t num)
const;
void allgathervRawVoid (const void *in, int numin, void *out,
const int *numout, const int *disout, type_index type) const;
/*! NB: \a num refers to the <i>total</i> number of items in the arrays;
the individual message size is \a num/num_ranks(). */
void all2allRawVoid (const void *in, void *out, type_index type, size_t num) const;
void all2allvRawVoid (const void *in, const int *numin, const int *disin,
void *out, const int *numout, const int *disout, type_index type) const;
void bcastRawVoid (void *data, type_index type, size_t num, int root) const;
public:
Communicator();
~Communicator();
Communicator(const Communicator &other) = default;
int num_ranks() const { return num_ranks_; }
int rank() const { return rank_; }
bool master() const { return rank_==0; }
CommType comm() const { return comm_; }
void barrier() const;
Communicator split(size_t subgroup) const;
template<typename T> void sendrecvRaw (const T *sendbuf, size_t sendcnt,
size_t dest, T *recvbuf, size_t recvcnt, size_t src) const
{
sendrecvRawVoid(sendbuf, sendcnt, dest, recvbuf, recvcnt, src,
tidx<T>());
}
template<typename T> void sendrecv_replaceRaw (T *data, size_t num,
size_t dest, size_t src) const
{ sendrecv_replaceRawVoid(data, tidx<T>(), num, dest, src); }
template<typename T> void allreduceRaw (const T *in, T *out, size_t num,
redOp op) const
{ allreduceRawVoid (in, out, tidx<T>(), num, op); }
template<typename T> void allgatherRaw (const T *in, T *out, size_t num)
const
{ allgatherRawVoid (in, out, tidx<T>(), num); }
template<typename T> void allgathervRaw (const T *in, int numin, T *out,
const int *numout, const int *disout) const
{ allgathervRawVoid (in, numin, out, numout, disout, tidx<T>()); }
template<typename T> T allreduce(const T &in, redOp op) const
{
T out;
allreduceRaw (&in, &out, 1, op);
return out;
}
template<typename T> std::vector<T> allreduce
(const std::vector<T> &in, redOp op) const
{
std::vector<T> out(in.size());
allreduceRaw (in.data(), out.data(), in.size(), op);
return out;
}
/*! NB: \a num refers to the <i>total</i> number of items in the arrays;
the individual message size is \a num/num_ranks(). */
template<typename T> void all2allRaw (const T *in, T *out, size_t num) const
{ all2allRawVoid (in, out, tidx<T>(), num); }
template<typename T> void all2allvRaw (const T *in, const int *numin,
const int *disin, T *out, const int *numout, const int *disout) const
{ all2allvRawVoid (in,numin,disin,out,numout,disout,tidx<T>()); }
template<typename T> void bcastRaw (T *data, size_t num, int root=0) const
{ bcastRawVoid (data, tidx<T>(), num, root); }
};
}
using detail_communication::Communication;
using detail_communication::Communicator;
}
#endif
/*
* This file is part of the MR utility library.
*
* This code is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This code is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this code; if not, write to the Free Software
* Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
*/
/* Copyright (C) 2020 Max-Planck-Society
Author: Martin Reinecke */
#include <cstdint>
#include "mr_util/infra/error_handling.h"
#include "mr_util/infra/types.h"
using namespace std;
namespace mr {
namespace detail_types {
using namespace std;
class Sizemap: public TypeMapper<size_t>
{
protected:
template<typename T, typename... Ts> void addTypes()
{
add<T>(sizeof(T));
if constexpr (sizeof...(Ts)>0) addTypes<Ts...>();
}
public:
Sizemap()
{
addTypes<double, float, int, long, size_t, ptrdiff_t,
int32_t, int64_t, uint32_t, uint64_t>();
}
};
Sizemap sizemap;
size_t typesize(const type_index &idx)
{ return sizemap[idx]; }
}}
/*
* This file is part of the MR utility library.
*
* This code is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This code is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this code; if not, write to the Free Software
* Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
*/
/* Copyright (C) 2020 Max-Planck-Society
Author: Martin Reinecke */
#ifndef MRUTIL_TYPES_H
#define MRUTIL_TYPES_H
#include <typeinfo>
#include <typeindex>
#include <cstddef>
#include <unordered_map>
#include "mr_util/infra/error_handling.h"
namespace mr {
namespace detail_types {
using namespace std;
template<typename T> constexpr inline auto tidx()
{ return type_index(typeid(T)); }
template<typename DT> class TypeMapper
{
protected:
unordered_map<type_index, DT> mapping;
public:
template<typename T> void add (const DT &dt)
{ mapping[tidx<T>()] = dt; }
DT operator[](const type_index &idx) const
{
auto res = mapping.find(idx);
MR_assert(res!=mapping.end(), "type not found");
return res->second;
}
};
size_t typesize(const type_index &idx);
}
using detail_types::tidx;
using detail_types::typesize;
using detail_types::TypeMapper;
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment