Commit b48cf00a authored by Andreas Marek's avatar Andreas Marek
Browse files

Start to remove assumed size arrays

This commit is not ABI compatible, since it changes the interfaces
of some routines

Also, introduce type checking for transpose and reduce_add routines
parent bf168297
This diff is collapsed.
...@@ -112,6 +112,7 @@ module ELPA2 ...@@ -112,6 +112,7 @@ module ELPA2
contains contains
function solve_evp_real_2stage(na, nev, a, lda, ev, q, ldq, nblk, & function solve_evp_real_2stage(na, nev, a, lda, ev, q, ldq, nblk, &
matrixCols, &
mpi_comm_rows, mpi_comm_cols, & mpi_comm_rows, mpi_comm_cols, &
mpi_comm_all, THIS_REAL_ELPA_KERNEL_API,& mpi_comm_all, THIS_REAL_ELPA_KERNEL_API,&
useQR) result(success) useQR) result(success)
...@@ -159,10 +160,10 @@ function solve_evp_real_2stage(na, nev, a, lda, ev, q, ldq, nblk, & ...@@ -159,10 +160,10 @@ function solve_evp_real_2stage(na, nev, a, lda, ev, q, ldq, nblk, &
integer, intent(in), optional :: THIS_REAL_ELPA_KERNEL_API integer, intent(in), optional :: THIS_REAL_ELPA_KERNEL_API
integer :: THIS_REAL_ELPA_KERNEL integer :: THIS_REAL_ELPA_KERNEL
integer, intent(in) :: na, nev, lda, ldq, mpi_comm_rows, & integer, intent(in) :: na, nev, lda, ldq, matrixCols, mpi_comm_rows, &
mpi_comm_cols, mpi_comm_all mpi_comm_cols, mpi_comm_all
integer, intent(in) :: nblk integer, intent(in) :: nblk
real*8, intent(inout) :: a(lda,*), ev(na), q(ldq,*) real*8, intent(inout) :: a(lda,matrixCols), ev(na), q(ldq,matrixCols)
integer :: my_pe, n_pes, my_prow, my_pcol, np_rows, np_cols, mpierr integer :: my_pe, n_pes, my_prow, my_pcol, np_rows, np_cols, mpierr
integer :: nbw, num_blocks integer :: nbw, num_blocks
...@@ -290,7 +291,7 @@ function solve_evp_real_2stage(na, nev, a, lda, ev, q, ldq, nblk, & ...@@ -290,7 +291,7 @@ function solve_evp_real_2stage(na, nev, a, lda, ev, q, ldq, nblk, &
! Solve tridiagonal system ! Solve tridiagonal system
ttt0 = MPI_Wtime() ttt0 = MPI_Wtime()
call solve_tridi(na, nev, ev, e, q, ldq, nblk, mpi_comm_rows, & call solve_tridi(na, nev, ev, e, q, ldq, nblk, matrixCols, mpi_comm_rows, &
mpi_comm_cols, wantDebug, success) mpi_comm_cols, wantDebug, success)
if (.not.(success)) return if (.not.(success)) return
...@@ -338,7 +339,7 @@ end function solve_evp_real_2stage ...@@ -338,7 +339,7 @@ end function solve_evp_real_2stage
!------------------------------------------------------------------------------- !-------------------------------------------------------------------------------
function solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, & function solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, &
mpi_comm_rows, mpi_comm_cols, & matrixCols, mpi_comm_rows, mpi_comm_cols, &
mpi_comm_all, THIS_COMPLEX_ELPA_KERNEL_API) result(success) mpi_comm_all, THIS_COMPLEX_ELPA_KERNEL_API) result(success)
!------------------------------------------------------------------------------- !-------------------------------------------------------------------------------
...@@ -381,8 +382,8 @@ function solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, & ...@@ -381,8 +382,8 @@ function solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, &
implicit none implicit none
integer, intent(in), optional :: THIS_COMPLEX_ELPA_KERNEL_API integer, intent(in), optional :: THIS_COMPLEX_ELPA_KERNEL_API
integer :: THIS_COMPLEX_ELPA_KERNEL integer :: THIS_COMPLEX_ELPA_KERNEL
integer, intent(in) :: na, nev, lda, ldq, nblk, mpi_comm_rows, mpi_comm_cols, mpi_comm_all integer, intent(in) :: na, nev, lda, ldq, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, mpi_comm_all
complex*16, intent(inout) :: a(lda,*), q(ldq,*) complex*16, intent(inout) :: a(lda,matrixCols), q(ldq,matrixCols)
real*8, intent(inout) :: ev(na) real*8, intent(inout) :: ev(na)
integer :: my_prow, my_pcol, np_rows, np_cols, mpierr, my_pe, n_pes integer :: my_prow, my_pcol, np_rows, np_cols, mpierr, my_pe, n_pes
...@@ -494,7 +495,7 @@ function solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, & ...@@ -494,7 +495,7 @@ function solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, &
! Solve tridiagonal system ! Solve tridiagonal system
ttt0 = MPI_Wtime() ttt0 = MPI_Wtime()
call solve_tridi(na, nev, ev, e, q_real, ubound(q_real,1), nblk, & call solve_tridi(na, nev, ev, e, q_real, ubound(q_real,dim=1), nblk, matrixCols, &
mpi_comm_rows, mpi_comm_cols, wantDebug, success) mpi_comm_rows, mpi_comm_cols, wantDebug, success)
if (.not.(success)) return if (.not.(success)) return
...@@ -774,15 +775,15 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, & ...@@ -774,15 +775,15 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, &
vav = 0 vav = 0
if (l_rows>0) & if (l_rows>0) &
call dsyrk('U','T',n_cols,l_rows,1.d0,vmr,ubound(vmr,1),0.d0,vav,ubound(vav,1)) call dsyrk('U','T',n_cols,l_rows,1.d0,vmr,ubound(vmr,dim=1),0.d0,vav,ubound(vav,dim=1))
call symm_matrix_allreduce(n_cols,vav,ubound(vav,1),mpi_comm_rows) call symm_matrix_allreduce(n_cols,vav,ubound(vav,dim=1),mpi_comm_rows)
! Calculate triangular matrix T for block Householder Transformation ! Calculate triangular matrix T for block Householder Transformation
do lc=n_cols,1,-1 do lc=n_cols,1,-1
tau = tmat(lc,lc,istep) tau = tmat(lc,lc,istep)
if (lc<n_cols) then if (lc<n_cols) then
call dtrmv('U','T','N',n_cols-lc,tmat(lc+1,lc+1,istep),ubound(tmat,1),vav(lc+1,lc),1) call dtrmv('U','T','N',n_cols-lc,tmat(lc+1,lc+1,istep),ubound(tmat,dim=1),vav(lc+1,lc),1)
tmat(lc,lc+1:n_cols,istep) = -tau * vav(lc+1:n_cols,lc) tmat(lc,lc+1:n_cols,istep) = -tau * vav(lc+1:n_cols,lc)
endif endif
enddo enddo
...@@ -790,8 +791,8 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, & ...@@ -790,8 +791,8 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, &
! Transpose vmr -> vmc (stored in umc, second half) ! Transpose vmr -> vmc (stored in umc, second half)
call elpa_transpose_vectors (vmr, ubound(vmr,1), mpi_comm_rows, & call elpa_transpose_vectors_real (vmr, ubound(vmr,dim=1), mpi_comm_rows, &
umc(1,n_cols+1), ubound(umc,1), mpi_comm_cols, & umc(1,n_cols+1), ubound(umc,dim=1), mpi_comm_cols, &
1, istep*nbw, n_cols, nblk) 1, istep*nbw, n_cols, nblk)
! Calculate umc = A**T * vmr ! Calculate umc = A**T * vmr
...@@ -809,13 +810,13 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, & ...@@ -809,13 +810,13 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, &
if (lce<lcs) cycle if (lce<lcs) cycle
lre = min(l_rows,(i+1)*l_rows_tile) lre = min(l_rows,(i+1)*l_rows_tile)
call DGEMM('T','N',lce-lcs+1,n_cols,lre,1.d0,a(1,lcs),ubound(a,1), & call DGEMM('T','N',lce-lcs+1,n_cols,lre,1.d0,a(1,lcs),ubound(a,dim=1), &
vmr,ubound(vmr,1),1.d0,umc(lcs,1),ubound(umc,1)) vmr,ubound(vmr,dim=1),1.d0,umc(lcs,1),ubound(umc,dim=1))
if (i==0) cycle if (i==0) cycle
lre = min(l_rows,i*l_rows_tile) lre = min(l_rows,i*l_rows_tile)
call DGEMM('N','N',lre,n_cols,lce-lcs+1,1.d0,a(1,lcs),lda, & call DGEMM('N','N',lre,n_cols,lce-lcs+1,1.d0,a(1,lcs),lda, &
umc(lcs,n_cols+1),ubound(umc,1),1.d0,vmr(1,n_cols+1),ubound(vmr,1)) umc(lcs,n_cols+1),ubound(umc,dim=1),1.d0,vmr(1,n_cols+1),ubound(vmr,dim=1))
enddo enddo
endif endif
...@@ -825,8 +826,8 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, & ...@@ -825,8 +826,8 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, &
! global tile size is smaller than the global remaining matrix ! global tile size is smaller than the global remaining matrix
if (tile_size < istep*nbw) then if (tile_size < istep*nbw) then
call elpa_reduce_add_vectors (vmr(1,n_cols+1),ubound(vmr,1),mpi_comm_rows, & call elpa_reduce_add_vectors_real (vmr(1,n_cols+1),ubound(vmr,dim=1),mpi_comm_rows, &
umc, ubound(umc,1), mpi_comm_cols, & umc, ubound(umc,dim=1), mpi_comm_cols, &
istep*nbw, n_cols, nblk) istep*nbw, n_cols, nblk)
endif endif
...@@ -839,22 +840,22 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, & ...@@ -839,22 +840,22 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, &
! U = U * Tmat**T ! U = U * Tmat**T
call dtrmm('Right','Upper','Trans','Nonunit',l_cols,n_cols,1.d0,tmat(1,1,istep),ubound(tmat,1),umc,ubound(umc,1)) call dtrmm('Right','Upper','Trans','Nonunit',l_cols,n_cols,1.d0,tmat(1,1,istep),ubound(tmat,dim=1),umc,ubound(umc,dim=1))
! VAV = Tmat * V**T * A * V * Tmat**T = (U*Tmat**T)**T * V * Tmat**T ! VAV = Tmat * V**T * A * V * Tmat**T = (U*Tmat**T)**T * V * Tmat**T
call dgemm('T','N',n_cols,n_cols,l_cols,1.d0,umc,ubound(umc,1),umc(1,n_cols+1),ubound(umc,1),0.d0,vav,ubound(vav,1)) call dgemm('T','N',n_cols,n_cols,l_cols,1.d0,umc,ubound(umc,dim=1),umc(1,n_cols+1),ubound(umc,dim=1),0.d0,vav,ubound(vav,dim=1))
call dtrmm('Right','Upper','Trans','Nonunit',n_cols,n_cols,1.d0,tmat(1,1,istep),ubound(tmat,1),vav,ubound(vav,1)) call dtrmm('Right','Upper','Trans','Nonunit',n_cols,n_cols,1.d0,tmat(1,1,istep),ubound(tmat,dim=1),vav,ubound(vav,dim=1))
call symm_matrix_allreduce(n_cols,vav,ubound(vav,1),mpi_comm_cols) call symm_matrix_allreduce(n_cols,vav,ubound(vav,dim=1),mpi_comm_cols)
! U = U - 0.5 * V * VAV ! U = U - 0.5 * V * VAV
call dgemm('N','N',l_cols,n_cols,n_cols,-0.5d0,umc(1,n_cols+1),ubound(umc,1),vav,ubound(vav,1),1.d0,umc,ubound(umc,1)) call dgemm('N','N',l_cols,n_cols,n_cols,-0.5d0,umc(1,n_cols+1),ubound(umc,dim=1),vav,ubound(vav,dim=1),1.d0,umc,ubound(umc,dim=1))
! Transpose umc -> umr (stored in vmr, second half) ! Transpose umc -> umr (stored in vmr, second half)
call elpa_transpose_vectors (umc, ubound(umc,1), mpi_comm_cols, & call elpa_transpose_vectors_real (umc, ubound(umc,dim=1), mpi_comm_cols, &
vmr(1,n_cols+1), ubound(vmr,1), mpi_comm_rows, & vmr(1,n_cols+1), ubound(vmr,dim=1), mpi_comm_rows, &
1, istep*nbw, n_cols, nblk) 1, istep*nbw, n_cols, nblk)
! A = A - V*U**T - U*V**T ! A = A - V*U**T - U*V**T
...@@ -865,7 +866,7 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, & ...@@ -865,7 +866,7 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, &
lre = min(l_rows,(i+1)*l_rows_tile) lre = min(l_rows,(i+1)*l_rows_tile)
if (lce<lcs .or. lre<1) cycle if (lce<lcs .or. lre<1) cycle
call dgemm('N','T',lre,lce-lcs+1,2*n_cols,-1.d0, & call dgemm('N','T',lre,lce-lcs+1,2*n_cols,-1.d0, &
vmr,ubound(vmr,1),umc(lcs,1),ubound(umc,1), & vmr,ubound(vmr,dim=1),umc(lcs,1),ubound(umc,dim=1), &
1.d0,a(1,lcs),lda) 1.d0,a(1,lcs),lda)
enddo enddo
...@@ -1089,7 +1090,7 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq, ...@@ -1089,7 +1090,7 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
! Q = Q - V * T**T * V**T * Q ! Q = Q - V * T**T * V**T * Q
if (l_rows>0) then if (l_rows>0) then
call dgemm('T','N',n_cols,l_cols,l_rows,1.d0,hvm,ubound(hvm,1), & call dgemm('T','N',n_cols,l_cols,l_rows,1.d0,hvm,ubound(hvm,dim=1), &
q,ldq,0.d0,tmp1,n_cols) q,ldq,0.d0,tmp1,n_cols)
else else
tmp1(1:l_cols*n_cols) = 0 tmp1(1:l_cols*n_cols) = 0
...@@ -1099,7 +1100,7 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq, ...@@ -1099,7 +1100,7 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
if (l_rows>0) then if (l_rows>0) then
call dtrmm('L','U','T','N',n_cols,l_cols,1.0d0,tmat_complete,cwy_blocking,tmp2,n_cols) call dtrmm('L','U','T','N',n_cols,l_cols,1.0d0,tmat_complete,cwy_blocking,tmp2,n_cols)
call dgemm('N','N',l_rows,l_cols,n_cols,-1.d0,hvm,ubound(hvm,1), tmp2,n_cols,1.d0,q,ldq) call dgemm('N','N',l_rows,l_cols,n_cols,-1.d0,hvm,ubound(hvm,dim=1), tmp2,n_cols,1.d0,q,ldq)
endif endif
enddo enddo
...@@ -1149,7 +1150,7 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq, ...@@ -1149,7 +1150,7 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
! Q = Q - V * T**T * V**T * Q ! Q = Q - V * T**T * V**T * Q
if (l_rows>0) then if (l_rows>0) then
call dgemm('T','N',n_cols,l_cols,l_rows,1.d0,hvm,ubound(hvm,1), & call dgemm('T','N',n_cols,l_cols,l_rows,1.d0,hvm,ubound(hvm,dim=1), &
q,ldq,0.d0,tmp1,n_cols) q,ldq,0.d0,tmp1,n_cols)
else else
tmp1(1:l_cols*n_cols) = 0 tmp1(1:l_cols*n_cols) = 0
...@@ -1158,8 +1159,8 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq, ...@@ -1158,8 +1159,8 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
call mpi_allreduce(tmp1,tmp2,n_cols*l_cols,MPI_REAL8,MPI_SUM,mpi_comm_rows,mpierr) call mpi_allreduce(tmp1,tmp2,n_cols*l_cols,MPI_REAL8,MPI_SUM,mpi_comm_rows,mpierr)
if (l_rows>0) then if (l_rows>0) then
call dtrmm('L','U','T','N',n_cols,l_cols,1.0d0,tmat(1,1,istep),ubound(tmat,1),tmp2,n_cols) call dtrmm('L','U','T','N',n_cols,l_cols,1.0d0,tmat(1,1,istep),ubound(tmat,dim=1),tmp2,n_cols)
call dgemm('N','N',l_rows,l_cols,n_cols,-1.d0,hvm,ubound(hvm,1), & call dgemm('N','N',l_rows,l_cols,n_cols,-1.d0,hvm,ubound(hvm,dim=1), &
tmp2,n_cols,1.d0,q,ldq) tmp2,n_cols,1.d0,q,ldq)
endif endif
enddo enddo
...@@ -3409,24 +3410,24 @@ subroutine bandred_complex(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, ...@@ -3409,24 +3410,24 @@ subroutine bandred_complex(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols,
vav = 0 vav = 0
if (l_rows>0) & if (l_rows>0) &
call zherk('U','C',n_cols,l_rows,CONE,vmr,ubound(vmr,1),CZERO,vav,ubound(vav,1)) call zherk('U','C',n_cols,l_rows,CONE,vmr,ubound(vmr,dim=1),CZERO,vav,ubound(vav,dim=1))
call herm_matrix_allreduce(n_cols,vav,ubound(vav,1),mpi_comm_rows) call herm_matrix_allreduce(n_cols,vav,ubound(vav,dim=1),mpi_comm_rows)
! Calculate triangular matrix T for block Householder Transformation ! Calculate triangular matrix T for block Householder Transformation
do lc=n_cols,1,-1 do lc=n_cols,1,-1
tau = tmat(lc,lc,istep) tau = tmat(lc,lc,istep)
if (lc<n_cols) then if (lc<n_cols) then
call ztrmv('U','C','N',n_cols-lc,tmat(lc+1,lc+1,istep),ubound(tmat,1),vav(lc+1,lc),1) call ztrmv('U','C','N',n_cols-lc,tmat(lc+1,lc+1,istep),ubound(tmat,dim=1),vav(lc+1,lc),1)
tmat(lc,lc+1:n_cols,istep) = -tau * conjg(vav(lc+1:n_cols,lc)) tmat(lc,lc+1:n_cols,istep) = -tau * conjg(vav(lc+1:n_cols,lc))
endif endif
enddo enddo
! Transpose vmr -> vmc (stored in umc, second half) ! Transpose vmr -> vmc (stored in umc, second half)
call elpa_transpose_vectors (vmr, 2*ubound(vmr,1), mpi_comm_rows, & call elpa_transpose_vectors_complex (vmr, ubound(vmr,dim=1), mpi_comm_rows, &
umc(1,n_cols+1), 2*ubound(umc,1), mpi_comm_cols, & umc(1,n_cols+1), ubound(umc,dim=1), mpi_comm_cols, &
1, 2*istep*nbw, n_cols, 2*nblk) 1, istep*nbw, n_cols, nblk)
! Calculate umc = A**T * vmr ! Calculate umc = A**T * vmr
! Note that the distributed A has to be transposed ! Note that the distributed A has to be transposed
...@@ -3443,13 +3444,13 @@ subroutine bandred_complex(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, ...@@ -3443,13 +3444,13 @@ subroutine bandred_complex(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols,
if (lce<lcs) cycle if (lce<lcs) cycle
lre = min(l_rows,(i+1)*l_rows_tile) lre = min(l_rows,(i+1)*l_rows_tile)
call ZGEMM('C','N',lce-lcs+1,n_cols,lre,CONE,a(1,lcs),ubound(a,1), & call ZGEMM('C','N',lce-lcs+1,n_cols,lre,CONE,a(1,lcs),ubound(a,dim=1), &
vmr,ubound(vmr,1),CONE,umc(lcs,1),ubound(umc,1)) vmr,ubound(vmr,dim=1),CONE,umc(lcs,1),ubound(umc,dim=1))
if (i==0) cycle if (i==0) cycle
lre = min(l_rows,i*l_rows_tile) lre = min(l_rows,i*l_rows_tile)
call ZGEMM('N','N',lre,n_cols,lce-lcs+1,CONE,a(1,lcs),lda, & call ZGEMM('N','N',lre,n_cols,lce-lcs+1,CONE,a(1,lcs),lda, &
umc(lcs,n_cols+1),ubound(umc,1),CONE,vmr(1,n_cols+1),ubound(vmr,1)) umc(lcs,n_cols+1),ubound(umc,dim=1),CONE,vmr(1,n_cols+1),ubound(vmr,dim=1))
enddo enddo
endif endif
...@@ -3459,9 +3460,9 @@ subroutine bandred_complex(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, ...@@ -3459,9 +3460,9 @@ subroutine bandred_complex(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols,
! global tile size is smaller than the global remaining matrix ! global tile size is smaller than the global remaining matrix
if (tile_size < istep*nbw) then if (tile_size < istep*nbw) then
call elpa_reduce_add_vectors (vmr(1,n_cols+1),2*ubound(vmr,1),mpi_comm_rows, & call elpa_reduce_add_vectors_complex (vmr(1,n_cols+1),ubound(vmr,dim=1),mpi_comm_rows, &
umc, 2*ubound(umc,1), mpi_comm_cols, & umc, ubound(umc,dim=1), mpi_comm_cols, &
2*istep*nbw, n_cols, 2*nblk) istep*nbw, n_cols, nblk)
endif endif
if (l_cols>0) then if (l_cols>0) then
...@@ -3473,23 +3474,25 @@ subroutine bandred_complex(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, ...@@ -3473,23 +3474,25 @@ subroutine bandred_complex(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols,
! U = U * Tmat**T ! U = U * Tmat**T
call ztrmm('Right','Upper','C','Nonunit',l_cols,n_cols,CONE,tmat(1,1,istep),ubound(tmat,1),umc,ubound(umc,1)) call ztrmm('Right','Upper','C','Nonunit',l_cols,n_cols,CONE,tmat(1,1,istep),ubound(tmat,dim=1),umc,ubound(umc,dim=1))
! VAV = Tmat * V**T * A * V * Tmat**T = (U*Tmat**T)**T * V * Tmat**T ! VAV = Tmat * V**T * A * V * Tmat**T = (U*Tmat**T)**T * V * Tmat**T
call zgemm('C','N',n_cols,n_cols,l_cols,CONE,umc,ubound(umc,1),umc(1,n_cols+1),ubound(umc,1),CZERO,vav,ubound(vav,1)) call zgemm('C','N',n_cols,n_cols,l_cols,CONE,umc,ubound(umc,dim=1),umc(1,n_cols+1), &
call ztrmm('Right','Upper','C','Nonunit',n_cols,n_cols,CONE,tmat(1,1,istep),ubound(tmat,1),vav,ubound(vav,1)) ubound(umc,dim=1),CZERO,vav,ubound(vav,dim=1))
call ztrmm('Right','Upper','C','Nonunit',n_cols,n_cols,CONE,tmat(1,1,istep),ubound(tmat,dim=1),vav,ubound(vav,dim=1))
call herm_matrix_allreduce(n_cols,vav,ubound(vav,1),mpi_comm_cols) call herm_matrix_allreduce(n_cols,vav,ubound(vav,dim=1),mpi_comm_cols)
! U = U - 0.5 * V * VAV ! U = U - 0.5 * V * VAV
call zgemm('N','N',l_cols,n_cols,n_cols,(-0.5d0,0.d0),umc(1,n_cols+1),ubound(umc,1),vav,ubound(vav,1),CONE,umc,ubound(umc,1)) call zgemm('N','N',l_cols,n_cols,n_cols,(-0.5d0,0.d0),umc(1,n_cols+1),ubound(umc,dim=1),vav,ubound(vav,dim=1), &
CONE,umc,ubound(umc,dim=1))
! Transpose umc -> umr (stored in vmr, second half) ! Transpose umc -> umr (stored in vmr, second half)
call elpa_transpose_vectors (umc, 2*ubound(umc,1), mpi_comm_cols, & call elpa_transpose_vectors_complex (umc, ubound(umc,dim=1), mpi_comm_cols, &
vmr(1,n_cols+1), 2*ubound(vmr,1), mpi_comm_rows, & vmr(1,n_cols+1), ubound(vmr,dim=1), mpi_comm_rows, &
1, 2*istep*nbw, n_cols, 2*nblk) 1, istep*nbw, n_cols, nblk)
! A = A - V*U**T - U*V**T ! A = A - V*U**T - U*V**T
...@@ -3499,7 +3502,7 @@ subroutine bandred_complex(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, ...@@ -3499,7 +3502,7 @@ subroutine bandred_complex(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols,
lre = min(l_rows,(i+1)*l_rows_tile) lre = min(l_rows,(i+1)*l_rows_tile)
if (lce<lcs .or. lre<1) cycle if (lce<lcs .or. lre<1) cycle
call zgemm('N','C',lre,lce-lcs+1,2*n_cols,-CONE, & call zgemm('N','C',lre,lce-lcs+1,2*n_cols,-CONE, &
vmr,ubound(vmr,1),umc(lcs,1),ubound(umc,1), & vmr,ubound(vmr,dim=1),umc(lcs,1),ubound(umc,dim=1), &
CONE,a(1,lcs),lda) CONE,a(1,lcs),lda)
enddo enddo
...@@ -3679,15 +3682,15 @@ subroutine trans_ev_band_to_full_complex(na, nqc, nblk, nbw, a, lda, tmat, q, ld ...@@ -3679,15 +3682,15 @@ subroutine trans_ev_band_to_full_complex(na, nqc, nblk, nbw, a, lda, tmat, q, ld
! Q = Q - V * T**T * V**T * Q ! Q = Q - V * T**T * V**T * Q
if (l_rows>0) then if (l_rows>0) then
call zgemm('C','N',n_cols,l_cols,l_rows,CONE,hvm,ubound(hvm,1), & call zgemm('C','N',n_cols,l_cols,l_rows,CONE,hvm,ubound(hvm,dim=1), &
q,ldq,CZERO,tmp1,n_cols) q,ldq,CZERO,tmp1,n_cols)
else else
tmp1(1:l_cols*n_cols) = 0 tmp1(1:l_cols*n_cols) = 0
endif endif
call mpi_allreduce(tmp1,tmp2,n_cols*l_cols,MPI_DOUBLE_COMPLEX,MPI_SUM,mpi_comm_rows,mpierr) call mpi_allreduce(tmp1,tmp2,n_cols*l_cols,MPI_DOUBLE_COMPLEX,MPI_SUM,mpi_comm_rows,mpierr)
if (l_rows>0) then if (l_rows>0) then
call ztrmm('L','U','C','N',n_cols,l_cols,CONE,tmat(1,1,istep),ubound(tmat,1),tmp2,n_cols) call ztrmm('L','U','C','N',n_cols,l_cols,CONE,tmat(1,1,istep),ubound(tmat,dim=1),tmp2,n_cols)
call zgemm('N','N',l_rows,l_cols,n_cols,-CONE,hvm,ubound(hvm,1), & call zgemm('N','N',l_rows,l_cols,n_cols,-CONE,hvm,ubound(hvm,dim=1), &
tmp2,n_cols,CONE,q,ldq) tmp2,n_cols,CONE,q,ldq)
endif endif
......
...@@ -64,21 +64,21 @@ ...@@ -64,21 +64,21 @@
end function end function
!c> int elpa_solve_evp_real_stage1(int na, int nev, int ncols, double *a, int lda, double *ev, double *q, int ldq, int nblk, int mpi_comm_rows, int mpi_comm_cols); !c> int elpa_solve_evp_real_stage1(int na, int nev, double *a, int lda, double *ev, double *q, int ldq, int nblk, int matrixCols, int mpi_comm_rows, int mpi_comm_cols);
function solve_elpa1_evp_real_wrapper(na, nev, ncols, a, lda, ev, q, ldq, nblk, & function solve_elpa1_evp_real_wrapper(na, nev, a, lda, ev, q, ldq, nblk, &
mpi_comm_rows, mpi_comm_cols) & matrixCols, mpi_comm_rows, mpi_comm_cols) &
result(success) bind(C,name="elpa_solve_evp_real_1stage") result(success) bind(C,name="elpa_solve_evp_real_1stage")
use, intrinsic :: iso_c_binding use, intrinsic :: iso_c_binding
use elpa1, only : solve_evp_real use elpa1, only : solve_evp_real
integer(kind=c_int) :: success integer(kind=c_int) :: success
integer(kind=c_int), value, intent(in) :: na, nev, ncols, lda, ldq, nblk, mpi_comm_cols, mpi_comm_rows integer(kind=c_int), value, intent(in) :: na, nev, lda, ldq, nblk, matrixCols, mpi_comm_cols, mpi_comm_rows
real(kind=c_double) :: a(1:lda,1:ncols), ev(1:na), q(1:ldq,1:ncols) real(kind=c_double) :: a(1:lda,1:matrixCols), ev(1:na), q(1:ldq,1:matrixCols)
logical :: successFortran logical :: successFortran
successFortran = solve_evp_real(na, nev, a, lda, ev, q, ldq, nblk, mpi_comm_rows, mpi_comm_cols) successFortran = solve_evp_real(na, nev, a, lda, ev, q, ldq, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols)
if (successFortran) then if (successFortran) then
success = 1 success = 1
...@@ -88,22 +88,22 @@ ...@@ -88,22 +88,22 @@
end function end function
! int elpa_solve_evp_complex_stage1(int na, int nev, int ncols double_complex *a, int lda, double *ev, double_complex *q, int ldq, int nblk, int mpi_comm_rows, int mpi_comm_cols); ! int elpa_solve_evp_complex_stage1(int na, int nev, double_complex *a, int lda, double *ev, double_complex *q, int ldq, int nblk, int matrixCols, int mpi_comm_rows, int mpi_comm_cols);
function solve_evp_real_wrapper(na, nev, ncols, a, lda, ev, q, ldq, nblk, & function solve_evp_real_wrapper(na, nev, a, lda, ev, q, ldq, nblk, &
mpi_comm_rows, mpi_comm_cols) & matrixCols, mpi_comm_rows, mpi_comm_cols) &
result(success) bind(C,name="elpa_solve_evp_complex_1stage") result(success) bind(C,name="elpa_solve_evp_complex_1stage")
use, intrinsic :: iso_c_binding use, intrinsic :: iso_c_binding
use elpa1, only : solve_evp_complex use elpa1, only : solve_evp_complex
integer(kind=c_int) :: success integer(kind=c_int) :: success
integer(kind=c_int), value, intent(in) :: na, nev, ncols, lda, ldq, nblk, mpi_comm_cols, mpi_comm_rows integer(kind=c_int), value, intent(in) :: na, nev, lda, ldq, nblk, matrixCols, mpi_comm_cols, mpi_comm_rows
complex(kind=c_double_complex) :: a(1:lda,1:ncols), q(1:ldq,1:ncols) complex(kind=c_double_complex) :: a(1:lda,1:matrixCols), q(1:ldq,1:matrixCols)
real(kind=c_double) :: ev(1:na) real(kind=c_double) :: ev(1:na)
logical :: successFortran logical :: successFortran
successFortran = solve_evp_complex(na, nev, a, lda, ev, q, ldq, nblk, mpi_comm_rows, mpi_comm_cols) successFortran = solve_evp_complex(na, nev, a, lda, ev, q, ldq, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols)
if (successFortran) then if (successFortran) then
success = 1 success = 1
...@@ -113,9 +113,9 @@ ...@@ -113,9 +113,9 @@
end function end function
!c> int elpa_solve_evp_real_stage2(int na, int nev, int ncols, double *a, int lda, double *ev, double *q, int ldq, int nblk, int mpi_comm_rows, int mpi_comm_cols, int THIS_REAL_ELPA_KERNEL_API, int useQR); !c> int elpa_solve_evp_real_stage2(int na, int nev, double *a, int lda, double *ev, double *q, int ldq, int nblk, int matrixCols, int mpi_comm_rows, int mpi_comm_cols, int THIS_REAL_ELPA_KERNEL_API, int useQR);
function solve_elpa2_evp_real_wrapper(na, nev, ncols, a, lda, ev, q, ldq, nblk, & function solve_elpa2_evp_real_wrapper(na, nev, a, lda, ev, q, ldq, nblk, &
mpi_comm_rows, mpi_comm_cols, mpi_comm_all, & matrixCols, mpi_comm_rows, mpi_comm_cols, mpi_comm_all, &
THIS_REAL_ELPA_KERNEL_API, useQR) & THIS_REAL_ELPA_KERNEL_API, useQR) &
result(success) bind(C,name="elpa_solve_evp_real_2stage") result(success) bind(C,name="elpa_solve_evp_real_2stage")
...@@ -123,10 +123,10 @@ ...@@ -123,10 +123,10 @@
use elpa2, only : solve_evp_real_2stage use elpa2, only : solve_evp_real_2stage
integer(kind=c_int) :: success integer(kind=c_int) :: success
integer(kind=c_int), value, intent(in) :: na, nev, ncols, lda, ldq, nblk, mpi_comm_cols, mpi_comm_rows, & integer(kind=c_int), value, intent(in) :: na, nev, lda,