Commit b6ce5fd4 authored by Alexander Heinecke's avatar Alexander Heinecke
Browse files

first attempt to fix the bug which was rolled back in 402629d9.

Current fix does as much blocking as possible, which should be
beneficial from both a compute and communication point of view.

Additionally, a second possible fix was added which just calls
the blocked version if the local matrix has a sufficient size.
This might create smaller and more messages at scale.
parent 402629d9
...@@ -1173,13 +1173,11 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq, ...@@ -1173,13 +1173,11 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
max_local_rows = max_blocks_row*nblk max_local_rows = max_blocks_row*nblk
max_local_cols = max_blocks_col*nblk max_local_cols = max_blocks_col*nblk
! t_blocking was formerly 2; 3 is a better choice
t_blocking = 3 ! number of matrices T (tmat) which are aggregated into a new (larger) T matrix (tmat_complete) and applied at once
! This conditional was introduced due to an merge error. For better performance this code path should ! we only use the t_blocking if we could call it fully, this is might be better but needs to benchmarked.
! always be used ! if ( na >= ((t_blocking+1)*nbw) ) then
if (useQR) then
! t_blocking was formerly 2; 3 is a better choice
t_blocking = 3 ! number of matrices T (tmat) which are aggregated into a new (larger) T matrix (tmat_complete) and applied at once
cwy_blocking = t_blocking * nbw cwy_blocking = t_blocking * nbw
allocate(tmp1(max_local_cols*cwy_blocking)) allocate(tmp1(max_local_cols*cwy_blocking))
...@@ -1189,23 +1187,33 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq, ...@@ -1189,23 +1187,33 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
allocate(tmat_complete(cwy_blocking,cwy_blocking)) allocate(tmat_complete(cwy_blocking,cwy_blocking))
allocate(t_tmp(cwy_blocking,nbw)) allocate(t_tmp(cwy_blocking,nbw))
allocate(t_tmp2(cwy_blocking,nbw)) allocate(t_tmp2(cwy_blocking,nbw))
else ! else
allocate(tmp1(max_local_cols*nbw)) ! allocate(tmp1(max_local_cols*nbw))
allocate(tmp2(max_local_cols*nbw)) ! allocate(tmp2(max_local_cols*nbw))
allocate(hvb(max_local_rows*nbw)) ! allocate(hvb(max_local_rows*nbw))
allocate(hvm(max_local_rows,nbw)) ! allocate(hvm(max_local_rows,nbw))
endif ! endif
hvm = 0 ! Must be set to 0 !!! hvm = 0 ! Must be set to 0 !!!
hvb = 0 ! Safety only hvb = 0 ! Safety only
l_cols = local_index(nqc, my_pcol, np_cols, nblk, -1) ! Local columns of q l_cols = local_index(nqc, my_pcol, np_cols, nblk, -1) ! Local columns of q
! This conditional has been introduced by the same merge error. Execute always this code path ! if ( na >= ((t_blocking+1)*nbw) ) then
if (useQR) then
do istep=1,((na-1)/nbw-1)/t_blocking + 1 do istep=1,((na-1)/nbw-1)/t_blocking + 1
n_cols = MIN(na,istep*cwy_blocking+nbw) - (istep-1)*cwy_blocking - nbw ! Number of columns in current step ! This the call when using na >= ((t_blocking+1)*nbw)
! n_cols = MIN(na,istep*cwy_blocking+nbw) - (istep-1)*cwy_blocking - nbw ! Number of columns in current step
! As an alternative we add some special case handling if na < cwy_blocking
IF (na < cwy_blocking) THEN
n_cols = MAX(0, na-nbw)
IF ( n_cols .eq. 0 ) THEN
EXIT
END IF
ELSE
n_cols = MIN(na,istep*cwy_blocking+nbw) - (istep-1)*cwy_blocking - nbw ! Number of columns in current step
END IF
! Broadcast all Householder vectors for current step compressed in hvb ! Broadcast all Householder vectors for current step compressed in hvb
...@@ -1278,72 +1286,72 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq, ...@@ -1278,72 +1286,72 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
endif endif
enddo enddo
else ! do not useQR ! else
!
do istep=1,(na-1)/nbw ! do istep=1,(na-1)/nbw
!
n_cols = MIN(na,(istep+1)*nbw) - istep*nbw ! Number of columns in current step ! n_cols = MIN(na,(istep+1)*nbw) - istep*nbw ! Number of columns in current step
!
! Broadcast all Householder vectors for current step compressed in hvb ! ! Broadcast all Householder vectors for current step compressed in hvb
!
nb = 0 ! nb = 0
ns = 0 ! ns = 0
!
do lc = 1, n_cols ! do lc = 1, n_cols
ncol = istep*nbw + lc ! absolute column number of householder vector ! ncol = istep*nbw + lc ! absolute column number of householder vector
nrow = ncol - nbw ! absolute number of pivot row ! nrow = ncol - nbw ! absolute number of pivot row
!
l_rows = local_index(nrow-1, my_prow, np_rows, nblk, -1) ! row length for bcast ! l_rows = local_index(nrow-1, my_prow, np_rows, nblk, -1) ! row length for bcast
l_colh = local_index(ncol , my_pcol, np_cols, nblk, -1) ! HV local column number ! l_colh = local_index(ncol , my_pcol, np_cols, nblk, -1) ! HV local column number
!
if (my_pcol==pcol(ncol, nblk, np_cols)) hvb(nb+1:nb+l_rows) = a(1:l_rows,l_colh) ! if (my_pcol==pcol(ncol, nblk, np_cols)) hvb(nb+1:nb+l_rows) = a(1:l_rows,l_colh)
!
nb = nb+l_rows ! nb = nb+l_rows
!
if (lc==n_cols .or. mod(ncol,nblk)==0) then ! if (lc==n_cols .or. mod(ncol,nblk)==0) then
call MPI_Bcast(hvb(ns+1),nb-ns,MPI_REAL8,pcol(ncol, nblk, np_cols),mpi_comm_cols,mpierr) ! call MPI_Bcast(hvb(ns+1),nb-ns,MPI_REAL8,pcol(ncol, nblk, np_cols),mpi_comm_cols,mpierr)
ns = nb ! ns = nb
endif ! endif
enddo ! enddo
!
! Expand compressed Householder vectors into matrix hvm ! ! Expand compressed Householder vectors into matrix hvm
!
nb = 0 ! nb = 0
do lc = 1, n_cols ! do lc = 1, n_cols
nrow = (istep-1)*nbw+lc ! absolute number of pivot row ! nrow = (istep-1)*nbw+lc ! absolute number of pivot row
l_rows = local_index(nrow-1, my_prow, np_rows, nblk, -1) ! row length for bcast ! l_rows = local_index(nrow-1, my_prow, np_rows, nblk, -1) ! row length for bcast
!
hvm(1:l_rows,lc) = hvb(nb+1:nb+l_rows) ! hvm(1:l_rows,lc) = hvb(nb+1:nb+l_rows)
if (my_prow==prow(nrow, nblk, np_rows)) hvm(l_rows+1,lc) = 1. ! if (my_prow==prow(nrow, nblk, np_rows)) hvm(l_rows+1,lc) = 1.
!
nb = nb+l_rows ! nb = nb+l_rows
enddo ! enddo
!
l_rows = local_index(MIN(na,(istep+1)*nbw), my_prow, np_rows, nblk, -1) ! l_rows = local_index(MIN(na,(istep+1)*nbw), my_prow, np_rows, nblk, -1)
!
! Q = Q - V * T**T * V**T * Q ! ! Q = Q - V * T**T * V**T * Q
!
if (l_rows>0) then ! if (l_rows>0) then
call dgemm('T','N',n_cols,l_cols,l_rows,1.d0,hvm,ubound(hvm,dim=1), & ! call dgemm('T','N',n_cols,l_cols,l_rows,1.d0,hvm,ubound(hvm,dim=1), &
q,ldq,0.d0,tmp1,n_cols) ! q,ldq,0.d0,tmp1,n_cols)
else ! else
tmp1(1:l_cols*n_cols) = 0 ! tmp1(1:l_cols*n_cols) = 0
endif ! endif
!
call mpi_allreduce(tmp1,tmp2,n_cols*l_cols,MPI_REAL8,MPI_SUM,mpi_comm_rows,mpierr) ! call mpi_allreduce(tmp1,tmp2,n_cols*l_cols,MPI_REAL8,MPI_SUM,mpi_comm_rows,mpierr)
!
if (l_rows>0) then ! if (l_rows>0) then
call dtrmm('L','U','T','N',n_cols,l_cols,1.0d0,tmat(1,1,istep),ubound(tmat,dim=1),tmp2,n_cols) ! call dtrmm('L','U','T','N',n_cols,l_cols,1.0d0,tmat(1,1,istep),ubound(tmat,dim=1),tmp2,n_cols)
call dgemm('N','N',l_rows,l_cols,n_cols,-1.d0,hvm,ubound(hvm,dim=1), & ! call dgemm('N','N',l_rows,l_cols,n_cols,-1.d0,hvm,ubound(hvm,dim=1), &
tmp2,n_cols,1.d0,q,ldq) ! tmp2,n_cols,1.d0,q,ldq)
endif ! endif
enddo ! enddo
endif ! endQR ! endif
deallocate(tmp1, tmp2, hvb, hvm) deallocate(tmp1, tmp2, hvb, hvm)
if (useQR) then ! if ( na >= ((t_blocking+1)*nbw) ) then
deallocate(tmat_complete, t_tmp, t_tmp2) deallocate(tmat_complete, t_tmp, t_tmp2)
endif ! endif
#ifdef HAVE_DETAILED_TIMINGS #ifdef HAVE_DETAILED_TIMINGS
call timer%stop("trans_ev_band_to_full_real") call timer%stop("trans_ev_band_to_full_real")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment