Commit b6ce5fd4 authored by Alexander Heinecke's avatar Alexander Heinecke

first attempt to fix the bug which was rolled back in 402629d9.

Current fix does as much blocking as possible, which should be
beneficial from both a compute and communication point of view.

Additionally, a second possible fix was added which just calls
the blocked version if the local matrix has a sufficient size.
This might create smaller and more messages at scale.
parent 402629d9
......@@ -1173,13 +1173,11 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
max_local_rows = max_blocks_row*nblk
max_local_cols = max_blocks_col*nblk
! t_blocking was formerly 2; 3 is a better choice
t_blocking = 3 ! number of matrices T (tmat) which are aggregated into a new (larger) T matrix (tmat_complete) and applied at once
! This conditional was introduced due to an merge error. For better performance this code path should
! always be used
if (useQR) then
! t_blocking was formerly 2; 3 is a better choice
t_blocking = 3 ! number of matrices T (tmat) which are aggregated into a new (larger) T matrix (tmat_complete) and applied at once
! we only use the t_blocking if we could call it fully, this is might be better but needs to benchmarked.
! if ( na >= ((t_blocking+1)*nbw) ) then
cwy_blocking = t_blocking * nbw
allocate(tmp1(max_local_cols*cwy_blocking))
......@@ -1189,23 +1187,33 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
allocate(tmat_complete(cwy_blocking,cwy_blocking))
allocate(t_tmp(cwy_blocking,nbw))
allocate(t_tmp2(cwy_blocking,nbw))
else
allocate(tmp1(max_local_cols*nbw))
allocate(tmp2(max_local_cols*nbw))
allocate(hvb(max_local_rows*nbw))
allocate(hvm(max_local_rows,nbw))
endif
! else
! allocate(tmp1(max_local_cols*nbw))
! allocate(tmp2(max_local_cols*nbw))
! allocate(hvb(max_local_rows*nbw))
! allocate(hvm(max_local_rows,nbw))
! endif
hvm = 0 ! Must be set to 0 !!!
hvb = 0 ! Safety only
l_cols = local_index(nqc, my_pcol, np_cols, nblk, -1) ! Local columns of q
! This conditional has been introduced by the same merge error. Execute always this code path
if (useQR) then
! if ( na >= ((t_blocking+1)*nbw) ) then
do istep=1,((na-1)/nbw-1)/t_blocking + 1
n_cols = MIN(na,istep*cwy_blocking+nbw) - (istep-1)*cwy_blocking - nbw ! Number of columns in current step
! This the call when using na >= ((t_blocking+1)*nbw)
! n_cols = MIN(na,istep*cwy_blocking+nbw) - (istep-1)*cwy_blocking - nbw ! Number of columns in current step
! As an alternative we add some special case handling if na < cwy_blocking
IF (na < cwy_blocking) THEN
n_cols = MAX(0, na-nbw)
IF ( n_cols .eq. 0 ) THEN
EXIT
END IF
ELSE
n_cols = MIN(na,istep*cwy_blocking+nbw) - (istep-1)*cwy_blocking - nbw ! Number of columns in current step
END IF
! Broadcast all Householder vectors for current step compressed in hvb
......@@ -1278,72 +1286,72 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
endif
enddo
else ! do not useQR
do istep=1,(na-1)/nbw
n_cols = MIN(na,(istep+1)*nbw) - istep*nbw ! Number of columns in current step
! Broadcast all Householder vectors for current step compressed in hvb
nb = 0
ns = 0
do lc = 1, n_cols
ncol = istep*nbw + lc ! absolute column number of householder vector
nrow = ncol - nbw ! absolute number of pivot row
l_rows = local_index(nrow-1, my_prow, np_rows, nblk, -1) ! row length for bcast
l_colh = local_index(ncol , my_pcol, np_cols, nblk, -1) ! HV local column number
if (my_pcol==pcol(ncol, nblk, np_cols)) hvb(nb+1:nb+l_rows) = a(1:l_rows,l_colh)
nb = nb+l_rows
if (lc==n_cols .or. mod(ncol,nblk)==0) then
call MPI_Bcast(hvb(ns+1),nb-ns,MPI_REAL8,pcol(ncol, nblk, np_cols),mpi_comm_cols,mpierr)
ns = nb
endif
enddo
! Expand compressed Householder vectors into matrix hvm
nb = 0
do lc = 1, n_cols
nrow = (istep-1)*nbw+lc ! absolute number of pivot row
l_rows = local_index(nrow-1, my_prow, np_rows, nblk, -1) ! row length for bcast
hvm(1:l_rows,lc) = hvb(nb+1:nb+l_rows)
if (my_prow==prow(nrow, nblk, np_rows)) hvm(l_rows+1,lc) = 1.
nb = nb+l_rows
enddo
l_rows = local_index(MIN(na,(istep+1)*nbw), my_prow, np_rows, nblk, -1)
! Q = Q - V * T**T * V**T * Q
if (l_rows>0) then
call dgemm('T','N',n_cols,l_cols,l_rows,1.d0,hvm,ubound(hvm,dim=1), &
q,ldq,0.d0,tmp1,n_cols)
else
tmp1(1:l_cols*n_cols) = 0
endif
call mpi_allreduce(tmp1,tmp2,n_cols*l_cols,MPI_REAL8,MPI_SUM,mpi_comm_rows,mpierr)
if (l_rows>0) then
call dtrmm('L','U','T','N',n_cols,l_cols,1.0d0,tmat(1,1,istep),ubound(tmat,dim=1),tmp2,n_cols)
call dgemm('N','N',l_rows,l_cols,n_cols,-1.d0,hvm,ubound(hvm,dim=1), &
tmp2,n_cols,1.d0,q,ldq)
endif
enddo
endif ! endQR
! else
!
! do istep=1,(na-1)/nbw
!
! n_cols = MIN(na,(istep+1)*nbw) - istep*nbw ! Number of columns in current step
!
! ! Broadcast all Householder vectors for current step compressed in hvb
!
! nb = 0
! ns = 0
!
! do lc = 1, n_cols
! ncol = istep*nbw + lc ! absolute column number of householder vector
! nrow = ncol - nbw ! absolute number of pivot row
!
! l_rows = local_index(nrow-1, my_prow, np_rows, nblk, -1) ! row length for bcast
! l_colh = local_index(ncol , my_pcol, np_cols, nblk, -1) ! HV local column number
!
! if (my_pcol==pcol(ncol, nblk, np_cols)) hvb(nb+1:nb+l_rows) = a(1:l_rows,l_colh)
!
! nb = nb+l_rows
!
! if (lc==n_cols .or. mod(ncol,nblk)==0) then
! call MPI_Bcast(hvb(ns+1),nb-ns,MPI_REAL8,pcol(ncol, nblk, np_cols),mpi_comm_cols,mpierr)
! ns = nb
! endif
! enddo
!
! ! Expand compressed Householder vectors into matrix hvm
!
! nb = 0
! do lc = 1, n_cols
! nrow = (istep-1)*nbw+lc ! absolute number of pivot row
! l_rows = local_index(nrow-1, my_prow, np_rows, nblk, -1) ! row length for bcast
!
! hvm(1:l_rows,lc) = hvb(nb+1:nb+l_rows)
! if (my_prow==prow(nrow, nblk, np_rows)) hvm(l_rows+1,lc) = 1.
!
! nb = nb+l_rows
! enddo
!
! l_rows = local_index(MIN(na,(istep+1)*nbw), my_prow, np_rows, nblk, -1)
!
! ! Q = Q - V * T**T * V**T * Q
!
! if (l_rows>0) then
! call dgemm('T','N',n_cols,l_cols,l_rows,1.d0,hvm,ubound(hvm,dim=1), &
! q,ldq,0.d0,tmp1,n_cols)
! else
! tmp1(1:l_cols*n_cols) = 0
! endif
!
! call mpi_allreduce(tmp1,tmp2,n_cols*l_cols,MPI_REAL8,MPI_SUM,mpi_comm_rows,mpierr)
!
! if (l_rows>0) then
! call dtrmm('L','U','T','N',n_cols,l_cols,1.0d0,tmat(1,1,istep),ubound(tmat,dim=1),tmp2,n_cols)
! call dgemm('N','N',l_rows,l_cols,n_cols,-1.d0,hvm,ubound(hvm,dim=1), &
! tmp2,n_cols,1.d0,q,ldq)
! endif
! enddo
! endif
deallocate(tmp1, tmp2, hvb, hvm)
if (useQR) then
! if ( na >= ((t_blocking+1)*nbw) ) then
deallocate(tmat_complete, t_tmp, t_tmp2)
endif
! endif
#ifdef HAVE_DETAILED_TIMINGS
call timer%stop("trans_ev_band_to_full_real")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment