Commit 8c16184c authored by Pavel Kus's avatar Pavel Kus

blas kernel runs partially on GPU

added some more wrappers for the cublas functions
parent 69965676
......@@ -387,6 +387,19 @@ extern "C" {
m, n, k, &alpha_casted, A_casted, lda, B_casted, ldb, &beta_casted, C_casted, ldc);
}
// TODO so far only double real
void cublasDsyrk_elpa_wrapper (intptr_t handle, char uplo, char trans, int n, int k,
double alpha, const double *A, int lda,
double beta, double *C, int ldc) {
cublasDsyrk(*((cublasHandle_t*)handle), fill_mode_new_api(uplo), operation_new_api(trans),
n, k, &alpha, A, lda, &beta, C, ldc);
}
// TODO so far only double real
void cublasDscal_elpa_wrapper (intptr_t handle, int n, double alpha, double *x, int incx) {
cublasDscal(*((cublasHandle_t*)handle), n, &alpha, x, incx);
}
// todo: new CUBLAS API diverged from standard BLAS api for these functions
// todo: it provides out-of-place (and apparently more efficient) implementation
......
......@@ -328,6 +328,40 @@ module cuda_functions
end subroutine cublas_strmm_c
end interface
!TODO so far only double real
interface
subroutine cublas_dsyrk_c(handle, uplo, trans, n, k, alpha, a, lda, beta, c, ldc) &
bind(C,name='cublasDsyrk_elpa_wrapper')
use iso_c_binding
implicit none
character(1, C_CHAR), value :: uplo, trans
integer(kind=C_INT), value :: n, k
integer(kind=C_INT), intent(in), value :: lda, ldc
real(kind=C_DOUBLE), value :: alpha, beta
integer(kind=C_intptr_T), value :: a, c
integer(kind=C_intptr_T), value :: handle
end subroutine cublas_dsyrk_c
end interface
!TODO so far only double real
interface
subroutine cublas_dscal_c(handle, n, alpha, x, incx) &
bind(C,name='cublasDscal_elpa_wrapper')
use iso_c_binding
implicit none
integer(kind=C_INT), value :: n
integer(kind=C_INT), intent(in), value :: incx
real(kind=C_DOUBLE), value :: alpha
integer(kind=C_intptr_T), value :: x
integer(kind=C_intptr_T), value :: handle
end subroutine cublas_dscal_c
end interface
interface
subroutine cublas_zgemm_c(handle, cta, ctb, m, n, k, alpha, a, lda, b, ldb, beta, c,ldc) &
bind(C,name='cublasZgemm_elpa_wrapper')
......@@ -759,6 +793,38 @@ module cuda_functions
#endif
end subroutine cublas_strmm
!TODO so far only double real
subroutine cublas_dsyrk(uplo, trans, n, k, alpha, a, lda, beta, c, ldc)
use iso_c_binding
implicit none
character(1, C_CHAR), value :: uplo, trans
integer(kind=C_INT) :: n, k
integer(kind=C_INT), intent(in) :: lda, ldc
real(kind=C_DOUBLE) :: alpha, beta
integer(kind=C_intptr_T) :: a, c
#ifdef WITH_GPU_VERSION
call cublas_dsyrk_c(cublasHandle, uplo, trans, n, k, alpha, a, lda, beta, c, ldc)
#endif
end subroutine cublas_dsyrk
!TODO so far only double real
subroutine cublas_dscal(n, alpha, x, incx)
use iso_c_binding
implicit none
integer(kind=C_INT) :: n
integer(kind=C_INT), intent(in) :: incx
real(kind=C_DOUBLE) :: alpha
integer(kind=C_intptr_T) :: x
#ifdef WITH_GPU_VERSION
call cublas_dscal_c(cublasHandle, n, alpha, x, incx)
#endif
end subroutine cublas_dscal
subroutine cublas_zgemm(cta, ctb, m, n, k, alpha, a, lda, b, ldb, beta, c,ldc)
use iso_c_binding
......
......@@ -1502,14 +1502,14 @@
&MATH_DATATYPE&
&_blas_4hv_&
&PRECISION&
& (useGPU, a(1,j+off+a_off-3,istripe,my_thread), w, nbw, nl, stripe_width, nbw)
& (useGPU, a(1,j+off+a_off-3,istripe,my_thread), w, nbw, nl, stripe_width, nbw, h_dev, s_dev, q_dev, w_dev)
#else
call quad_hh_trafo_&
&MATH_DATATYPE&
&_blas_4hv_&
&PRECISION&
& (useGPU, a(1:stripe_width,j+off+a_off-3:j+off+a_off+nbw-1,istripe,my_thread), w(1:nbw,1:6), nbw, nl, &
stripe_width, nbw)
stripe_width, nbw, h_dev, s_dev, q_dev, w_dev)
#endif
#else
......@@ -1519,14 +1519,14 @@
&MATH_DATATYPE&
&_blas_4hv_&
&PRECISION&
& (useGPU, a(1,j+off+a_off-3,istripe), w, nbw, nl, stripe_width, nbw)
& (useGPU, a(1,j+off+a_off-3,istripe), w, nbw, nl, stripe_width, nbw, h_dev, s_dev, q_dev, w_dev)
#else
call quad_hh_trafo_&
&MATH_DATATYPE&
&_blas_4hv_&
&PRECISION&
& (useGPU, a(1:stripe_width,j+off+a_off-3:j+off+a_off+nbw-1,istripe), w(1:nbw,1:6), nbw, nl, &
stripe_width, nbw)
stripe_width, nbw, h_dev, s_dev, q_dev, w_dev)
#endif
#endif
......
......@@ -1534,7 +1534,7 @@
&MATH_DATATYPE&
&_openmp_&
&PRECISION &
(obj, useGPU_LEGACY, wantDebug, aIntern, aIntern_dev, stripe_width, a_dim2, stripe_count, max_threads, &
(obj, useGPU, wantDebug, aIntern, aIntern_dev, stripe_width, a_dim2, stripe_count, max_threads, &
l_nev, a_off, nbw, max_blk_size, bcast_buffer, bcast_buffer_dev, &
#if REALCASE == 1
hh_dot_dev, &
......@@ -1551,7 +1551,7 @@
&MATH_DATATYPE&
&_&
&PRECISION&
& (obj, useGPU_LEGACY, wantDebug, aIntern, aIntern_dev, stripe_width, a_dim2, stripe_count, max_threads, &
& (obj, useGPU, wantDebug, aIntern, aIntern_dev, stripe_width, a_dim2, stripe_count, max_threads, &
a_off, nbw, max_blk_size, bcast_buffer, bcast_buffer_dev, &
#if REALCASE == 1
hh_dot_dev, &
......@@ -1639,7 +1639,7 @@
&MATH_DATATYPE&
&_openmp_&
&PRECISION&
& (obj, useGPU_LEGACY, wantDebug, aIntern, aIntern_dev, stripe_width, a_dim2, stripe_count, max_threads, l_nev, a_off, &
& (obj, useGPU, wantDebug, aIntern, aIntern_dev, stripe_width, a_dim2, stripe_count, max_threads, l_nev, a_off, &
nbw, max_blk_size, bcast_buffer, bcast_buffer_dev, &
#if REALCASE == 1
hh_dot_dev, &
......@@ -1682,7 +1682,7 @@
&MATH_DATATYPE&
&_&
&PRECISION&
& (obj, useGPU_LEGACY, wantDebug, aIntern, aIntern_dev, stripe_width, a_dim2, stripe_count, max_threads, &
& (obj, useGPU, wantDebug, aIntern, aIntern_dev, stripe_width, a_dim2, stripe_count, max_threads, &
a_off, nbw, max_blk_size, bcast_buffer, bcast_buffer_dev, &
#if REALCASE == 1
hh_dot_dev, &
......@@ -1753,7 +1753,7 @@
&MATH_DATATYPE&
&_openmp_&
&PRECISION&
& (obj, useGPU_LEGACY, wantDebug, aIntern, aIntern_dev, stripe_width ,a_dim2, stripe_count, max_threads, l_nev, a_off, &
& (obj, useGPU, wantDebug, aIntern, aIntern_dev, stripe_width ,a_dim2, stripe_count, max_threads, l_nev, a_off, &
nbw, max_blk_size, bcast_buffer, bcast_buffer_dev, &
#if REALCASE == 1
hh_dot_dev, &
......@@ -1771,7 +1771,7 @@
&MATH_DATATYPE&
&_&
&PRECISION&
& (obj, useGPU_LEGACY, wantDebug, aIntern, aIntern_dev, stripe_width, a_dim2, stripe_count, max_threads, &
& (obj, useGPU, wantDebug, aIntern, aIntern_dev, stripe_width, a_dim2, stripe_count, max_threads, &
a_off, nbw, max_blk_size, bcast_buffer, bcast_buffer_dev, &
#if REALCASE == 1
hh_dot_dev, &
......@@ -1833,7 +1833,7 @@
&MATH_DATATYPE&
&_openmp_&
&PRECISION&
& (obj, useGPU_LEGACY, wantDebug, aIntern, aIntern_dev, stripe_width, a_dim2, stripe_count, max_threads, l_nev, a_off, &
& (obj, useGPU, wantDebug, aIntern, aIntern_dev, stripe_width, a_dim2, stripe_count, max_threads, l_nev, a_off, &
nbw, max_blk_size, bcast_buffer, bcast_buffer_dev, &
#if REALCASE == 1
hh_dot_dev, &
......@@ -1850,7 +1850,7 @@
&MATH_DATATYPE&
&_&
&PRECISION&
& (obj, useGPU_LEGACY, wantDebug, aIntern, aIntern_dev, stripe_width, a_dim2, stripe_count, max_threads, &
& (obj, useGPU, wantDebug, aIntern, aIntern_dev, stripe_width, a_dim2, stripe_count, max_threads, &
a_off, nbw, max_blk_size, bcast_buffer, bcast_buffer_dev, &
#if REALCASE == 1
hh_dot_dev, &
......
......@@ -64,13 +64,15 @@
&MATH_DATATYPE&
&_blas_4hv_&
&PRECISION&
& (useGPU, q, hh, nb, nq, ldq, ldh)
& (useGPU, q, hh, nb, nq, ldq, ldh, h_dev, s_dev, q_dev, w_dev)
use precision
use iso_c_binding
use cuda_functions
implicit none
#include "../../general/precision_kinds.F90"
logical, intent(in) :: useGPU
logical :: useGPU
integer(kind=ik), intent(in) :: nb, nq, ldq, ldh
#ifdef USE_ASSUMED_SIZE
......@@ -85,7 +87,21 @@
real(kind=C_DATATYPE_KIND) :: h_mat(4, nb+3)
real(kind=C_DATATYPE_KIND) :: s_mat(4, 4)
integer(kind=ik) :: i, j, k
!TODO remove
!real(kind=C_DATATYPE_KIND) :: q_extra(1:ldq,1:nb+3)
integer(kind=c_intptr_t) :: h_dev, s_dev, q_dev, w_dev
logical :: successCUDA
integer(kind=c_intptr_t), parameter :: size_of_datatype = size_of_&
&PRECISION&
&_&
&MATH_DATATYPE
integer(kind=ik) :: i, j, k
integer(kind=ik), parameter :: max_block_blas = 4
! Calculate dot product of the two Householder vectors
......@@ -102,36 +118,138 @@
h_mat(3,3:nb+1) = -hh(2:nb, 3)
h_mat(4,2:nb) = -hh(2:nb, 4)
if(useGPU) then
! nb == nbw
successCUDA = cuda_memcpy(h_dev, loc(h_mat(1,1)), &
max_block_blas * (nb+3) * size_of_datatype, &
cudaMemcpyHostToDevice)
if (.not.(successCUDA)) then
print *,"blas_block4_kernel: error in cudaMemcpy, h_dev host to device"
stop 1
endif
! nq == stripe_width
successCUDA = cuda_memcpy(q_dev, loc(q(1,1)), &
ldq * (nb+3) * size_of_datatype, &
cudaMemcpyHostToDevice)
if (.not.(successCUDA)) then
print *,"blas_block4_kernel: error in cudaMemcpy, q_dev host to device"
stop 1
endif
endif
! TODO we do not need the diagonal, but how to do it with BLAS?
!s_mat = - matmul(h_mat, transpose(h_mat))
call PRECISION_SYRK('L', 'N', 4, nb+3, &
-ONE, h_mat, 4, &
ZERO, s_mat, 4)
if(useGPU) then
call cublas_PRECISION_SYRK('L', 'N', 4, nb+3, &
-ONE, h_dev, 4, &
ZERO, s_dev, 4)
else
call PRECISION_SYRK('L', 'N', 4, nb+3, &
-ONE, h_mat, 4, &
ZERO, s_mat, 4)
endif
!w_comb = - matmul(q(1:nq, 1:nb+3), transpose(h_mat))
call PRECISION_GEMM('N', 'T', nq, 4, nb+3, &
-ONE, q, ldq, &
h_mat, 4, &
ZERO, w_comb, ldq)
if(useGPU) then
call cublas_PRECISION_GEMM('N', 'T', nq, 4, nb+3, &
-ONE, q_dev, ldq, &
h_dev, 4, &
ZERO, w_dev, ldq)
else
call PRECISION_GEMM('N', 'T', nq, 4, nb+3, &
-ONE, q, ldq, &
h_mat, 4, &
ZERO, w_comb, ldq)
endif
! Rank-1 update
!w_comb(1:nq,1) = hh(1,1) * w_comb(1:nq, 1)
call PRECISION_SCAL(nq, hh(1,1), w_comb(1:nq, 1), 1)
if(useGPU) then
call cublas_PRECISION_SCAL(nq, hh(1,1), w_dev, 1)
else
call PRECISION_SCAL(nq, hh(1,1), w_comb(1, 1), 1)
endif
do i = 2, 4
!w_comb(1:nq,i) = matmul(w_comb(1:nq,1:i-1), hh(1,i) * s_mat(i,1:i-1)) + hh(1,i) * w_comb(1:nq, i)
call PRECISION_GEMV('N', nq, i-1, &
hh(1,i), w_comb(1, 1), ldq, &
s_mat(i,1), 4, &
hh(1,i), w_comb(1,i), 1)
! w_comb(1:nq,i) = matmul(w_comb(1:nq,1:i-1), hh(1,i) * s_mat(i,1:i-1)) + hh(1,i) * w_comb(1:nq, i)
if(useGPU) then
call cublas_PRECISION_GEMV('N', nq, i-1, &
hh(1,i), w_dev, ldq, &
s_dev + (i - 1) * size_of_datatype, 4, &
hh(1,i), w_dev + (i-1) * ldq * size_of_datatype, 1)
else
call PRECISION_GEMV('N', nq, i-1, &
hh(1,i), w_comb(1, 1), ldq, &
s_mat(i,1), 4, &
hh(1,i), w_comb(1,i), 1)
endif
enddo
! ---------------------
if(useGPU) then
! successCUDA = cuda_memcpy(loc(s_mat(1,1)), s_dev, &
! 4 * 4 * size_of_datatype, &
! cudaMemcpyDeviceToHost)
! if (.not.(successCUDA)) then
! print *,"blas_block4_kernel: error in cudaMemcpy, q_dev device to host"
! stop 1
! endif
successCUDA = cuda_memcpy(loc(w_comb(1,1)), w_dev, &
nq * 4 * size_of_datatype, &
cudaMemcpyDeviceToHost)
if (.not.(successCUDA)) then
print *,"blas_block4_kernel: error in cudaMemcpy, w_dev device to host"
stop 1
endif
successCUDA = cuda_memcpy(loc(h_mat(1,1)), h_dev, &
max_block_blas * (nb+3) * size_of_datatype, &
cudaMemcpyDeviceToHost)
if (.not.(successCUDA)) then
print *,"blas_block4_kernel: error in cudaMemcpy, w_dev device to host"
stop 1
endif
endif
useGPU = .false.
! ---------------------
!q(1:nq, 1:nb+3) = matmul(w_comb, h_mat) + q(1:nq, 1:nb+3)
call PRECISION_GEMM('N', 'N', nq, nb+3, 4, &
ONE, w_comb, ldq, &
h_mat, 4, &
ONE, q, ldq)
if(useGPU) then
call cublas_PRECISION_GEMM('N', 'N', nq, nb+3, 4, &
ONE, w_dev, ldq, &
h_dev, 4, &
ONE, q_dev, ldq)
else
call PRECISION_GEMM('N', 'N', nq, nb+3, 4, &
ONE, w_comb, ldq, &
h_mat, 4, &
ONE, q, ldq)
endif
if(useGPU) then
!successCUDA = cuda_memcpy(loc(q_extra(1,1)), q_dev, &
successCUDA = cuda_memcpy(loc(q(1,1)), q_dev, &
ldq * (nb+3) * size_of_datatype, &
cudaMemcpyDeviceToHost)
if (.not.(successCUDA)) then
print *,"blas_block4_kernel: error in cudaMemcpy, q_dev device to host"
stop 1
endif
endif
! print *, "difference ", norm2(q(1:ldq,1:nb+3)-q_extra(1:ldq,1:nb+3)), ", ldq ", ldq, ", nq ", nq, ", nb ", nb
! print *, q(1:ldq,1:nb+3)
! stop 1
end subroutine
......
......@@ -46,6 +46,8 @@
#undef cublas_PRECISION_TRMM
#undef cublas_PRECISION_GEMV
#undef cublas_PRECISION_SYMV
#undef cublas_PRECISION_SYRK
#undef cublas_PRECISION_SCAL
#undef scal_PRECISION_GEMM
#undef scal_PRECISION_NRM2
#undef scal_PRECISION_LASET
......@@ -106,6 +108,8 @@
#define cublas_PRECISION_TRMM cublas_DTRMM
#define cublas_PRECISION_GEMV cublas_DGEMV
#define cublas_PRECISION_SYMV cublas_DSYMV
#define cublas_PRECISION_SYRK cublas_DSYRK
#define cublas_PRECISION_SCAL cublas_DSCAL
#define scal_PRECISION_GEMM PDGEMM
#define scal_PRECISION_NRM2 PDNRM2
#define scal_PRECISION_LASET PDLASET
......@@ -167,6 +171,8 @@
#define cublas_PRECISION_TRMM cublas_STRMM
#define cublas_PRECISION_GEMV cublas_SGEMV
#define cublas_PRECISION_SYMV cublas_SSYMV
#define cublas_PRECISION_SYRK cublas_SSYRK
#define cublas_PRECISION_SCAL cublas_SSCAL
#define scal_PRECISION_GEMM PSGEMM
#define scal_PRECISION_NRM2 PSNRM2
#define scal_PRECISION_LASET PSLASET
......@@ -236,6 +242,8 @@
#undef cublas_PRECISION_TRMM
#undef cublas_PRECISION_GEMV
#undef cublas_PRECISION_SYMV
#undef cublas_PRECISION_SYRK
#undef cublas_PRECISION_SCAL
#undef scal_PRECISION_GEMM
#undef scal_PRECISION_DOTC
#undef scal_PRECISION_LASET
......@@ -307,6 +315,8 @@
#define cublas_PRECISION_TRMM cublas_ZTRMM
#define cublas_PRECISION_GEMV cublas_ZGEMV
#define cublas_PRECISION_SYMV cublas_ZSYMV
#define cublas_PRECISION_SYRK cublas_ZSYRK
#define cublas_PRECISION_SCAL cublas_ZSCAL
#define scal_PRECISION_GEMM PZGEMM
#define scal_PRECISION_DOTC PZDOTC
#define scal_PRECISION_LASET PZLASET
......@@ -372,6 +382,8 @@
#define cublas_PRECISION_TRMM cublas_CTRMM
#define cublas_PRECISION_GEMV cublas_CGEMV
#define cublas_PRECISION_SYMV cublas_CSYMV
#define cublas_PRECISION_SYRK cublas_CSYRK
#define cublas_PRECISION_SCAL cublas_CSCAL
#define scal_PRECISION_GEMM PCGEMM
#define scal_PRECISION_DOTC PCDOTC
#define scal_PRECISION_LASET PCLASET
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment