diff --git a/libsharp2/sharp_mpi.c b/libsharp2/sharp_mpi.c index 3b7ca7f50037a9b7f2a9626835c08745be07d14b..2d3e5f6538d42f688f577d00b209ac0d41aa5e85 100644 --- a/libsharp2/sharp_mpi.c +++ b/libsharp2/sharp_mpi.c @@ -50,6 +50,9 @@ typedef struct int *ispair; /* is this really a pair? */ int *almcount, *almdisp, *mapcount, *mapdisp; /* for all2all communication */ + + int npairmax; /* maxium value in npair */ + int nmmax; /* maximum value in nm */ } sharp_mpi_info; static void sharp_make_mpi_info (MPI_Comm comm, const sharp_job *job, @@ -111,6 +114,13 @@ static void sharp_make_mpi_info (MPI_Comm comm, const sharp_job *job, minfo->mapcount[i] = 2*minfo->nph*minfo->nm[i]*minfo->npair[minfo->mytask]; minfo->mapdisp[i+1] = minfo->mapdisp[i]+minfo->mapcount[i]; } + minfo->npairmax = minfo->nmmax = 0; + for (int i=0; i<minfo->ntasks; ++i) + { + if (minfo->npair[i]>minfo->npairmax) minfo->npairmax = minfo->npair[i]; + if (minfo->nm[i]>minfo->nmmax) minfo->nmmax = minfo->nm[i]; + } + } static void sharp_destroy_mpi_info (sharp_mpi_info *minfo) @@ -130,49 +140,74 @@ static void sharp_destroy_mpi_info (sharp_mpi_info *minfo) static void sharp_communicate_alm2map (const sharp_mpi_info *minfo, dcmplx **ph) { - dcmplx *phas_tmp = RALLOC(dcmplx,minfo->mapdisp[minfo->ntasks]/2); - - MPI_Alltoallv (*ph,minfo->almcount,minfo->almdisp,MPI_DOUBLE,phas_tmp, - minfo->mapcount,minfo->mapdisp,MPI_DOUBLE,minfo->comm); - + // on input: ph has shape(npairtotal,nm[task],nph) + dcmplx *sendbuf = RALLOC(dcmplx,minfo->nmmax*minfo->ntasks*minfo->npairmax*minfo->nph); + for (int task=0; task<minfo->ntasks; ++task) + for (int ti=0; ti<minfo->npair[task]; ++ti) + for (int mi=0; mi<minfo->nm[minfo->mytask]; ++mi) + { + int th = minfo->ofs_pair[task] + ti; + int obuf = minfo->nmmax*minfo->npairmax*minfo->nph*task + +minfo->nph*(mi + minfo->nmmax*ti); + int oarr = minfo->nph*(th*(minfo->nm[minfo->mytask]) + mi); + for (int i=0; i<minfo->nph; ++i) + sendbuf[obuf+i] = (*ph)[oarr+i]; + } DEALLOC(*ph); + dcmplx *recvbuf = RALLOC(dcmplx,minfo->nmmax*minfo->ntasks*minfo->npairmax*minfo->nph); + MPI_Alltoall (sendbuf, minfo->nph*minfo->nmmax*minfo->npairmax*2,MPI_DOUBLE, + recvbuf, minfo->nph*minfo->nmmax*minfo->npairmax*2,MPI_DOUBLE, + minfo->comm); + DEALLOC(sendbuf); ALLOC(*ph,dcmplx,minfo->nph*minfo->npair[minfo->mytask]*minfo->nmtotal); - + // on output: ph has shape(npair[task],mmax+1,nph) for (int task=0; task<minfo->ntasks; ++task) - for (int th=0; th<minfo->npair[minfo->mytask]; ++th) + for (int ti=0; ti<minfo->npair[minfo->mytask]; ++ti) for (int mi=0; mi<minfo->nm[task]; ++mi) { int m = minfo->mval[mi+minfo->ofs_m[task]]; - int o1 = minfo->nph*(th*(minfo->mmax+1) + m); - int o2 = minfo->mapdisp[task]/2+minfo->nph*(mi+th*minfo->nm[task]); + int obuf = minfo->nmmax*minfo->npairmax*minfo->nph*task + +minfo->nph*(mi + minfo->nmmax*ti); + int oarr = minfo->nph*(ti*(minfo->mmax+1) + m); for (int i=0; i<minfo->nph; ++i) - (*ph)[o1+i] = phas_tmp[o2+i]; + (*ph)[oarr+i] = recvbuf[obuf+i]; } - DEALLOC(phas_tmp); + DEALLOC(recvbuf); } static void sharp_communicate_map2alm (const sharp_mpi_info *minfo, dcmplx **ph) { - dcmplx *phas_tmp = RALLOC(dcmplx,minfo->mapdisp[minfo->ntasks]/2); - + dcmplx *sendbuf = RALLOC(dcmplx,minfo->nmmax*minfo->ntasks*minfo->npairmax*minfo->nph); for (int task=0; task<minfo->ntasks; ++task) - for (int th=0; th<minfo->npair[minfo->mytask]; ++th) + for (int ti=0; ti<minfo->npair[minfo->mytask]; ++ti) for (int mi=0; mi<minfo->nm[task]; ++mi) { int m = minfo->mval[mi+minfo->ofs_m[task]]; - int o1 = minfo->mapdisp[task]/2+minfo->nph*(mi+th*minfo->nm[task]); - int o2 = minfo->nph*(th*(minfo->mmax+1) + m); + int obuf = minfo->nmmax*minfo->npairmax*minfo->nph*task + +minfo->nph*(mi + minfo->nmmax*ti); + int oarr = minfo->nph*(ti*(minfo->mmax+1) + m); for (int i=0; i<minfo->nph; ++i) - phas_tmp[o1+i] = (*ph)[o2+i]; + sendbuf[obuf+i] = (*ph)[oarr+i]; } - DEALLOC(*ph); + dcmplx *recvbuf = RALLOC(dcmplx,minfo->nmmax*minfo->ntasks*minfo->npairmax*minfo->nph); + MPI_Alltoall (sendbuf, minfo->nph*minfo->nmmax*minfo->npairmax*2,MPI_DOUBLE, + recvbuf, minfo->nph*minfo->nmmax*minfo->npairmax*2,MPI_DOUBLE, + minfo->comm); + DEALLOC(sendbuf); ALLOC(*ph,dcmplx,minfo->nph*minfo->nm[minfo->mytask]*minfo->npairtotal); - - MPI_Alltoallv (phas_tmp,minfo->mapcount,minfo->mapdisp,MPI_DOUBLE, - *ph,minfo->almcount,minfo->almdisp,MPI_DOUBLE,minfo->comm); - - DEALLOC(phas_tmp); + for (int task=0; task<minfo->ntasks; ++task) + for (int ti=0; ti<minfo->npair[task]; ++ti) + for (int mi=0; mi<minfo->nm[minfo->mytask]; ++mi) + { + int th = minfo->ofs_pair[task] + ti; + int obuf = minfo->nmmax*minfo->npairmax*minfo->nph*task + +minfo->nph*(mi + minfo->nmmax*ti); + int oarr = minfo->nph*(th*(minfo->nm[minfo->mytask]) + mi); + for (int i=0; i<minfo->nph; ++i) + (*ph)[oarr+i] = recvbuf[obuf+i]; + } + DEALLOC(recvbuf); } static void alloc_phase_mpi (sharp_job *job, int nm, int ntheta,