Commit bbb5785a authored by Thomas Auckenthaler's avatar Thomas Auckenthaler
Browse files

Higher blocking for banded2full back transformation added

parent 086d6ecb
......@@ -724,9 +724,9 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
integer l_cols, l_rows, l_colh, n_cols
integer istep, lc, ncol, nrow, nb, ns
real*8, allocatable:: tmp1(:), tmp2(:), hvb(:), hvm(:,:)
real*8, allocatable:: tmp1(:), tmp2(:), hvb(:), hvm(:,:), tmat_complete(:,:), t_tmp(:,:)
integer pcol, prow, i
integer pcol, prow, i, cwy_blocking, t_blocking, t_cols, t_rows
pcol(i) = MOD((i-1)/nblk,np_cols) !Processor col for global col number
prow(i) = MOD((i-1)/nblk,np_rows) !Processor row for global row number
......@@ -742,20 +742,25 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
max_local_rows = max_blocks_row*nblk
max_local_cols = max_blocks_col*nblk
t_blocking = 2 ! number of matrices T (tmat) which are aggregated into a new (larger) T matrix (tmat_complete) and applied at once
cwy_blocking = t_blocking * nbw
allocate(tmp1(max_local_cols*nbw))
allocate(tmp2(max_local_cols*nbw))
allocate(hvb(max_local_rows*nbw))
allocate(hvm(max_local_rows,nbw))
allocate(tmp1(max_local_cols*cwy_blocking))
allocate(tmp2(max_local_cols*cwy_blocking))
allocate(hvb(max_local_rows*cwy_blocking))
allocate(hvm(max_local_rows,cwy_blocking))
allocate(tmat_complete(cwy_blocking,cwy_blocking))
allocate(t_tmp(cwy_blocking,nbw))
hvm = 0 ! Must be set to 0 !!!
hvb = 0 ! Safety only
l_cols = local_index(nqc, my_pcol, np_cols, nblk, -1) ! Local columns of q
do istep=1,(na-1)/nbw
do istep=1,((na-1)/nbw-1)/t_blocking + 1
n_cols = MIN(na,(istep+1)*nbw) - istep*nbw ! Number of columns in current step
n_cols = MIN(na,istep*cwy_blocking+nbw) - (istep-1)*cwy_blocking - nbw ! Number of columns in current step
! Broadcast all Householder vectors for current step compressed in hvb
......@@ -763,7 +768,7 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
ns = 0
do lc = 1, n_cols
ncol = istep*nbw + lc ! absolute column number of householder vector
ncol = (istep-1)*cwy_blocking + nbw + lc ! absolute column number of householder vector
nrow = ncol - nbw ! absolute number of pivot row
l_rows = local_index(nrow-1, my_prow, np_rows, nblk, -1) ! row length for bcast
......@@ -783,7 +788,7 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
nb = 0
do lc = 1, n_cols
nrow = (istep-1)*nbw+lc ! absolute number of pivot row
nrow = (istep-1)*cwy_blocking + lc ! absolute number of pivot row
l_rows = local_index(nrow-1, my_prow, np_rows, nblk, -1) ! row length for bcast
hvm(1:l_rows,lc) = hvb(nb+1:nb+l_rows)
......@@ -792,21 +797,35 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
nb = nb+l_rows
enddo
l_rows = local_index(MIN(na,(istep+1)*nbw), my_prow, np_rows, nblk, -1)
l_rows = local_index(MIN(na,(istep+1)*cwy_blocking), my_prow, np_rows, nblk, -1)
! compute tmat2 out of tmat(:,:,)
tmat_complete = 0
do i = 1, t_blocking
t_cols = MIN(nbw, n_cols - (i-1)*nbw)
if(t_cols <= 0) exit
t_rows = (i - 1) * nbw
tmat_complete(t_rows+1:t_rows+t_cols,t_rows+1:t_rows+t_cols) = tmat(1:t_cols,1:t_cols,(istep-1)*t_blocking + i)
if(i > 1) then
call dgemm('T', 'N', t_rows, t_cols, l_rows, 1.d0, hvm(1,1), max_local_rows, hvm(1,(i-1)*nbw+1), max_local_rows, 0.d0, t_tmp, cwy_blocking)
call mpi_allreduce(MPI_IN_PLACE,t_tmp,cwy_blocking*nbw,MPI_REAL8,MPI_SUM,mpi_comm_rows,mpierr)
call dtrmm('L','U','N','N',t_rows,t_cols,1.0d0,tmat_complete,cwy_blocking,t_tmp,cwy_blocking)
call dtrmm('R','U','N','N',t_rows,t_cols,-1.0d0,tmat_complete(t_rows+1,t_rows+1),cwy_blocking,t_tmp,cwy_blocking)
tmat_complete(1:t_rows,t_rows+1:t_rows+t_cols) = t_tmp(1:t_rows,1:t_cols)
endif
enddo
! Q = Q - V * T**T * V**T * Q
if(l_rows>0) then
call dgemm('T','N',n_cols,l_cols,l_rows,1.d0,hvm,ubound(hvm,1), &
q,ldq,0.d0,tmp1,n_cols)
call dgemm('T','N',n_cols,l_cols,l_rows,1.d0,hvm,ubound(hvm,1), q,ldq,0.d0,tmp1,n_cols)
else
tmp1(1:l_cols*n_cols) = 0
endif
call mpi_allreduce(tmp1,tmp2,n_cols*l_cols,MPI_REAL8,MPI_SUM,mpi_comm_rows,mpierr)
if(l_rows>0) then
call dtrmm('L','U','T','N',n_cols,l_cols,1.0d0,tmat(1,1,istep),ubound(tmat,1),tmp2,n_cols)
call dgemm('N','N',l_rows,l_cols,n_cols,-1.d0,hvm,ubound(hvm,1), &
tmp2,n_cols,1.d0,q,ldq)
call dtrmm('L','U','T','N',n_cols,l_cols,1.0d0,tmat_complete,cwy_blocking,tmp2,n_cols)
call dgemm('N','N',l_rows,l_cols,n_cols,-1.d0,hvm,ubound(hvm,1), tmp2,n_cols,1.d0,q,ldq)
endif
enddo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment