Commit 58c393cd authored by Pavel Kus's avatar Pavel Kus

further real/complex unifications

in elpa2_bandred_template.F90
parent e62e2dd8
......@@ -591,13 +591,7 @@
#ifdef WITH_MPI
if (wantDebug) call obj%timer%start("mpi_communication")
call mpi_allreduce(aux1, aux2, 2, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
call mpi_allreduce(aux1, aux2, 2, MPI_MATH_DATATYPE_PRECISION, &
MPI_SUM, mpi_comm_rows, mpierr)
if (wantDebug) call obj%timer%stop("mpi_communication")
......@@ -609,14 +603,11 @@
vrl = aux2(2)
! Householder transformation
#if REALCASE == 1
call hh_transform_real_&
#endif
#if COMPLEXCASE == 1
call hh_transform_complex_&
#endif
call hh_transform_&
&MATH_DATATYPE&
&_&
&PRECISION &
(obj, vrl, vnorm2, xf, tau, wantDebug)
(obj, vrl, vnorm2, xf, tau, wantDebug)
! Scale vr and store Householder Vector for back transformation
vr(1:lr) = vr(1:lr) * xf
......@@ -635,13 +626,7 @@
vr(lr+1) = tau
#ifdef WITH_MPI
if (wantDebug) call obj%timer%start("mpi_communication")
call MPI_Bcast(vr, lr+1, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
call MPI_Bcast(vr, lr+1, MPI_MATH_DATATYPE_PRECISION, &
cur_pcol, mpi_comm_cols, mpierr)
if (wantDebug) call obj%timer%stop("mpi_communication")
......@@ -754,14 +739,8 @@
!$omp single
#ifdef WITH_MPI
if (wantDebug) call obj%timer%start("mpi_communication")
if (mynlc>0) call mpi_allreduce(aux1, aux2, mynlc, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
MPI_SUM, mpi_comm_rows, mpierr)
if (mynlc>0) call mpi_allreduce(aux1, aux2, mynlc, MPI_MATH_DATATYPE_PRECISION,
MPI_SUM, mpi_comm_rows, mpierr)
if (wantDebug) call obj%timer%stop("mpi_communication")
#else /* WITH_MPI */
if (mynlc>0) aux2 = aux1
......@@ -807,14 +786,8 @@
! Get global dot products
#ifdef WITH_MPI
if (wantDebug) call obj%timer%start("mpi_communication")
if (nlc>0) call mpi_allreduce(aux1, aux2, nlc, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION,&
#endif
MPI_SUM, mpi_comm_rows, mpierr)
if (nlc>0) call mpi_allreduce(aux1, aux2, nlc, MPI_MATH_DATATYPE_PRECISION, &
MPI_SUM, mpi_comm_rows, mpierr)
if (wantDebug) call obj%timer%stop("mpi_communication")
#else /* WITH_MPI */
if (nlc>0) aux2=aux1
......@@ -899,12 +872,7 @@
do lc=n_cols,1,-1
tau = tmat(lc,lc,istep)
if (lc<n_cols) then
#if REALCASE == 1
call PRECISION_TRMV('U', 'T', 'N', &
#endif
#if COMPLEXCASE == 1
call PRECISION_TRMV('U', 'C', 'N', &
#endif
call PRECISION_TRMV('U', BLAS_TRANS_OR_CONJ, 'N',&
n_cols-lc, tmat(lc+1,lc+1,istep), ubound(tmat,dim=1), vav(lc+1,lc), 1)
#if REALCASE == 1
......@@ -1040,12 +1008,7 @@
! C1 += A10' B0
if ( lce > lcs .and. i > 0 ) then
call obj%timer%start("blas")
#if REALCASE == 1
call PRECISION_GEMM('T', 'N', &
#endif
#if COMPLEXCASE == 1
call PRECISION_GEMM('C', 'N', &
#endif
call PRECISION_GEMM(BLAS_TRANS_OR_CONJ, 'N', &
lce-lcs+1, n_cols, lrs-1, &
ONE, a(1,lcs), ubound(a,dim=1), &
vmrCPU(1,1), ubound(vmrCPU,dim=1), &
......@@ -1099,12 +1062,7 @@
if (useGPU) then
call obj%timer%start("cublas")
#if REALCASE == 1
call cublas_PRECISION_GEMM('T', 'N', &
#endif
#if COMPLEXCASE == 1
call cublas_PRECISION_GEMM('C', 'N', &
#endif
call cublas_PRECISION_GEMM(BLAS_TRANS_OR_CONJ, 'N', &
lce-lcs+1, n_cols, lre, &
ONE, (a_dev + ((lcs-1)*lda* &
size_of_datatype)), &
......@@ -1131,12 +1089,7 @@
else ! useGPU
call obj%timer%start("blas")
#if REALCASE == 1
call PRECISION_GEMM('T', 'N', &
#endif
#if COMPLEXCASE == 1
call PRECISION_GEMM('C', 'N', &
#endif
call PRECISION_GEMM(BLAS_TRANS_OR_CONJ, 'N', &
lce-lcs+1, n_cols, lre, ONE, a(1,lcs), ubound(a,dim=1), &
vmrCPU, ubound(vmrCPU,dim=1), ONE, umcCPU(lcs,1), ubound(umcCPU,dim=1))
call obj%timer%stop("blas")
......@@ -1222,13 +1175,7 @@
if (wantDebug) call obj%timer%start("mpi_communication")
call mpi_allreduce(umcCUDA, tmpCUDA, l_cols*n_cols, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
call mpi_allreduce(umcCUDA, tmpCUDA, l_cols*n_cols, MPI_MATH_DATATYPE_PRECISION, &
MPI_SUM, mpi_comm_rows, ierr)
umcCUDA(1 : l_cols * n_cols) = tmpCUDA(1 : l_cols * n_cols)
......@@ -1261,13 +1208,7 @@
#ifdef WITH_MPI
if (wantDebug) call obj%timer%start("mpi_communication")
call mpi_allreduce(umcCPU, tmpCPU, l_cols*n_cols, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
call mpi_allreduce(umcCPU, tmpCPU, l_cols*n_cols, MPI_MATH_DATATYPE_PRECISION, &
MPI_SUM, mpi_comm_rows, mpierr)
umcCPU(1:l_cols,1:n_cols) = tmpCPU(1:l_cols,1:n_cols)
if (wantDebug) call obj%timer%stop("mpi_communication")
......@@ -1306,12 +1247,7 @@
endif
call obj%timer%start("cublas")
#if REALCASE == 1
call cublas_PRECISION_TRMM('Right', 'Upper', 'Trans', 'Nonunit', &
#endif
#if COMPLEXCASE == 1
call cublas_PRECISION_TRMM('Right', 'Upper', 'C', 'Nonunit', &
#endif
call cublas_PRECISION_TRMM('Right', 'Upper', BLAS_TRANS_OR_CONJ, 'Nonunit', &
l_cols, n_cols, ONE, tmat_dev, nbw, umc_dev, cur_l_cols)
call obj%timer%stop("cublas")
......@@ -1325,22 +1261,12 @@
endif
call obj%timer%start("cublas")
#if REALCASE == 1
call cublas_PRECISION_GEMM('T', 'N', &
#endif
#if COMPLEXCASE == 1
call cublas_PRECISION_GEMM('C', 'N', &
#endif
call cublas_PRECISION_GEMM(BLAS_TRANS_OR_CONJ, 'N', &
n_cols, n_cols, l_cols, ONE, umc_dev, cur_l_cols, &
(umc_dev+(cur_l_cols * n_cols )*size_of_datatype),cur_l_cols, &
ZERO, vav_dev, nbw)
#if REALCASE == 1
call cublas_PRECISION_TRMM('Right', 'Upper', 'Trans', 'Nonunit', &
#endif
#if COMPLEXCASE == 1
call cublas_PRECISION_TRMM('Right', 'Upper', 'C', 'Nonunit', &
#endif
call cublas_PRECISION_TRMM('Right', 'Upper', BLAS_TRANS_OR_CONJ, 'Nonunit', &
n_cols, n_cols, ONE, tmat_dev, nbw, vav_dev, nbw)
call obj%timer%stop("cublas")
......@@ -1355,32 +1281,17 @@
call obj%timer%start("blas")
#if REALCASE == 1
call PRECISION_TRMM('Right', 'Upper', 'Trans', 'Nonunit', &
#endif
#if COMPLEXCASE == 1
call PRECISION_TRMM('Right', 'Upper', 'C', 'Nonunit', &
#endif
call PRECISION_TRMM('Right', 'Upper', BLAS_TRANS_OR_CONJ, 'Nonunit', &
l_cols,n_cols, ONE, tmat(1,1,istep), ubound(tmat,dim=1), &
umcCPU, ubound(umcCPU,dim=1))
! VAV = Tmat * V**T * A * V * Tmat**T = (U*Tmat**T)**T * V * Tmat**T
#if REALCASE == 1
call PRECISION_GEMM('T', 'N', &
#endif
#if COMPLEXCASE == 1
call PRECISION_GEMM('C', 'N', &
#endif
call PRECISION_GEMM(BLAS_TRANS_OR_CONJ, 'N', &
n_cols, n_cols, l_cols, ONE, umcCPU, ubound(umcCPU,dim=1), umcCPU(1,n_cols+1), &
ubound(umcCPU,dim=1), ZERO, vav, ubound(vav,dim=1))
#if REALCASE == 1
call PRECISION_TRMM('Right', 'Upper', 'Trans', 'Nonunit', &
#endif
#if COMPLEXCASE == 1
call PRECISION_TRMM('Right', 'Upper', 'C', 'Nonunit', &
#endif
call PRECISION_TRMM('Right', 'Upper', BLAS_TRANS_OR_CONJ, 'Nonunit', &
n_cols, n_cols, ONE, tmat(1,1,istep), &
ubound(tmat,dim=1), vav, ubound(vav,dim=1))
call obj%timer%stop("blas")
......@@ -1519,16 +1430,9 @@
if ( myend > lre ) myend = lre
if ( myend-mystart+1 < 1) cycle
call obj%timer%start("blas")
#if REALCASE == 1
call PRECISION_GEMM('N', 'T', myend-mystart+1, lce-lcs+1, 2*n_cols, -ONE, &
call PRECISION_GEMM('N', BLAS_TRANS_OR_CONJ, myend-mystart+1, lce-lcs+1, 2*n_cols, -ONE, &
vmrCPU(mystart, 1), ubound(vmrCPU,1), umcCPU(lcs,1), ubound(umcCPU,1), &
ONE, a(mystart,lcs), ubound(a,1))
#endif
#if COMPLEXCASE == 1
call PRECISION_GEMM('N', 'C', myend-mystart+1, lce-lcs+1, 2*n_cols, -ONE, &
vmrCPU(mystart, 1), ubound(vmrCPU,1), umcCPU(lcs,1), ubound(umcCPU,1), &
one, a(mystart,lcs), ubound(a,1))
#endif
call obj%timer%stop("blas")
enddo
!$omp end parallel
......@@ -1557,12 +1461,7 @@
if (useGPU) then
call obj%timer%start("cublas")
#if REALCASE == 1
call cublas_PRECISION_GEMM('N', 'T', &
#endif
#if COMPLEXCASE == 1
call cublas_PRECISION_GEMM('N', 'C', &
#endif
call cublas_PRECISION_GEMM('N', BLAS_TRANS_OR_CONJ, &
lre, lce-lcs+1, 2*n_cols, -ONE, &
vmr_dev, cur_l_rows, (umc_dev +(lcs-1)* &
size_of_datatype), &
......@@ -1573,13 +1472,7 @@
else ! useGPU
call obj%timer%start("blas")
#if REALCASE == 1
call PRECISION_GEMM('N', 'T', &
#endif
#if COMPLEXCASE == 1
call PRECISION_GEMM('N', 'C', &
#endif
lre,lce-lcs+1, 2*n_cols, -ONE, &
call PRECISION_GEMM('N', BLAS_TRANS_OR_CONJ, lre,lce-lcs+1, 2*n_cols, -ONE, &
vmrCPU, ubound(vmrCPU,dim=1), umcCPU(lcs,1), ubound(umcCPU,dim=1), &
ONE, a(1,lcs), lda)
call obj%timer%stop("blas")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment