diff --git a/libsharp2/sharp_mpi.c b/libsharp2/sharp_mpi.c index c90f41bc815ed304bdc1e45944ab815406aa2b03..e0d4804d10e30dd202e2180b3b906b62d0f4932a 100644 --- a/libsharp2/sharp_mpi.c +++ b/libsharp2/sharp_mpi.c @@ -149,13 +149,34 @@ static void measure_drift(const sharp_mpi_info *minfo, const char *msg) printf("drift at %s: %e\n", msg, time2-timered); } -static void sharp_communicate_alm2map (const sharp_mpi_info *minfo, dcmplx **ph) +static void alloc_phase_mpi(sharp_job *job, int nm, int ntheta, int nmfull, int nthetafull, int full_theta, int fast_theta) { -printf("task %d arrived at %e\n", minfo->mytask, MPI_Wtime()); -MPI_Barrier(minfo->comm); -double time=MPI_Wtime(); + if (full_theta) + ntheta = nthetafull; + else + nm = nmfull; + if (fast_theta) + { + job->s_th=2*job->nmaps; + if (((job->s_th*16*ntheta)&1023)==0) ntheta+=3; // hack to avoid critical strides + job->s_m=job->s_th*ntheta; + } + else + { + job->s_m=2*job->nmaps; + if (((job->s_m*16*nm)&1023)==0) nm+=3; // hack to avoid critical strides + job->s_th=job->s_m*nm; + } + job->phase=RALLOC(dcmplx,2*job->nmaps*nm*ntheta); + } + +static void sharp_communicate_alm2map (const sharp_mpi_info *minfo, sharp_job *job) + { +//printf("task %d arrived at %e\n", minfo->mytask, MPI_Wtime()); +//MPI_Barrier(minfo->comm); +//double time=MPI_Wtime(); +//printf("%d %d \n", job->s_m, job->s_th); - // on input: ph has shape(npairtotal,nm[task],nph) dcmplx *sendbuf = RALLOC(dcmplx,minfo->nmmax*minfo->ntasks*minfo->npairmax*minfo->nph); for (int task=0; task<minfo->ntasks; ++task) for (int ti=0; ti<minfo->npair[task]; ++ti) @@ -164,18 +185,24 @@ double time=MPI_Wtime(); int th = minfo->ofs_pair[task] + ti; int obuf = minfo->nmmax*minfo->npairmax*minfo->nph*task +minfo->nph*(mi + minfo->nmmax*ti); - int oarr = minfo->nph*(th*(minfo->nm[minfo->mytask]) + mi); + int oarr = mi*job->s_m + th*job->s_th; for (int i=0; i<minfo->nph; ++i) - sendbuf[obuf+i] = (*ph)[oarr+i]; + sendbuf[obuf+i] = (job->phase)[oarr+i]; } - DEALLOC(*ph); + DEALLOC(job->phase); +//MPI_Barrier(minfo->comm); +//printf("xtask %d arrived at %e\n", minfo->mytask, MPI_Wtime()); dcmplx *recvbuf = RALLOC(dcmplx,minfo->nmmax*minfo->ntasks*minfo->npairmax*minfo->nph); MPI_Alltoall (sendbuf, minfo->nph*minfo->nmmax*minfo->npairmax*2,MPI_DOUBLE, recvbuf, minfo->nph*minfo->nmmax*minfo->npairmax*2,MPI_DOUBLE, minfo->comm); DEALLOC(sendbuf); - ALLOC(*ph,dcmplx,minfo->nph*minfo->npair[minfo->mytask]*minfo->nmtotal); - // on output: ph has shape(npair[task],mmax+1,nph) + alloc_phase_mpi(job,job->ainfo->nm,job->ginfo->npairs,minfo->mmax+1, + minfo->npairtotal, job->type==SHARP_MAP2ALM, job->type==SHARP_MAP2ALM); +//MPI_Barrier(minfo->comm); +//printf("ytask %d arrived at %e\n", minfo->mytask, MPI_Wtime()); +//printf("%d %d \n", job->s_m, job->s_th); + for (int task=0; task<minfo->ntasks; ++task) for (int ti=0; ti<minfo->npair[minfo->mytask]; ++ti) for (int mi=0; mi<minfo->nm[task]; ++mi) @@ -183,17 +210,19 @@ double time=MPI_Wtime(); int m = minfo->mval[mi+minfo->ofs_m[task]]; int obuf = minfo->nmmax*minfo->npairmax*minfo->nph*task +minfo->nph*(mi + minfo->nmmax*ti); - int oarr = minfo->nph*(ti*(minfo->mmax+1) + m); + int oarr = m*job->s_m + ti*job->s_th; for (int i=0; i<minfo->nph; ++i) - (*ph)[oarr+i] = recvbuf[obuf+i]; + (job->phase)[oarr+i] = recvbuf[obuf+i]; } +//MPI_Barrier(minfo->comm); +//printf("ztask %d arrived at %e\n", minfo->mytask, MPI_Wtime()); DEALLOC(recvbuf); -MPI_Barrier(minfo->comm); -if (minfo->mytask==0) printf("time for alm2map communication: %e\n", MPI_Wtime()-time); +//MPI_Barrier(minfo->comm); +//if (minfo->mytask==0) printf("time for alm2map communication: %e\n", MPI_Wtime()-time); } -static void sharp_communicate_map2alm (const sharp_mpi_info *minfo, dcmplx **ph) +static void sharp_communicate_map2alm (const sharp_mpi_info *minfo, sharp_job *job) { dcmplx *sendbuf = RALLOC(dcmplx,minfo->nmmax*minfo->ntasks*minfo->npairmax*minfo->nph); for (int task=0; task<minfo->ntasks; ++task) @@ -203,17 +232,20 @@ static void sharp_communicate_map2alm (const sharp_mpi_info *minfo, dcmplx **ph) int m = minfo->mval[mi+minfo->ofs_m[task]]; int obuf = minfo->nmmax*minfo->npairmax*minfo->nph*task +minfo->nph*(mi + minfo->nmmax*ti); - int oarr = minfo->nph*(ti*(minfo->mmax+1) + m); + int oarr = m*job->s_m + ti*job->s_th; for (int i=0; i<minfo->nph; ++i) - sendbuf[obuf+i] = (*ph)[oarr+i]; + sendbuf[obuf+i] = (job->phase)[oarr+i]; } - DEALLOC(*ph); + DEALLOC(job->phase); dcmplx *recvbuf = RALLOC(dcmplx,minfo->nmmax*minfo->ntasks*minfo->npairmax*minfo->nph); MPI_Alltoall (sendbuf, minfo->nph*minfo->nmmax*minfo->npairmax*2,MPI_DOUBLE, recvbuf, minfo->nph*minfo->nmmax*minfo->npairmax*2,MPI_DOUBLE, minfo->comm); DEALLOC(sendbuf); - ALLOC(*ph,dcmplx,minfo->nph*minfo->nm[minfo->mytask]*minfo->npairtotal); + + alloc_phase_mpi(job,job->ainfo->nm,job->ginfo->npairs,minfo->mmax+1, + minfo->npairtotal, job->type==SHARP_MAP2ALM, job->type==SHARP_MAP2ALM); + for (int task=0; task<minfo->ntasks; ++task) for (int ti=0; ti<minfo->npair[task]; ++ti) for (int mi=0; mi<minfo->nm[minfo->mytask]; ++mi) @@ -221,39 +253,23 @@ static void sharp_communicate_map2alm (const sharp_mpi_info *minfo, dcmplx **ph) int th = minfo->ofs_pair[task] + ti; int obuf = minfo->nmmax*minfo->npairmax*minfo->nph*task +minfo->nph*(mi + minfo->nmmax*ti); - int oarr = minfo->nph*(th*(minfo->nm[minfo->mytask]) + mi); + int oarr = mi*job->s_m + th*job->s_th; for (int i=0; i<minfo->nph; ++i) - (*ph)[oarr+i] = recvbuf[obuf+i]; + (job->phase)[oarr+i] = recvbuf[obuf+i]; } DEALLOC(recvbuf); } -static void alloc_phase_mpi (sharp_job *job, int nm, int ntheta, - int nmfull, int nthetafull) - { - ptrdiff_t phase_size = (job->type==SHARP_MAP2ALM) ? - (ptrdiff_t)(nmfull)*ntheta : (ptrdiff_t)(nm)*nthetafull; - job->phase=RALLOC(dcmplx,2*job->nmaps*phase_size); - job->s_m=2*job->nmaps; - job->s_th = job->s_m * ((job->type==SHARP_MAP2ALM) ? nmfull : nm); - } - static void alm2map_comm (sharp_job *job, const sharp_mpi_info *minfo) { if (job->type != SHARP_MAP2ALM) - { - sharp_communicate_alm2map (minfo,&job->phase); - job->s_th=job->s_m*minfo->nmtotal; - } + sharp_communicate_alm2map (minfo,job); } static void map2alm_comm (sharp_job *job, const sharp_mpi_info *minfo) { if (job->type == SHARP_MAP2ALM) - { - sharp_communicate_map2alm (minfo,&job->phase); - job->s_th=job->s_m*minfo->nm[minfo->mytask]; - } + sharp_communicate_map2alm (minfo,job); } static void sharp_execute_job_mpi (sharp_job *job, MPI_Comm comm) @@ -295,19 +311,19 @@ static void sharp_execute_job_mpi (sharp_job *job, MPI_Comm comm) } else { -measure_drift(&minfo,"start"); +//measure_drift(&minfo,"start"); int lmax = job->ainfo->lmax; job->norm_l = sharp_Ylmgen_get_norm (lmax, job->spin); -measure_drift(&minfo,"after get_norm"); +//measure_drift(&minfo,"after get_norm"); /* clear output arrays if requested */ init_output (job); - alloc_phase_mpi (job,job->ainfo->nm,job->ginfo->npairs,minfo.mmax+1, - minfo.npairtotal); + alloc_phase_mpi(job,job->ainfo->nm,job->ginfo->npairs,minfo.mmax+1, + minfo.npairtotal, job->type!=SHARP_MAP2ALM, job->type!=SHARP_MAP2ALM); double *cth = RALLOC(double,minfo.npairtotal), - *sth = RALLOC(double,minfo.npairtotal); + *sth = RALLOC(double,minfo.npairtotal); int *mlim = RALLOC(int,minfo.npairtotal); for (int i=0; i<minfo.npairtotal; ++i) {