Commit 9d83ed2f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

first stab at MPI redistribution

parent 45ad79d6
Pipeline #85543 failed with stages
in 6 minutes and 42 seconds
#include "ducc0/infra/system.cc"
#include "ducc0/infra/string_utils.cc"
#include "ducc0/infra/threading.cc"
#include "ducc0/infra/communication.cc"
#include "ducc0/math/pointing.cc"
#include "ducc0/math/geom_utils.cc"
#include "ducc0/math/space_filling.cc"
......
......@@ -186,6 +186,91 @@ void Communicator::all2allvRawVoid (const void *in, const int *numin,
void Communicator::bcastRawVoid (void *data, type_index type, size_t num, int root) const
{ MPI_Bcast (data,num,ndt2mpi(type),root,comm_); }
template<typename T1, typename T2> inline void rearrange(T1 &v, const T2 &idx)
{
T1 tmp(v);
for (size_t i=0; i<idx.size(); ++i)
v[i] = tmp[idx[i]];
}
MPI_Datatype fmav2mpidt(const fmav_info &info, MPI_Datatype origtype)
{
size_t ndim = info.ndim();
vector<int>shape(ndim), stride(ndim);
for (size_t i=0; i<ndim; ++i)
{
shape[i] = int(info.shape(i));
stride[i] = int(info.stride(i));
}
vector<size_t> idx(shape.size());
iota(idx.begin(), idx.end(), 0);
sort (idx.begin(), idx.end(),
[&stride](size_t i1, size_t i2) {return stride[i1] > stride[i2];});
rearrange(shape, idx);
rearrange(stride, idx);
for (size_t i=0; i+1<stride.size(); ++i)
{
auto tmp = stride[i]/stride[i+1];
MR_assert(stride[i]==stride[i+1]*tmp, "weird strides");
stride[i] = tmp;
}
shape.push_back(1);
stride.insert(stride.begin(),shape[0]);
MPI_Datatype res;
vector<int> zeros(ndim+1,0);
MPI_Type_create_subarray(shape.size(),
stride.data(),
shape.data(),
zeros.data(),
MPI_ORDER_C,
origtype,
&res);
MPI_Type_commit(&res);
return res;
}
void Communicator::redistributeRawVoid(const fmav_info &iin, const void *in,
const fmav_info &iout, void *out, size_t axin, size_t axout, type_index type) const
{
auto ndim = iin.ndim();
auto nranks = size_t(num_ranks());
MR_assert(ndim==iout.ndim(), "array dimensions must be equal");
MR_assert(axin<ndim, "invalid axin");
MR_assert(axout<ndim, "invalid axout");
for (size_t i=0; i<ndim; ++i)
if ((i!=axin) && (i!=axout))
MR_assert(iin.shape(i)==iout.shape(i), "shape mismatch");
auto s_in = allgatherVec(int(iin.shape(axin)));
MR_assert(int(iout.shape(axin))==reduce(s_in.begin(), s_in.end()), "inconsistency");
auto s_out = allgatherVec(int(iout.shape(axout)));
MR_assert(int(iin.shape(axout))==reduce(s_out.begin(), s_out.end()), "inconsistency");
vector<MPI_Datatype> v_in(nranks), v_out(nranks);
for (size_t i=0; i<nranks; ++i)
{
auto tmp = iin.shape();
tmp[axout] = s_out[i];
v_in[i] = fmav2mpidt(fmav_info(tmp, iin.stride()), ndt2mpi(type));
tmp = iout.shape();
tmp[axin] = s_in[i];
v_out[i] = fmav2mpidt(fmav_info(tmp, iout.stride()), ndt2mpi(type));
}
vector<int> disp_in(nranks), disp_out(nranks);
for (size_t i=0; i<nranks; ++i)
{
disp_in[i] = (i==0) ? 0 : disp_in[i-1]+s_out[i-1]*iin.stride(axout)*typesize(type);
disp_out[i] = (i==0) ? 0 : disp_out[i-1]+s_in[i-1]*iout.stride(axin)*typesize(type);
}
vector<int> num(nranks, 1);
MPI_Alltoallw(in, num.data(), disp_in.data(), v_in.data(),
out, num.data(), disp_out.data(), v_out.data(), comm_);
for (auto &t: v_in) MPI_Type_free(&t);
for (auto &t: v_out) MPI_Type_free(&t);
}
#else
//static
......@@ -262,6 +347,10 @@ void Communicator::all2allvRawVoid (const void *in, const int *numin,
void Communicator::bcastRawVoid (void *, type_index, size_t, int) const
{}
void Communicator::redistributeRawVoid(const fmav_info &, const void *,
const fmav_info &, void *, size_t, size_t, type_index) const
{ MR_fail("must not get here"); }
#endif
}}
......@@ -24,7 +24,7 @@
#ifndef DUCC0_COMMUNICATION_H
#define DUCC0_COMMUNICATION_H
#define DUCC0_USE_MPI
//#define DUCC0_USE_MPI
#include <vector>
#ifdef DUCC0_USE_MPI
......@@ -32,6 +32,8 @@
#endif
#include "ducc0/infra/types.h"
#include "ducc0/infra/mav.h"
#include "ducc0/infra/transpose.h"
namespace ducc0 {
......@@ -61,7 +63,6 @@ class Communicator
private:
CommType comm_;
int rank_, num_ranks_;
Communicator(CommType comm);
void sendrecvRawVoid (const void *sendbuf, size_t sendcnt,
size_t dest, void *recvbuf, size_t recvcnt, size_t src, type_index type) const;
......@@ -79,9 +80,13 @@ class Communicator
void all2allvRawVoid (const void *in, const int *numin, const int *disin,
void *out, const int *numout, const int *disout, type_index type) const;
void bcastRawVoid (void *data, type_index type, size_t num, int root) const;
void redistributeRawVoid (const fmav_info &iin, const void *in,
const fmav_info &iout, void *out,
size_t axin, size_t axout, type_index type) const;
public:
Communicator();
Communicator(CommType comm);
~Communicator();
Communicator(const Communicator &other) = default;
......@@ -147,6 +152,15 @@ class Communicator
template<typename T> void bcastRaw (T *data, size_t num, int root=0) const
{ bcastRawVoid (data, tidx<T>(), num, root); }
template<typename T> void redistribute (const fmav<T> &in, fmav<T> &out,
size_t axin, size_t axout) const
{
if (num_ranks()==1)
transpose(in, out);
else
redistributeRawVoid(in, &in[0], out, &out.vraw(0), axin, axout, tidx<T>());
}
};
}
......
......@@ -326,6 +326,14 @@ template<typename T> class fmav: public fmav_info, public membuf<T>
: tinfo(shp_), tbuf(d_,rw_) {}
fmav(const shape_t &shp_)
: tinfo(shp_), tbuf(size()) {}
fmav(const shape_t &shp_, const stride_t &str_)
: tinfo(shp_, str_), tbuf(size())
{
ptrdiff_t ofs=0;
for (size_t i=0; i<ndim(); ++i)
ofs += (ptrdiff_t(shp[i])-1)*str[i];
MR_assert(ofs+1==ptrdiff_t(size()), "array is not compact");
}
fmav(const T* d_, const tinfo &info)
: tinfo(info), tbuf(d_) {}
fmav(T* d_, const tinfo &info, bool rw_=false)
......
......@@ -47,7 +47,7 @@ template<typename T1, typename T2> inline void rearrange(T1 &v, const T2 &idx)
v[i] = tmp[idx[i]];
}
auto prep(const fmav_info &in, const fmav_info &out)
inline auto prep(const fmav_info &in, const fmav_info &out)
{
MR_assert(in.shape()==out.shape(), "shape mismatch");
shape_t shp;
......@@ -89,7 +89,7 @@ template<typename T, typename Func> void sthelper1(const T *in, T *out,
func(*in, *out);
}
bool critical(ptrdiff_t s)
inline bool critical(ptrdiff_t s)
{
s = (s>=0) ? s : -s;
return (s>4096) && ((s&(s-1))==0);
......@@ -173,6 +173,8 @@ template<typename T, typename Func> void transpose(const fmav<T> &in,
}
iter(in2, out2, 0, 0, 0, func);
}
template<typename T> void transpose(const fmav<T> &in, fmav<T> &out)
{ transpose(in, out, [](const T &in, T&out) { out=in; }); }
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment