Commit 68e4bfa8 authored by Pavel Kus's avatar Pavel Kus

real/complex unification of elpa1_trans_ev

parent 1839109d
......@@ -98,61 +98,28 @@
use precision
use elpa_abstract_impl
implicit none
#include "../general/precision_kinds.F90"
class(elpa_abstract_impl_t), intent(inout) :: obj
integer(kind=ik), intent(in) :: na, nqc, lda, ldq, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols
#if REALCASE == 1
real(kind=REAL_DATATYPE), intent(in) :: tau(na)
#endif
#if COMPLEXCASE == 1
complex(kind=COMPLEX_DATATYPE), intent(in) :: tau(na)
#endif
MATH_DATATYPE(kind=rck), intent(in) :: tau(na)
#if REALCASE == 1
#ifdef USE_ASSUMED_SIZE
real(kind=REAL_DATATYPE), intent(inout) :: a_mat(lda,*), q_mat(ldq,*)
MATH_DATATYPE(kind=rck), intent(inout) :: a_mat(lda,*), q_mat(ldq,*)
#else
real(kind=REAL_DATATYPE), intent(inout) :: a_mat(lda,matrixCols), q_mat(ldq,matrixCols)
#endif
#endif
#if COMPLEXCASE == 1
#ifdef USE_ASSUMED_SIZE
complex(kind=COMPLEX_DATATYPE), intent(inout) :: a_mat(lda,*), q_mat(ldq,*)
#else
complex(kind=COMPLEX_DATATYPE), intent(inout) :: a_mat(lda,matrixCols), q_mat(ldq,matrixCols)
#endif
MATH_DATATYPE(kind=rck), intent(inout) :: a_mat(lda,matrixCols), q_mat(ldq,matrixCols)
#endif
logical, intent(in) :: useGPU
integer(kind=ik) :: max_stored_rows
#if REALCASE == 1
#ifdef DOUBLE_PRECISION_REAL
real(kind=rk8), parameter :: ZERO = 0.0_rk8, ONE = 1.0_rk8
#else
real(kind=rk4), parameter :: ZERO = 0.0_rk4, ONE = 1.0_rk4
#endif
#endif
#if COMPLEXCASE == 1
#ifdef DOUBLE_PRECISION_COMPLEX
complex(kind=ck8), parameter :: ZERO = (0.0_rk8,0.0_rk8), ONE = (1.0_rk8,0.0_rk8)
#else
complex(kind=ck4), parameter :: ZERO = (0.0_rk4,0.0_rk4), ONE = (1.0_rk4,0.0_rk4)
#endif
#endif
integer(kind=ik) :: my_prow, my_pcol, np_rows, np_cols, mpierr
integer(kind=ik) :: totalblocks, max_blocks_row, max_blocks_col, max_local_rows, max_local_cols
integer(kind=ik) :: l_cols, l_rows, l_colh, nstor
integer(kind=ik) :: istep, n, nc, ic, ics, ice, nb, cur_pcol
integer(kind=ik) :: hvn_ubnd, hvm_ubnd
#if REALCASE == 1
real(kind=REAL_DATATYPE), allocatable :: tmp1(:), tmp2(:), hvb(:), hvm(:,:)
real(kind=REAL_DATATYPE), allocatable :: tmat(:,:), h1(:), h2(:), hvm1(:)
#endif
#if COMPLEXCASE == 1
complex(kind=COMPLEX_DATATYPE), allocatable :: tmp1(:), tmp2(:), hvb(:), hvm(:,:)
complex(kind=COMPLEX_DATATYPE), allocatable :: tmat(:,:), h1(:), h2(:), hvm1(:)
#endif
MATH_DATATYPE(kind=rck), allocatable :: tmp1(:), tmp2(:), hvb(:), hvm(:,:)
MATH_DATATYPE(kind=rck), allocatable :: tmat(:,:), h1(:), h2(:), hvm1(:)
integer(kind=ik) :: istat
character(200) :: errorMessage
character(20) :: gpuString
......@@ -294,15 +261,7 @@
#ifdef WITH_MPI
call obj%timer%start("mpi_communication")
if (nb>0) &
call MPI_Bcast(hvb, nb, &
#if REALCASE == 1
&MPI_REAL_PRECISION&
#endif
#if COMPLEXCASE == 1
&MPI_COMPLEX_PRECISION&
#endif
, cur_pcol, mpi_comm_cols, mpierr)
call MPI_Bcast(hvb, nb, MPI_MATH_DATATYPE_PRECISION , cur_pcol, mpi_comm_cols, mpierr)
call obj%timer%stop("mpi_communication")
#endif /* WITH_MPI */
......@@ -341,14 +300,7 @@
enddo
#ifdef WITH_MPI
call obj%timer%start("mpi_communication")
if (nc>0) call mpi_allreduce( h1, h2, nc, &
#if REALCASE == 1
&MPI_REAL_PRECISION&
#endif
#if COMPLEXCASE == 1
&MPI_COMPLEX_PRECISION&
#endif
&, MPI_SUM, mpi_comm_rows, mpierr)
if (nc>0) call mpi_allreduce( h1, h2, nc, MPI_MATH_DATATYPE_PRECISION, MPI_SUM, mpi_comm_rows, mpierr)
call obj%timer%stop("mpi_communication")
#else /* WITH_MPI */
......@@ -361,13 +313,7 @@
tmat(1,1) = tau(ice-nstor+1)
do n = 1, nstor-1
call obj%timer%start("blas")
#if REALCASE == 1
call PRECISION_TRMV('L', 'T', 'N', &
#endif
#if COMPLEXCASE == 1
call PRECISION_TRMV('L', 'C', 'N', &
#endif
n, tmat, max_stored_rows, h2(nc+1), 1)
call PRECISION_TRMV('L', BLAS_TRANS_OR_CONJ , 'N', n, tmat, max_stored_rows, h2(nc+1), 1)
call obj%timer%stop("blas")
tmat(n+1,1:n) = &
......@@ -404,25 +350,15 @@
if (l_rows>0) then
if (useGPU) then
call obj%timer%start("cublas")
#if REALCASE == 1
call cublas_PRECISION_GEMM('T', 'N', &
#endif
#if COMPLEXCASE == 1
call cublas_PRECISION_GEMM('C', 'N', &
#endif
call cublas_PRECISION_GEMM(BLAS_TRANS_OR_CONJ, 'N', &
nstor, l_cols, l_rows, ONE, hvm_dev, hvm_ubnd, &
q_dev, ldq, ZERO, tmp_dev, nstor)
q_dev, ldq, ZERO, tmp_dev, nstor)
call obj%timer%stop("cublas")
else ! useGPU
call obj%timer%start("blas")
#if REALCASE == 1
call PRECISION_GEMM('T', 'N', &
#endif
#if COMPLEXCASE == 1
call PRECISION_GEMM('C', 'N', &
#endif
call PRECISION_GEMM(BLAS_TRANS_OR_CONJ, 'N', &
nstor, l_cols, l_rows, ONE, hvm, ubound(hvm,dim=1), &
q_mat, ldq, ZERO, tmp1, nstor)
call obj%timer%stop("blas")
......@@ -447,14 +383,7 @@
check_memcpy_cuda("trans_ev", successCUDA)
endif
call obj%timer%start("mpi_communication")
call mpi_allreduce(tmp1, tmp2, nstor*l_cols, &
#if REALCASE == 1
&MPI_REAL_PRECISION&
#endif
#if COMPLEXCASE == 1
&MPI_COMPLEX_PRECISION&
#endif
&, MPI_SUM, mpi_comm_rows, mpierr)
call mpi_allreduce(tmp1, tmp2, nstor*l_cols, MPI_MATH_DATATYPE_PRECISION, MPI_SUM, mpi_comm_rows, mpierr)
call obj%timer%stop("mpi_communication")
! copy back tmp2 - after reduction...
if (useGPU) then
......
......@@ -143,7 +143,7 @@ call prmat(na,useGpu,a_mat,a_dev,lda,matrixCols,nblk,my_prow,my_pcol,np_rows,np_
integer(kind=ik) :: omp_get_thread_num, omp_get_num_threads, omp_get_max_threads
#endif
real(kind=REAL_DATATYPE) :: vnorm2
real(kind=rk) :: vnorm2
MATH_DATATYPE(kind=rck) :: vav, x, aux(2*max_stored_uv), aux1(2), aux2(2), vrl, xf
#if COMPLEXCASE == 1
complex(kind=rck) :: aux3(1)
......@@ -167,7 +167,7 @@ call prmat(na,useGpu,a_mat,a_dev,lda,matrixCols,nblk,my_prow,my_pcol,np_rows,np_
MATH_DATATYPE(kind=rck), allocatable :: ur_p(:,:), uc_p(:,:)
#endif
real(kind=rk), allocatable :: tmp_real(:)
real(kind=rk), allocatable :: tmp_real(:)
integer(kind=ik) :: min_tile_size, error
integer(kind=ik) :: istat
character(200) :: errorMessage
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment