Commit 6febaf23 authored by Pavel Kus's avatar Pavel Kus

passing useGPU to tridiagonal solver

parent 8b5fa5e3
...@@ -57,7 +57,7 @@ ...@@ -57,7 +57,7 @@
subroutine merge_systems_& subroutine merge_systems_&
&PRECISION & &PRECISION &
(obj, na, nm, d, e, q, ldq, nqoff, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, & (obj, na, nm, d, e, q, ldq, nqoff, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, &
l_col, p_col, l_col_out, p_col_out, npc_0, npc_n, wantDebug, success) l_col, p_col, l_col_out, p_col_out, npc_0, npc_n, useGPU, wantDebug, success)
use precision use precision
use elpa_abstract_impl use elpa_abstract_impl
...@@ -73,7 +73,7 @@ ...@@ -73,7 +73,7 @@
#else #else
real(kind=REAL_DATATYPE), intent(inout) :: q(ldq,matrixCols) real(kind=REAL_DATATYPE), intent(inout) :: q(ldq,matrixCols)
#endif #endif
logical, intent(in) :: wantDebug logical, intent(in) :: useGPU, wantDebug
logical, intent(out) :: success logical, intent(out) :: success
integer(kind=ik), parameter :: max_strip=128 integer(kind=ik), parameter :: max_strip=128
...@@ -729,9 +729,14 @@ ...@@ -729,9 +729,14 @@
! 1.d0,qtmp2(1,1),ubound(qtmp2,1)) ! 1.d0,qtmp2(1,1),ubound(qtmp2,1))
! else ! else
call obj%timer%start("blas") call obj%timer%start("blas")
call obj%timer%start("gemm-first")
if (l_rnm>0 .and. ncnt>0 .and. nnzu>0) & if (l_rnm>0 .and. ncnt>0 .and. nnzu>0) &
call PRECISION_GEMM('N', 'N', l_rnm, ncnt, nnzu, 1.0_rk, qtmp1, ubound(qtmp1,dim=1), ev, ubound(ev,dim=1), & !write(*,*) "merging-first", l_rnm, ncnt, nnzu
call PRECISION_GEMM('N', 'N', l_rnm, ncnt, nnzu, &
1.0_rk, qtmp1, ubound(qtmp1,dim=1), &
ev, ubound(ev,dim=1), &
1.0_rk, qtmp2(1,1), ubound(qtmp2,dim=1)) 1.0_rk, qtmp2(1,1), ubound(qtmp2,dim=1))
call obj%timer%stop("gemm-first")
call obj%timer%stop("blas") call obj%timer%stop("blas")
! endif ! useGPU ! endif ! useGPU
! Compute eigenvectors of the rank-1 modified matrix. ! Compute eigenvectors of the rank-1 modified matrix.
...@@ -756,9 +761,14 @@ ...@@ -756,9 +761,14 @@
! 1.d0,qtmp2(l_rnm+1,1),ubound(qtmp2,1)) ! 1.d0,qtmp2(l_rnm+1,1),ubound(qtmp2,1))
! else ! else
call obj%timer%start("blas") call obj%timer%start("blas")
call obj%timer%start("gemm")
if (l_rows-l_rnm>0 .and. ncnt>0 .and. nnzl>0) & if (l_rows-l_rnm>0 .and. ncnt>0 .and. nnzl>0) &
call PRECISION_GEMM('N', 'N', l_rows-l_rnm, ncnt, nnzl, 1.0_rk, qtmp1(l_rnm+1,1), ubound(qtmp1,dim=1), ev, & !write(*,*) "merging ", l_rows-l_rnm, ncnt, nnzl
ubound(ev,dim=1), 1.0_rk, qtmp2(l_rnm+1,1), ubound(qtmp2,dim=1)) call PRECISION_GEMM('N', 'N', l_rows-l_rnm, ncnt, nnzl, &
1.0_rk, qtmp1(l_rnm+1,1), ubound(qtmp1,dim=1), &
ev, ubound(ev,dim=1), &
1.0_rk, qtmp2(l_rnm+1,1), ubound(qtmp2,dim=1))
call obj%timer%stop("gemm")
call obj%timer%stop("blas") call obj%timer%stop("blas")
! endif ! useGPU ! endif ! useGPU
! Put partial result into (output) Q ! Put partial result into (output) Q
...@@ -767,8 +777,8 @@ ...@@ -767,8 +777,8 @@
q(l_rqs:l_rqe,l_col_out(idxq1(i+ns))) = qtmp2(1:l_rows,i) q(l_rqs:l_rqe,l_col_out(idxq1(i+ns))) = qtmp2(1:l_rows,i)
enddo enddo
enddo enddo !ns = 0, nqcols1-1, max_strip ! strimining loop
enddo enddo !do np = 1, npc_n
deallocate(ev, qtmp1, qtmp2, stat=istat, errmsg=errorMessage) deallocate(ev, qtmp1, qtmp2, stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then if (istat .ne. 0) then
......
...@@ -57,7 +57,7 @@ ...@@ -57,7 +57,7 @@
subroutine solve_tridi_& subroutine solve_tridi_&
&PRECISION_AND_SUFFIX & &PRECISION_AND_SUFFIX &
( obj, na, nev, d, e, q, ldq, nblk, matrixCols, mpi_comm_rows, & ( obj, na, nev, d, e, q, ldq, nblk, matrixCols, mpi_comm_rows, &
mpi_comm_cols, wantDebug, success ) mpi_comm_cols, useGPU, wantDebug, success )
use precision use precision
use elpa_abstract_impl use elpa_abstract_impl
...@@ -71,7 +71,7 @@ subroutine solve_tridi_& ...@@ -71,7 +71,7 @@ subroutine solve_tridi_&
#else #else
real(kind=REAL_DATATYPE), intent(inout) :: q(ldq,matrixCols) real(kind=REAL_DATATYPE), intent(inout) :: q(ldq,matrixCols)
#endif #endif
logical, intent(in) :: wantDebug logical, intent(in) :: useGPU, wantDebug
logical, intent(out) :: success logical, intent(out) :: success
integer(kind=ik) :: i, j, n, np, nc, nev1, l_cols, l_rows integer(kind=ik) :: i, j, n, np, nc, nev1, l_cols, l_rows
...@@ -145,7 +145,7 @@ subroutine solve_tridi_& ...@@ -145,7 +145,7 @@ subroutine solve_tridi_&
call solve_tridi_col_& call solve_tridi_col_&
&PRECISION_AND_SUFFIX & &PRECISION_AND_SUFFIX &
(obj, l_cols, nev1, nc, d(nc+1), e(nc+1), q, ldq, nblk, & (obj, l_cols, nev1, nc, d(nc+1), e(nc+1), q, ldq, nblk, &
matrixCols, mpi_comm_rows, wantDebug, success) matrixCols, mpi_comm_rows, useGPU, wantDebug, success)
if (.not.(success)) then if (.not.(success)) then
call obj%timer%stop("solve_tridi" // PRECISION_SUFFIX) call obj%timer%stop("solve_tridi" // PRECISION_SUFFIX)
return return
...@@ -220,7 +220,7 @@ subroutine solve_tridi_& ...@@ -220,7 +220,7 @@ subroutine solve_tridi_&
! Recursively merge sub problems ! Recursively merge sub problems
call merge_recursive_& call merge_recursive_&
&PRECISION & &PRECISION &
(obj, 0, np_cols, wantDebug, success) (obj, 0, np_cols, useGPU, wantDebug, success)
if (.not.(success)) then if (.not.(success)) then
call obj%timer%stop("solve_tridi" // PRECISION_SUFFIX) call obj%timer%stop("solve_tridi" // PRECISION_SUFFIX)
return return
...@@ -238,7 +238,7 @@ subroutine solve_tridi_& ...@@ -238,7 +238,7 @@ subroutine solve_tridi_&
contains contains
recursive subroutine merge_recursive_& recursive subroutine merge_recursive_&
&PRECISION & &PRECISION &
(obj, np_off, nprocs, wantDebug, success) (obj, np_off, nprocs, useGPU, wantDebug, success)
use precision use precision
use elpa_abstract_impl use elpa_abstract_impl
implicit none implicit none
...@@ -252,7 +252,7 @@ subroutine solve_tridi_& ...@@ -252,7 +252,7 @@ subroutine solve_tridi_&
#ifdef WITH_MPI #ifdef WITH_MPI
! integer(kind=ik) :: my_mpi_status(mpi_status_size) ! integer(kind=ik) :: my_mpi_status(mpi_status_size)
#endif #endif
logical, intent(in) :: wantDebug logical, intent(in) :: useGPU, wantDebug
logical, intent(out) :: success logical, intent(out) :: success
success = .true. success = .true.
...@@ -270,11 +270,11 @@ subroutine solve_tridi_& ...@@ -270,11 +270,11 @@ subroutine solve_tridi_&
if (np1 > 1) call merge_recursive_& if (np1 > 1) call merge_recursive_&
&PRECISION & &PRECISION &
(obj, np_off, np1, wantDebug, success) (obj, np_off, np1, useGPU, wantDebug, success)
if (.not.(success)) return if (.not.(success)) return
if (np2 > 1) call merge_recursive_& if (np2 > 1) call merge_recursive_&
&PRECISION & &PRECISION &
(obj, np_off+np1, np2, wantDebug, success) (obj, np_off+np1, np2, useGPU, wantDebug, success)
if (.not.(success)) return if (.not.(success)) return
noff = limits(np_off) noff = limits(np_off)
...@@ -328,7 +328,7 @@ subroutine solve_tridi_& ...@@ -328,7 +328,7 @@ subroutine solve_tridi_&
&PRECISION & &PRECISION &
(obj, nlen, nmid, d(noff+1), e(noff+nmid), q, ldq, noff, & (obj, nlen, nmid, d(noff+1), e(noff+nmid), q, ldq, noff, &
nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, l_col, p_col, & nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, l_col, p_col, &
l_col_bc, p_col_bc, np_off, nprocs, wantDebug, success ) l_col_bc, p_col_bc, np_off, nprocs, useGPU, wantDebug, success )
if (.not.(success)) return if (.not.(success)) return
else else
! Not last merge, leave dense column distribution ! Not last merge, leave dense column distribution
...@@ -336,7 +336,7 @@ subroutine solve_tridi_& ...@@ -336,7 +336,7 @@ subroutine solve_tridi_&
&PRECISION & &PRECISION &
(obj, nlen, nmid, d(noff+1), e(noff+nmid), q, ldq, noff, & (obj, nlen, nmid, d(noff+1), e(noff+nmid), q, ldq, noff, &
nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, l_col(noff+1), p_col(noff+1), & nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, l_col(noff+1), p_col(noff+1), &
l_col(noff+1), p_col(noff+1), np_off, nprocs, wantDebug, success ) l_col(noff+1), p_col(noff+1), np_off, nprocs, useGPU, wantDebug, success )
if (.not.(success)) return if (.not.(success)) return
endif endif
end subroutine merge_recursive_& end subroutine merge_recursive_&
...@@ -347,7 +347,7 @@ subroutine solve_tridi_& ...@@ -347,7 +347,7 @@ subroutine solve_tridi_&
subroutine solve_tridi_col_& subroutine solve_tridi_col_&
&PRECISION_AND_SUFFIX & &PRECISION_AND_SUFFIX &
( obj, na, nev, nqoff, d, e, q, ldq, nblk, matrixCols, mpi_comm_rows, wantDebug, success ) ( obj, na, nev, nqoff, d, e, q, ldq, nblk, matrixCols, mpi_comm_rows, useGPU, wantDebug, success )
! Solves the symmetric, tridiagonal eigenvalue problem on one processor column ! Solves the symmetric, tridiagonal eigenvalue problem on one processor column
! with the divide and conquer method. ! with the divide and conquer method.
...@@ -373,7 +373,7 @@ subroutine solve_tridi_& ...@@ -373,7 +373,7 @@ subroutine solve_tridi_&
integer(kind=ik) :: my_prow, np_rows, mpierr integer(kind=ik) :: my_prow, np_rows, mpierr
integer(kind=ik), allocatable :: limits(:), l_col(:), p_col_i(:), p_col_o(:) integer(kind=ik), allocatable :: limits(:), l_col(:), p_col_i(:), p_col_o(:)
logical, intent(in) :: wantDebug logical, intent(in) :: useGPU, wantDebug
logical, intent(out) :: success logical, intent(out) :: success
integer(kind=ik) :: istat integer(kind=ik) :: istat
character(200) :: errorMessage character(200) :: errorMessage
...@@ -550,7 +550,7 @@ subroutine solve_tridi_& ...@@ -550,7 +550,7 @@ subroutine solve_tridi_&
&PRECISION & &PRECISION &
(obj, nlen, nmid, d(noff+1), e(noff+nmid), q, ldq, nqoff+noff, nblk, & (obj, nlen, nmid, d(noff+1), e(noff+nmid), q, ldq, nqoff+noff, nblk, &
matrixCols, mpi_comm_rows, mpi_comm_self, l_col(noff+1), p_col_i(noff+1), & matrixCols, mpi_comm_rows, mpi_comm_self, l_col(noff+1), p_col_i(noff+1), &
l_col(noff+1), p_col_o(noff+1), 0, 1, wantDebug, success) l_col(noff+1), p_col_o(noff+1), 0, 1, useGPU, wantDebug, success)
if (.not.(success)) return if (.not.(success)) return
enddo enddo
......
...@@ -314,7 +314,7 @@ function elpa_solve_evp_& ...@@ -314,7 +314,7 @@ function elpa_solve_evp_&
#if COMPLEXCASE == 1 #if COMPLEXCASE == 1
q_real, l_rows, & q_real, l_rows, &
#endif #endif
nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, wantDebug, success) nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, do_useGPU, wantDebug, success)
call obj%timer%stop("solve") call obj%timer%stop("solve")
if (.not.(success)) return if (.not.(success)) return
endif !do_solve endif !do_solve
......
...@@ -116,7 +116,7 @@ ...@@ -116,7 +116,7 @@
call solve_tridi_& call solve_tridi_&
&PRECISION& &PRECISION&
&_private_impl(obj, na, nev, d, e, q, ldq, nblk, matrixCols, & &_private_impl(obj, na, nev, d, e, q, ldq, nblk, matrixCols, &
mpi_comm_rows, mpi_comm_cols, wantDebug, success) mpi_comm_rows, mpi_comm_cols,.false., wantDebug, success)
call obj%timer%stop("elpa_solve_tridi_public_& call obj%timer%stop("elpa_solve_tridi_public_&
&MATH_DATATYPE& &MATH_DATATYPE&
......
...@@ -541,7 +541,7 @@ ...@@ -541,7 +541,7 @@
#if COMPLEXCASE == 1 #if COMPLEXCASE == 1
q_real, ubound(q_real,dim=1), & q_real, ubound(q_real,dim=1), &
#endif #endif
nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, wantDebug, success) nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, do_useGPU, wantDebug, success)
call obj%timer%stop("solve") call obj%timer%stop("solve")
if (.not.(success)) return if (.not.(success)) return
endif ! do_solve_tridi endif ! do_solve_tridi
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment