Commit b48cf00a authored by Andreas Marek's avatar Andreas Marek
Browse files

Start to remove assumed size arrays

This commit is not ABI compatible, since it changes the interfaces
of some routines

Also, introduce type checking for transpose and reduce_add routines
parent bf168297
This diff is collapsed.
......@@ -112,6 +112,7 @@ module ELPA2
contains
function solve_evp_real_2stage(na, nev, a, lda, ev, q, ldq, nblk, &
matrixCols, &
mpi_comm_rows, mpi_comm_cols, &
mpi_comm_all, THIS_REAL_ELPA_KERNEL_API,&
useQR) result(success)
......@@ -159,10 +160,10 @@ function solve_evp_real_2stage(na, nev, a, lda, ev, q, ldq, nblk, &
integer, intent(in), optional :: THIS_REAL_ELPA_KERNEL_API
integer :: THIS_REAL_ELPA_KERNEL
integer, intent(in) :: na, nev, lda, ldq, mpi_comm_rows, &
integer, intent(in) :: na, nev, lda, ldq, matrixCols, mpi_comm_rows, &
mpi_comm_cols, mpi_comm_all
integer, intent(in) :: nblk
real*8, intent(inout) :: a(lda,*), ev(na), q(ldq,*)
real*8, intent(inout) :: a(lda,matrixCols), ev(na), q(ldq,matrixCols)
integer :: my_pe, n_pes, my_prow, my_pcol, np_rows, np_cols, mpierr
integer :: nbw, num_blocks
......@@ -290,7 +291,7 @@ function solve_evp_real_2stage(na, nev, a, lda, ev, q, ldq, nblk, &
! Solve tridiagonal system
ttt0 = MPI_Wtime()
call solve_tridi(na, nev, ev, e, q, ldq, nblk, mpi_comm_rows, &
call solve_tridi(na, nev, ev, e, q, ldq, nblk, matrixCols, mpi_comm_rows, &
mpi_comm_cols, wantDebug, success)
if (.not.(success)) return
......@@ -338,7 +339,7 @@ end function solve_evp_real_2stage
!-------------------------------------------------------------------------------
function solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, &
mpi_comm_rows, mpi_comm_cols, &
matrixCols, mpi_comm_rows, mpi_comm_cols, &
mpi_comm_all, THIS_COMPLEX_ELPA_KERNEL_API) result(success)
!-------------------------------------------------------------------------------
......@@ -381,8 +382,8 @@ function solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, &
implicit none
integer, intent(in), optional :: THIS_COMPLEX_ELPA_KERNEL_API
integer :: THIS_COMPLEX_ELPA_KERNEL
integer, intent(in) :: na, nev, lda, ldq, nblk, mpi_comm_rows, mpi_comm_cols, mpi_comm_all
complex*16, intent(inout) :: a(lda,*), q(ldq,*)
integer, intent(in) :: na, nev, lda, ldq, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, mpi_comm_all
complex*16, intent(inout) :: a(lda,matrixCols), q(ldq,matrixCols)
real*8, intent(inout) :: ev(na)
integer :: my_prow, my_pcol, np_rows, np_cols, mpierr, my_pe, n_pes
......@@ -494,7 +495,7 @@ function solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, &
! Solve tridiagonal system
ttt0 = MPI_Wtime()
call solve_tridi(na, nev, ev, e, q_real, ubound(q_real,1), nblk, &
call solve_tridi(na, nev, ev, e, q_real, ubound(q_real,dim=1), nblk, matrixCols, &
mpi_comm_rows, mpi_comm_cols, wantDebug, success)
if (.not.(success)) return
......@@ -774,15 +775,15 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, &
vav = 0
if (l_rows>0) &
call dsyrk('U','T',n_cols,l_rows,1.d0,vmr,ubound(vmr,1),0.d0,vav,ubound(vav,1))
call symm_matrix_allreduce(n_cols,vav,ubound(vav,1),mpi_comm_rows)
call dsyrk('U','T',n_cols,l_rows,1.d0,vmr,ubound(vmr,dim=1),0.d0,vav,ubound(vav,dim=1))
call symm_matrix_allreduce(n_cols,vav,ubound(vav,dim=1),mpi_comm_rows)
! Calculate triangular matrix T for block Householder Transformation
do lc=n_cols,1,-1
tau = tmat(lc,lc,istep)
if (lc<n_cols) then
call dtrmv('U','T','N',n_cols-lc,tmat(lc+1,lc+1,istep),ubound(tmat,1),vav(lc+1,lc),1)
call dtrmv('U','T','N',n_cols-lc,tmat(lc+1,lc+1,istep),ubound(tmat,dim=1),vav(lc+1,lc),1)
tmat(lc,lc+1:n_cols,istep) = -tau * vav(lc+1:n_cols,lc)
endif
enddo
......@@ -790,8 +791,8 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, &
! Transpose vmr -> vmc (stored in umc, second half)
call elpa_transpose_vectors (vmr, ubound(vmr,1), mpi_comm_rows, &
umc(1,n_cols+1), ubound(umc,1), mpi_comm_cols, &
call elpa_transpose_vectors_real (vmr, ubound(vmr,dim=1), mpi_comm_rows, &
umc(1,n_cols+1), ubound(umc,dim=1), mpi_comm_cols, &
1, istep*nbw, n_cols, nblk)
! Calculate umc = A**T * vmr
......@@ -809,13 +810,13 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, &
if (lce<lcs) cycle
lre = min(l_rows,(i+1)*l_rows_tile)
call DGEMM('T','N',lce-lcs+1,n_cols,lre,1.d0,a(1,lcs),ubound(a,1), &
vmr,ubound(vmr,1),1.d0,umc(lcs,1),ubound(umc,1))
call DGEMM('T','N',lce-lcs+1,n_cols,lre,1.d0,a(1,lcs),ubound(a,dim=1), &
vmr,ubound(vmr,dim=1),1.d0,umc(lcs,1),ubound(umc,dim=1))
if (i==0) cycle
lre = min(l_rows,i*l_rows_tile)
call DGEMM('N','N',lre,n_cols,lce-lcs+1,1.d0,a(1,lcs),lda, &
umc(lcs,n_cols+1),ubound(umc,1),1.d0,vmr(1,n_cols+1),ubound(vmr,1))
umc(lcs,n_cols+1),ubound(umc,dim=1),1.d0,vmr(1,n_cols+1),ubound(vmr,dim=1))
enddo
endif
......@@ -825,8 +826,8 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, &
! global tile size is smaller than the global remaining matrix
if (tile_size < istep*nbw) then
call elpa_reduce_add_vectors (vmr(1,n_cols+1),ubound(vmr,1),mpi_comm_rows, &
umc, ubound(umc,1), mpi_comm_cols, &
call elpa_reduce_add_vectors_real (vmr(1,n_cols+1),ubound(vmr,dim=1),mpi_comm_rows, &
umc, ubound(umc,dim=1), mpi_comm_cols, &
istep*nbw, n_cols, nblk)
endif
......@@ -839,22 +840,22 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, &
! U = U * Tmat**T
call dtrmm('Right','Upper','Trans','Nonunit',l_cols,n_cols,1.d0,tmat(1,1,istep),ubound(tmat,1),umc,ubound(umc,1))
call dtrmm('Right','Upper','Trans','Nonunit',l_cols,n_cols,1.d0,tmat(1,1,istep),ubound(tmat,dim=1),umc,ubound(umc,dim=1))
! VAV = Tmat * V**T * A * V * Tmat**T = (U*Tmat**T)**T * V * Tmat**T
call dgemm('T','N',n_cols,n_cols,l_cols,1.d0,umc,ubound(umc,1),umc(1,n_cols+1),ubound(umc,1),0.d0,vav,ubound(vav,1))
call dtrmm('Right','Upper','Trans','Nonunit',n_cols,n_cols,1.d0,tmat(1,1,istep),ubound(tmat,1),vav,ubound(vav,1))
call dgemm('T','N',n_cols,n_cols,l_cols,1.d0,umc,ubound(umc,dim=1),umc(1,n_cols+1),ubound(umc,dim=1),0.d0,vav,ubound(vav,dim=1))
call dtrmm('Right','Upper','Trans','Nonunit',n_cols,n_cols,1.d0,tmat(1,1,istep),ubound(tmat,dim=1),vav,ubound(vav,dim=1))
call symm_matrix_allreduce(n_cols,vav,ubound(vav,1),mpi_comm_cols)
call symm_matrix_allreduce(n_cols,vav,ubound(vav,dim=1),mpi_comm_cols)
! U = U - 0.5 * V * VAV
call dgemm('N','N',l_cols,n_cols,n_cols,-0.5d0,umc(1,n_cols+1),ubound(umc,1),vav,ubound(vav,1),1.d0,umc,ubound(umc,1))
call dgemm('N','N',l_cols,n_cols,n_cols,-0.5d0,umc(1,n_cols+1),ubound(umc,dim=1),vav,ubound(vav,dim=1),1.d0,umc,ubound(umc,dim=1))
! Transpose umc -> umr (stored in vmr, second half)
call elpa_transpose_vectors (umc, ubound(umc,1), mpi_comm_cols, &
vmr(1,n_cols+1), ubound(vmr,1), mpi_comm_rows, &
call elpa_transpose_vectors_real (umc, ubound(umc,dim=1), mpi_comm_cols, &
vmr(1,n_cols+1), ubound(vmr,dim=1), mpi_comm_rows, &
1, istep*nbw, n_cols, nblk)
! A = A - V*U**T - U*V**T
......@@ -865,7 +866,7 @@ subroutine bandred_real(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols, &
lre = min(l_rows,(i+1)*l_rows_tile)
if (lce<lcs .or. lre<1) cycle
call dgemm('N','T',lre,lce-lcs+1,2*n_cols,-1.d0, &
vmr,ubound(vmr,1),umc(lcs,1),ubound(umc,1), &
vmr,ubound(vmr,dim=1),umc(lcs,1),ubound(umc,dim=1), &
1.d0,a(1,lcs),lda)
enddo
......@@ -1089,7 +1090,7 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
! Q = Q - V * T**T * V**T * Q
if (l_rows>0) then
call dgemm('T','N',n_cols,l_cols,l_rows,1.d0,hvm,ubound(hvm,1), &
call dgemm('T','N',n_cols,l_cols,l_rows,1.d0,hvm,ubound(hvm,dim=1), &
q,ldq,0.d0,tmp1,n_cols)
else
tmp1(1:l_cols*n_cols) = 0
......@@ -1099,7 +1100,7 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
if (l_rows>0) then
call dtrmm('L','U','T','N',n_cols,l_cols,1.0d0,tmat_complete,cwy_blocking,tmp2,n_cols)
call dgemm('N','N',l_rows,l_cols,n_cols,-1.d0,hvm,ubound(hvm,1), tmp2,n_cols,1.d0,q,ldq)
call dgemm('N','N',l_rows,l_cols,n_cols,-1.d0,hvm,ubound(hvm,dim=1), tmp2,n_cols,1.d0,q,ldq)
endif
enddo
......@@ -1149,7 +1150,7 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
! Q = Q - V * T**T * V**T * Q
if (l_rows>0) then
call dgemm('T','N',n_cols,l_cols,l_rows,1.d0,hvm,ubound(hvm,1), &
call dgemm('T','N',n_cols,l_cols,l_rows,1.d0,hvm,ubound(hvm,dim=1), &
q,ldq,0.d0,tmp1,n_cols)
else
tmp1(1:l_cols*n_cols) = 0
......@@ -1158,8 +1159,8 @@ subroutine trans_ev_band_to_full_real(na, nqc, nblk, nbw, a, lda, tmat, q, ldq,
call mpi_allreduce(tmp1,tmp2,n_cols*l_cols,MPI_REAL8,MPI_SUM,mpi_comm_rows,mpierr)
if (l_rows>0) then
call dtrmm('L','U','T','N',n_cols,l_cols,1.0d0,tmat(1,1,istep),ubound(tmat,1),tmp2,n_cols)
call dgemm('N','N',l_rows,l_cols,n_cols,-1.d0,hvm,ubound(hvm,1), &
call dtrmm('L','U','T','N',n_cols,l_cols,1.0d0,tmat(1,1,istep),ubound(tmat,dim=1),tmp2,n_cols)
call dgemm('N','N',l_rows,l_cols,n_cols,-1.d0,hvm,ubound(hvm,dim=1), &
tmp2,n_cols,1.d0,q,ldq)
endif
enddo
......@@ -3409,24 +3410,24 @@ subroutine bandred_complex(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols,
vav = 0
if (l_rows>0) &
call zherk('U','C',n_cols,l_rows,CONE,vmr,ubound(vmr,1),CZERO,vav,ubound(vav,1))
call herm_matrix_allreduce(n_cols,vav,ubound(vav,1),mpi_comm_rows)
call zherk('U','C',n_cols,l_rows,CONE,vmr,ubound(vmr,dim=1),CZERO,vav,ubound(vav,dim=1))
call herm_matrix_allreduce(n_cols,vav,ubound(vav,dim=1),mpi_comm_rows)
! Calculate triangular matrix T for block Householder Transformation
do lc=n_cols,1,-1
tau = tmat(lc,lc,istep)
if (lc<n_cols) then
call ztrmv('U','C','N',n_cols-lc,tmat(lc+1,lc+1,istep),ubound(tmat,1),vav(lc+1,lc),1)
call ztrmv('U','C','N',n_cols-lc,tmat(lc+1,lc+1,istep),ubound(tmat,dim=1),vav(lc+1,lc),1)
tmat(lc,lc+1:n_cols,istep) = -tau * conjg(vav(lc+1:n_cols,lc))
endif
enddo
! Transpose vmr -> vmc (stored in umc, second half)
call elpa_transpose_vectors (vmr, 2*ubound(vmr,1), mpi_comm_rows, &
umc(1,n_cols+1), 2*ubound(umc,1), mpi_comm_cols, &
1, 2*istep*nbw, n_cols, 2*nblk)
call elpa_transpose_vectors_complex (vmr, ubound(vmr,dim=1), mpi_comm_rows, &
umc(1,n_cols+1), ubound(umc,dim=1), mpi_comm_cols, &
1, istep*nbw, n_cols, nblk)
! Calculate umc = A**T * vmr
! Note that the distributed A has to be transposed
......@@ -3443,13 +3444,13 @@ subroutine bandred_complex(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols,
if (lce<lcs) cycle
lre = min(l_rows,(i+1)*l_rows_tile)
call ZGEMM('C','N',lce-lcs+1,n_cols,lre,CONE,a(1,lcs),ubound(a,1), &
vmr,ubound(vmr,1),CONE,umc(lcs,1),ubound(umc,1))
call ZGEMM('C','N',lce-lcs+1,n_cols,lre,CONE,a(1,lcs),ubound(a,dim=1), &
vmr,ubound(vmr,dim=1),CONE,umc(lcs,1),ubound(umc,dim=1))
if (i==0) cycle
lre = min(l_rows,i*l_rows_tile)
call ZGEMM('N','N',lre,n_cols,lce-lcs+1,CONE,a(1,lcs),lda, &
umc(lcs,n_cols+1),ubound(umc,1),CONE,vmr(1,n_cols+1),ubound(vmr,1))
umc(lcs,n_cols+1),ubound(umc,dim=1),CONE,vmr(1,n_cols+1),ubound(vmr,dim=1))
enddo
endif
......@@ -3459,9 +3460,9 @@ subroutine bandred_complex(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols,
! global tile size is smaller than the global remaining matrix
if (tile_size < istep*nbw) then
call elpa_reduce_add_vectors (vmr(1,n_cols+1),2*ubound(vmr,1),mpi_comm_rows, &
umc, 2*ubound(umc,1), mpi_comm_cols, &
2*istep*nbw, n_cols, 2*nblk)
call elpa_reduce_add_vectors_complex (vmr(1,n_cols+1),ubound(vmr,dim=1),mpi_comm_rows, &
umc, ubound(umc,dim=1), mpi_comm_cols, &
istep*nbw, n_cols, nblk)
endif
if (l_cols>0) then
......@@ -3473,23 +3474,25 @@ subroutine bandred_complex(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols,
! U = U * Tmat**T
call ztrmm('Right','Upper','C','Nonunit',l_cols,n_cols,CONE,tmat(1,1,istep),ubound(tmat,1),umc,ubound(umc,1))
call ztrmm('Right','Upper','C','Nonunit',l_cols,n_cols,CONE,tmat(1,1,istep),ubound(tmat,dim=1),umc,ubound(umc,dim=1))
! VAV = Tmat * V**T * A * V * Tmat**T = (U*Tmat**T)**T * V * Tmat**T
call zgemm('C','N',n_cols,n_cols,l_cols,CONE,umc,ubound(umc,1),umc(1,n_cols+1),ubound(umc,1),CZERO,vav,ubound(vav,1))
call ztrmm('Right','Upper','C','Nonunit',n_cols,n_cols,CONE,tmat(1,1,istep),ubound(tmat,1),vav,ubound(vav,1))
call zgemm('C','N',n_cols,n_cols,l_cols,CONE,umc,ubound(umc,dim=1),umc(1,n_cols+1), &
ubound(umc,dim=1),CZERO,vav,ubound(vav,dim=1))
call ztrmm('Right','Upper','C','Nonunit',n_cols,n_cols,CONE,tmat(1,1,istep),ubound(tmat,dim=1),vav,ubound(vav,dim=1))
call herm_matrix_allreduce(n_cols,vav,ubound(vav,1),mpi_comm_cols)
call herm_matrix_allreduce(n_cols,vav,ubound(vav,dim=1),mpi_comm_cols)
! U = U - 0.5 * V * VAV
call zgemm('N','N',l_cols,n_cols,n_cols,(-0.5d0,0.d0),umc(1,n_cols+1),ubound(umc,1),vav,ubound(vav,1),CONE,umc,ubound(umc,1))
call zgemm('N','N',l_cols,n_cols,n_cols,(-0.5d0,0.d0),umc(1,n_cols+1),ubound(umc,dim=1),vav,ubound(vav,dim=1), &
CONE,umc,ubound(umc,dim=1))
! Transpose umc -> umr (stored in vmr, second half)
call elpa_transpose_vectors (umc, 2*ubound(umc,1), mpi_comm_cols, &
vmr(1,n_cols+1), 2*ubound(vmr,1), mpi_comm_rows, &
1, 2*istep*nbw, n_cols, 2*nblk)
call elpa_transpose_vectors_complex (umc, ubound(umc,dim=1), mpi_comm_cols, &
vmr(1,n_cols+1), ubound(vmr,dim=1), mpi_comm_rows, &
1, istep*nbw, n_cols, nblk)
! A = A - V*U**T - U*V**T
......@@ -3499,7 +3502,7 @@ subroutine bandred_complex(na, a, lda, nblk, nbw, mpi_comm_rows, mpi_comm_cols,
lre = min(l_rows,(i+1)*l_rows_tile)
if (lce<lcs .or. lre<1) cycle
call zgemm('N','C',lre,lce-lcs+1,2*n_cols,-CONE, &
vmr,ubound(vmr,1),umc(lcs,1),ubound(umc,1), &
vmr,ubound(vmr,dim=1),umc(lcs,1),ubound(umc,dim=1), &
CONE,a(1,lcs),lda)
enddo
......@@ -3679,15 +3682,15 @@ subroutine trans_ev_band_to_full_complex(na, nqc, nblk, nbw, a, lda, tmat, q, ld
! Q = Q - V * T**T * V**T * Q
if (l_rows>0) then
call zgemm('C','N',n_cols,l_cols,l_rows,CONE,hvm,ubound(hvm,1), &
call zgemm('C','N',n_cols,l_cols,l_rows,CONE,hvm,ubound(hvm,dim=1), &
q,ldq,CZERO,tmp1,n_cols)
else
tmp1(1:l_cols*n_cols) = 0
endif
call mpi_allreduce(tmp1,tmp2,n_cols*l_cols,MPI_DOUBLE_COMPLEX,MPI_SUM,mpi_comm_rows,mpierr)
if (l_rows>0) then
call ztrmm('L','U','C','N',n_cols,l_cols,CONE,tmat(1,1,istep),ubound(tmat,1),tmp2,n_cols)
call zgemm('N','N',l_rows,l_cols,n_cols,-CONE,hvm,ubound(hvm,1), &
call ztrmm('L','U','C','N',n_cols,l_cols,CONE,tmat(1,1,istep),ubound(tmat,dim=1),tmp2,n_cols)
call zgemm('N','N',l_rows,l_cols,n_cols,-CONE,hvm,ubound(hvm,dim=1), &
tmp2,n_cols,CONE,q,ldq)
endif
......
......@@ -64,21 +64,21 @@
end function
!c> int elpa_solve_evp_real_stage1(int na, int nev, int ncols, double *a, int lda, double *ev, double *q, int ldq, int nblk, int mpi_comm_rows, int mpi_comm_cols);
function solve_elpa1_evp_real_wrapper(na, nev, ncols, a, lda, ev, q, ldq, nblk, &
mpi_comm_rows, mpi_comm_cols) &
!c> int elpa_solve_evp_real_stage1(int na, int nev, double *a, int lda, double *ev, double *q, int ldq, int nblk, int matrixCols, int mpi_comm_rows, int mpi_comm_cols);
function solve_elpa1_evp_real_wrapper(na, nev, a, lda, ev, q, ldq, nblk, &
matrixCols, mpi_comm_rows, mpi_comm_cols) &
result(success) bind(C,name="elpa_solve_evp_real_1stage")
use, intrinsic :: iso_c_binding
use elpa1, only : solve_evp_real
integer(kind=c_int) :: success
integer(kind=c_int), value, intent(in) :: na, nev, ncols, lda, ldq, nblk, mpi_comm_cols, mpi_comm_rows
real(kind=c_double) :: a(1:lda,1:ncols), ev(1:na), q(1:ldq,1:ncols)
integer(kind=c_int), value, intent(in) :: na, nev, lda, ldq, nblk, matrixCols, mpi_comm_cols, mpi_comm_rows
real(kind=c_double) :: a(1:lda,1:matrixCols), ev(1:na), q(1:ldq,1:matrixCols)
logical :: successFortran
successFortran = solve_evp_real(na, nev, a, lda, ev, q, ldq, nblk, mpi_comm_rows, mpi_comm_cols)
successFortran = solve_evp_real(na, nev, a, lda, ev, q, ldq, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols)
if (successFortran) then
success = 1
......@@ -88,22 +88,22 @@
end function
! int elpa_solve_evp_complex_stage1(int na, int nev, int ncols double_complex *a, int lda, double *ev, double_complex *q, int ldq, int nblk, int mpi_comm_rows, int mpi_comm_cols);
function solve_evp_real_wrapper(na, nev, ncols, a, lda, ev, q, ldq, nblk, &
mpi_comm_rows, mpi_comm_cols) &
! int elpa_solve_evp_complex_stage1(int na, int nev, double_complex *a, int lda, double *ev, double_complex *q, int ldq, int nblk, int matrixCols, int mpi_comm_rows, int mpi_comm_cols);
function solve_evp_real_wrapper(na, nev, a, lda, ev, q, ldq, nblk, &
matrixCols, mpi_comm_rows, mpi_comm_cols) &
result(success) bind(C,name="elpa_solve_evp_complex_1stage")
use, intrinsic :: iso_c_binding
use elpa1, only : solve_evp_complex
integer(kind=c_int) :: success
integer(kind=c_int), value, intent(in) :: na, nev, ncols, lda, ldq, nblk, mpi_comm_cols, mpi_comm_rows
complex(kind=c_double_complex) :: a(1:lda,1:ncols), q(1:ldq,1:ncols)
integer(kind=c_int), value, intent(in) :: na, nev, lda, ldq, nblk, matrixCols, mpi_comm_cols, mpi_comm_rows
complex(kind=c_double_complex) :: a(1:lda,1:matrixCols), q(1:ldq,1:matrixCols)
real(kind=c_double) :: ev(1:na)
logical :: successFortran
successFortran = solve_evp_complex(na, nev, a, lda, ev, q, ldq, nblk, mpi_comm_rows, mpi_comm_cols)
successFortran = solve_evp_complex(na, nev, a, lda, ev, q, ldq, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols)
if (successFortran) then
success = 1
......@@ -113,9 +113,9 @@
end function
!c> int elpa_solve_evp_real_stage2(int na, int nev, int ncols, double *a, int lda, double *ev, double *q, int ldq, int nblk, int mpi_comm_rows, int mpi_comm_cols, int THIS_REAL_ELPA_KERNEL_API, int useQR);
function solve_elpa2_evp_real_wrapper(na, nev, ncols, a, lda, ev, q, ldq, nblk, &
mpi_comm_rows, mpi_comm_cols, mpi_comm_all, &
!c> int elpa_solve_evp_real_stage2(int na, int nev, double *a, int lda, double *ev, double *q, int ldq, int nblk, int matrixCols, int mpi_comm_rows, int mpi_comm_cols, int THIS_REAL_ELPA_KERNEL_API, int useQR);
function solve_elpa2_evp_real_wrapper(na, nev, a, lda, ev, q, ldq, nblk, &
matrixCols, mpi_comm_rows, mpi_comm_cols, mpi_comm_all, &
THIS_REAL_ELPA_KERNEL_API, useQR) &
result(success) bind(C,name="elpa_solve_evp_real_2stage")
......@@ -123,10 +123,10 @@
use elpa2, only : solve_evp_real_2stage
integer(kind=c_int) :: success
integer(kind=c_int), value, intent(in) :: na, nev, ncols, lda, ldq, nblk, mpi_comm_cols, mpi_comm_rows, &
integer(kind=c_int), value, intent(in) :: na, nev, lda, ldq, nblk, matrixCols, mpi_comm_cols, mpi_comm_rows, &
mpi_comm_all
integer(kind=c_int), value, intent(in) :: THIS_REAL_ELPA_KERNEL_API, useQR
real(kind=c_double) :: a(1:lda,1:ncols), ev(1:na), q(1:ldq,1:ncols)
real(kind=c_double) :: a(1:lda,1:matrixCols), ev(1:na), q(1:ldq,1:matrixCols)
......@@ -138,7 +138,7 @@
useQRFortran = .true.
endif
successFortran = solve_evp_real_2stage(na, nev, a, lda, ev, q, ldq, nblk, mpi_comm_rows, mpi_comm_cols, mpi_comm_all, &
successFortran = solve_evp_real_2stage(na, nev, a, lda, ev, q, ldq, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, mpi_comm_all, &
THIS_REAL_ELPA_KERNEL_API, useQRFortran)
if (successFortran) then
......@@ -149,9 +149,9 @@
end function
! int elpa_solve_evp_complex_stage2(int na, int nev, int ncols, double_complex *a, int lda, double *ev, double_complex *q, int ldq, int nblk, int mpi_comm_rows, int mpi_comm_cols);
function solve_elpa2_evp_complex_wrapper(na, nev, ncols, a, lda, ev, q, ldq, nblk, &
mpi_comm_rows, mpi_comm_cols, mpi_comm_all, &
! int elpa_solve_evp_complex_stage2(int na, int nev, double_complex *a, int lda, double *ev, double_complex *q, int ldq, int nblk, int matrixCols, int mpi_comm_rows, int mpi_comm_cols);
function solve_elpa2_evp_complex_wrapper(na, nev, a, lda, ev, q, ldq, nblk, &
matrixCols, mpi_comm_rows, mpi_comm_cols, mpi_comm_all, &
THIS_COMPLEX_ELPA_KERNEL_API) &
result(success) bind(C,name="elpa_solve_evp_complex_2stage")
......@@ -159,14 +159,14 @@
use elpa2, only : solve_evp_complex_2stage
integer(kind=c_int) :: success
integer(kind=c_int), value, intent(in) :: na, nev, ncols, lda, ldq, nblk, mpi_comm_cols, mpi_comm_rows, &
integer(kind=c_int), value, intent(in) :: na, nev, lda, ldq, nblk, matrixCols, mpi_comm_cols, mpi_comm_rows, &
mpi_comm_all
integer(kind=c_int), value, intent(in) :: THIS_COMPLEX_ELPA_KERNEL_API
complex(kind=c_double_complex) :: a(1:lda,1:ncols), q(1:ldq,1:ncols)
complex(kind=c_double_complex) :: a(1:lda,1:matrixCols), q(1:ldq,1:matrixCols)
real(kind=c_double) :: ev(1:na)
logical :: successFortran
successFortran = solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, mpi_comm_rows, mpi_comm_cols, &
successFortran = solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, &
mpi_comm_all, THIS_COMPLEX_ELPA_KERNEL_API)
if (successFortran) then
......
#if REALCASE==1
subroutine elpa_reduce_add_vectors_real(vmat_s,ld_s,comm_s,vmat_t,ld_t,comm_t,nvr,nvc,nblk)
#endif
#if COMPLEXCASE==1
subroutine elpa_reduce_add_vectors_complex(vmat_s,ld_s,comm_s,vmat_t,ld_t,comm_t,nvr,nvc,nblk)
#endif
!-------------------------------------------------------------------------------
! This routine does a reduce of all vectors in vmat_s over the communicator comm_t.
! The result of the reduce is gathered on the processors owning the diagonal
! and added to the array of vectors vmat_t (which is distributed over comm_t).
!
! Opposed to elpa_transpose_vectors, there is NO identical copy of vmat_s
! in the different members within vmat_t (else a reduce wouldn't be necessary).
! After this routine, an allreduce of vmat_t has to be done.
!
! vmat_s array of vectors to be reduced and added
! ld_s leading dimension of vmat_s
! comm_s communicator over which vmat_s is distributed
! vmat_t array of vectors to which vmat_s is added
! ld_t leading dimension of vmat_t
! comm_t communicator over which vmat_t is distributed
! nvr global length of vmat_s/vmat_t
! nvc number of columns in vmat_s/vmat_t
! nblk block size of block cyclic distribution
!
!-------------------------------------------------------------------------------
! use ELPA1 ! for least_common_multiple
implicit none
include 'mpif.h'
integer, intent(in) :: ld_s, comm_s, ld_t, comm_t, nvr, nvc, nblk
DATATYPE*BYTESIZE, intent(in) :: vmat_s(ld_s,nvc)
DATATYPE*BYTESIZE, intent(inout) :: vmat_t(ld_t,nvc)
DATATYPE*BYTESIZE, allocatable :: aux1(:), aux2(:)
integer myps, mypt, nps, npt
integer n, lc, k, i, ips, ipt, ns, nl, mpierr
integer lcm_s_t, nblks_tot
call mpi_comm_rank(comm_s,myps,mpierr)
call mpi_comm_size(comm_s,nps ,mpierr)
call mpi_comm_rank(comm_t,mypt,mpierr)
call mpi_comm_size(comm_t,npt ,mpierr)
! Look to elpa_transpose_vectors for the basic idea!
! The communictation pattern repeats in the global matrix after
! the least common multiple of (nps,npt) blocks
lcm_s_t = least_common_multiple(nps,npt) ! least common multiple of nps, npt
nblks_tot = (nvr+nblk-1)/nblk ! number of blocks corresponding to nvr
allocate(aux1( ((nblks_tot+lcm_s_t-1)/lcm_s_t) * nblk * nvc ))
allocate(aux2( ((nblks_tot+lcm_s_t-1)/lcm_s_t) * nblk * nvc ))
aux1(:) = 0
aux2(:) = 0
do n = 0, lcm_s_t-1
ips = mod(n,nps)
ipt = mod(n,npt)
if(myps == ips) then
k = 0
do lc=1,nvc
do i = n, nblks_tot-1, lcm_s_t
ns = (i/nps)*nblk ! local start of block i
nl = min(nvr-i*nblk,nblk) ! length
aux1(k+1:k+nl) = vmat_s(ns+1:ns+nl,lc)
k = k+nblk
enddo
enddo
#if REALCASE==1
if(k>0) call mpi_reduce(aux1,aux2,k,MPI_REAL8,MPI_SUM,ipt,comm_t,mpierr)
#endif
#if COMPLEXCASE==1
if(k>0) call mpi_reduce(aux1,aux2,k,MPI_DOUBLE_COMPLEX,MPI_SUM,ipt,comm_t,mpierr)
#endif
if(mypt == ipt) then
k = 0
do lc=1,nvc
do i = n, nblks_tot-1, lcm_s_t
ns = (i/npt)*nblk ! local start of block i
nl = min(nvr-i*nblk,nblk) ! length
vmat_t(ns+1:ns+nl,lc) = vmat_t(ns+1:ns+nl,lc) + aux2(k+1:k+nl)
k = k+nblk
enddo
enddo
endif
endif
enddo
deallocate(aux1)
deallocate(aux2)
end subroutine
#if REALCASE==1
subroutine elpa_transpose_vectors_real(vmat_s,ld_s,comm_s,vmat_t,ld_t,comm_t,nvs,nvr,nvc,nblk)
#endif
#if COMPLEXCASE==1
subroutine elpa_transpose_vectors_complex(vmat_s,ld_s,comm_s,vmat_t,ld_t,comm_t,nvs,nvr,nvc,nblk)
#endif
!-------------------------------------------------------------------------------
! This routine transposes an array of vectors which are distributed in
! communicator comm_s into its transposed form distributed in communicator comm_t.
! There must be an identical copy of vmat_s in every communicator comm_s.
! After this routine, there is an identical copy of vmat_t in every communicator comm_t.
!
! vmat_s original array of vectors
! ld_s leading dimension of vmat_s
! comm_s communicator over which vmat_s is distributed
! vmat_t array of vectors in transposed form
! ld_t leading dimension of vmat_t
! comm_t communicator over which vmat_t is distributed
! nvs global index where to start in vmat_s/vmat_t