Commit 4b953ea9 authored by Pavel Kus's avatar Pavel Kus
Browse files

BLAS_TRANS_OR_CONJ and MPI_MATH_DATATYPE_PRECISION introduced

parent f46f93d8
...@@ -413,14 +413,8 @@ ...@@ -413,14 +413,8 @@
#ifdef WITH_MPI #ifdef WITH_MPI
call timer%start("mpi_communication") call timer%start("mpi_communication")
call mpi_allreduce(aux1, aux2, 2, & call mpi_allreduce(aux1, aux2, 2, MPI_MATH_DATATYPE_PRECISION, &
#if REALCASE == 1 MPI_SUM, mpi_comm_rows, mpierr)
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
MPI_SUM, mpi_comm_rows, mpierr)
call timer%stop("mpi_communication") call timer%stop("mpi_communication")
#else /* WITH_MPI */ #else /* WITH_MPI */
aux2 = aux1 aux2 = aux1
...@@ -456,13 +450,7 @@ ...@@ -456,13 +450,7 @@
#ifdef WITH_MPI #ifdef WITH_MPI
! Broadcast the Householder vector (and tau) along columns ! Broadcast the Householder vector (and tau) along columns
call MPI_Bcast(v_row, l_rows+1, & call MPI_Bcast(v_row, l_rows+1, MPI_MATH_DATATYPE_PRECISION, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
pcol(istep, nblk, np_cols), mpi_comm_cols, mpierr) pcol(istep, nblk, np_cols), mpi_comm_cols, mpierr)
#endif /* WITH_MPI */ #endif /* WITH_MPI */
...@@ -558,13 +546,7 @@ ...@@ -558,13 +546,7 @@
#ifdef WITH_OPENMP #ifdef WITH_OPENMP
if (mod(n_iter,n_threads) == my_thread) then if (mod(n_iter,n_threads) == my_thread) then
call timer%start("blas") call timer%start("blas")
#if REALCASE == 1 call PRECISION_GEMV(BLAS_TRANS_OR_CONJ, l_row_end-l_row_beg+1, l_col_end-l_col_beg+1, &
call PRECISION_GEMV('T', &
#endif
#if COMPLEXCASE == 1
call PRECISION_GEMV('C', &
#endif
l_row_end-l_row_beg+1, l_col_end-l_col_beg+1, &
ONE, a_mat(l_row_beg,l_col_beg), lda, & ONE, a_mat(l_row_beg,l_col_beg), lda, &
v_row(l_row_beg), 1, ONE, uc_p(l_col_beg,my_thread), 1) v_row(l_row_beg), 1, ONE, uc_p(l_col_beg,my_thread), 1)
...@@ -581,13 +563,8 @@ ...@@ -581,13 +563,8 @@
if (useGPU) then if (useGPU) then
a_offset = ((l_row_beg-1) + (l_col_beg - 1) * lda) * size_of_datatype a_offset = ((l_row_beg-1) + (l_col_beg - 1) * lda) * size_of_datatype
call timer%start("cublas") call timer%start("cublas")
#if REALCASE == 1 call cublas_PRECISION_GEMV(BLAS_TRANS_OR_CONJ, &
call cublas_PRECISION_GEMV('T', &
#endif
#if COMPLEXCASE == 1
call cublas_PRECISION_GEMV('C', &
#endif
l_row_end-l_row_beg+1,l_col_end-l_col_beg+1, & l_row_end-l_row_beg+1,l_col_end-l_col_beg+1, &
ONE, a_dev + a_offset, lda, & ONE, a_dev + a_offset, lda, &
v_row_dev + (l_row_beg - 1) * & v_row_dev + (l_row_beg - 1) * &
...@@ -602,21 +579,15 @@ ...@@ -602,21 +579,15 @@
size_of_datatype, 1, & size_of_datatype, 1, &
ONE, u_row_dev + (l_row_beg - 1) * & ONE, u_row_dev + (l_row_beg - 1) * &
size_of_datatype, 1) size_of_datatype, 1)
endif endif
call timer%stop("cublas") call timer%stop("cublas")
else ! useGPU else ! useGPU
call timer%start("blas") call timer%start("blas")
#if REALCASE == 1 call PRECISION_GEMV(BLAS_TRANS_OR_CONJ, &
call PRECISION_GEMV('T', & l_row_end-l_row_beg+1, l_col_end-l_col_beg+1, &
#endif ONE, a_mat(l_row_beg, l_col_beg), lda, &
#if COMPLEXCASE == 1 v_row(l_row_beg), 1, &
call PRECISION_GEMV('C', & ONE, u_col(l_col_beg), 1)
#endif
l_row_end-l_row_beg+1, l_col_end-l_col_beg+1, &
ONE, a_mat(l_row_beg, l_col_beg), lda, &
v_row(l_row_beg), 1, &
ONE, u_col(l_col_beg), 1)
if (i/=j) then if (i/=j) then
...@@ -699,13 +670,7 @@ ...@@ -699,13 +670,7 @@
tmp(1:l_cols) = u_col(1:l_cols) tmp(1:l_cols) = u_col(1:l_cols)
#ifdef WITH_MPI #ifdef WITH_MPI
call timer%start("mpi_communication") call timer%start("mpi_communication")
call mpi_allreduce(tmp, u_col, l_cols, & call mpi_allreduce(tmp, u_col, l_cols, MPI_MATH_DATATYPE_PRECISION, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
MPI_SUM, mpi_comm_rows, mpierr) MPI_SUM, mpi_comm_rows, mpierr)
call timer%stop("mpi_communication") call timer%stop("mpi_communication")
#else /* WITH_MPI */ #else /* WITH_MPI */
...@@ -735,10 +700,10 @@ ...@@ -735,10 +700,10 @@
#ifdef WITH_MPI #ifdef WITH_MPI
call timer%start("mpi_communication") call timer%start("mpi_communication")
#if REALCASE == 1 #if REALCASE == 1
call mpi_allreduce(x, vav, 1, MPI_REAL_PRECISION, MPI_SUM, mpi_comm_cols, mpierr) call mpi_allreduce(x, vav, 1, MPI_MATH_DATATYPE_PRECISION, MPI_SUM, mpi_comm_cols, mpierr)
#endif #endif
#if COMPLEXCASE == 1 #if COMPLEXCASE == 1
call mpi_allreduce(xc, vav, 1 , MPI_COMPLEX_PRECISION, MPI_SUM, mpi_comm_cols, mpierr) call mpi_allreduce(xc, vav, 1 , MPI_MATH_DATATYPE_PRECISION, MPI_SUM, mpi_comm_cols, mpierr)
#endif #endif
call timer%stop("mpi_communication") call timer%stop("mpi_communication")
#else /* WITH_MPI */ #else /* WITH_MPI */
...@@ -803,34 +768,24 @@ ...@@ -803,34 +768,24 @@
cycle cycle
if (useGPU) then if (useGPU) then
call timer%start("cublas") call timer%start("cublas")
#if REALCASE == 1 call cublas_PRECISION_GEMM('N', BLAS_TRANS_OR_CONJ, &
call cublas_PRECISION_GEMM('N', 'T', & l_row_end-l_row_beg+1, l_col_end-l_col_beg+1, 2*n_stored_vecs, &
#endif
#if COMPLEXCASE == 1
call cublas_PRECISION_GEMM('N', 'C', &
#endif
l_row_end-l_row_beg+1, l_col_end-l_col_beg+1, 2*n_stored_vecs, &
ONE, vu_stored_rows_dev + (l_row_beg - 1) * & ONE, vu_stored_rows_dev + (l_row_beg - 1) * &
size_of_datatype, & size_of_datatype, &
max_local_rows, uv_stored_cols_dev + (l_col_beg - 1) * & max_local_rows, uv_stored_cols_dev + (l_col_beg - 1) * &
size_of_datatype, & size_of_datatype, &
max_local_cols, ONE, a_dev + ((l_row_beg - 1) + (l_col_beg - 1) * lda) * & max_local_cols, ONE, a_dev + ((l_row_beg - 1) + (l_col_beg - 1) * lda) * &
size_of_datatype , lda) size_of_datatype , lda)
call timer%stop("cublas") call timer%stop("cublas")
else !useGPU else !useGPU
call timer%start("blas") call timer%start("blas")
#if REALCASE == 1 call PRECISION_GEMM('N', BLAS_TRANS_OR_CONJ, &
call PRECISION_GEMM('N', 'T', & l_row_end-l_row_beg+1, l_col_end-l_col_beg+1, 2*n_stored_vecs, &
#endif
#if COMPLEXCASE == 1
call PRECISION_GEMM('N', 'C', &
#endif
l_row_end-l_row_beg+1, l_col_end-l_col_beg+1, 2*n_stored_vecs, &
ONE, vu_stored_rows(l_row_beg,1), ubound(vu_stored_rows,dim=1), & ONE, vu_stored_rows(l_row_beg,1), ubound(vu_stored_rows,dim=1), &
uv_stored_cols(l_col_beg,1), ubound(uv_stored_cols,dim=1), & uv_stored_cols(l_col_beg,1), ubound(uv_stored_cols,dim=1), &
ONE, a_mat(l_row_beg,l_col_beg), lda) ONE, a_mat(l_row_beg,l_col_beg), lda)
call timer%stop("blas") call timer%stop("blas")
endif !useGPU endif !useGPU
enddo enddo
...@@ -856,7 +811,7 @@ ...@@ -856,7 +811,7 @@
if (useGPU) then if (useGPU) then
!a_dev(l_rows,l_cols) = a_mat(l_rows,l_cols) !a_dev(l_rows,l_cols) = a_mat(l_rows,l_cols)
!successCUDA = cuda_threadsynchronize() !successCUDA = cuda_threadsynchronize()
!check_memcpy_cuda("tridiag: a_dev 4a5a", successCUDA) !check_memcpy_cuda("tridiag: a_dev 4a5a", successCUDA)
successCUDA = cuda_memcpy(a_dev + a_offset, int(loc(a_mat(l_rows, l_cols)),kind=c_size_t), & successCUDA = cuda_memcpy(a_dev + a_offset, int(loc(a_mat(l_rows, l_cols)),kind=c_size_t), &
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#undef DOUBLE_PRECISION_REAL #undef DOUBLE_PRECISION_REAL
#undef MATH_DATATYPE #undef MATH_DATATYPE
#undef BLAS_TRANS_OR_CONJ
#undef PRECISION #undef PRECISION
#undef PRECISION_STR #undef PRECISION_STR
#undef REAL_DATATYPE #undef REAL_DATATYPE
...@@ -41,10 +42,12 @@ ...@@ -41,10 +42,12 @@
#undef CONST_2_0 #undef CONST_2_0
#undef CONST_8_0 #undef CONST_8_0
#undef MPI_REAL_PRECISION #undef MPI_REAL_PRECISION
#undef MPI_MATH_DATATYPE_PRECISION
#undef C_DATATYPE_KIND #undef C_DATATYPE_KIND
/* General definitions needed in single and real case */ /* General definitions needed in single and real case */
#define MATH_DATATYPE real #define MATH_DATATYPE real
#define BLAS_TRANS_OR_CONJ 'T'
#ifdef DOUBLE_PRECISION #ifdef DOUBLE_PRECISION
#define DOUBLE_PRECISION_REAL #define DOUBLE_PRECISION_REAL
...@@ -87,6 +90,7 @@ ...@@ -87,6 +90,7 @@
#define CONST_2_0 2.0_rk8 #define CONST_2_0 2.0_rk8
#define CONST_8_0 8.0_rk8 #define CONST_8_0 8.0_rk8
#define MPI_REAL_PRECISION MPI_REAL8 #define MPI_REAL_PRECISION MPI_REAL8
#define MPI_MATH_DATATYPE_PRECISION MPI_REAL8
#define C_DATATYPE_KIND c_double #define C_DATATYPE_KIND c_double
#endif /* DOUBLE_PRECISION */ #endif /* DOUBLE_PRECISION */
...@@ -131,6 +135,7 @@ ...@@ -131,6 +135,7 @@
#define CONST_2_0 2.0_rk4 #define CONST_2_0 2.0_rk4
#define CONST_8_0 8.0_rk4 #define CONST_8_0 8.0_rk4
#define MPI_REAL_PRECISION MPI_REAL4 #define MPI_REAL_PRECISION MPI_REAL4
#define MPI_MATH_DATATYPE_PRECISION MPI_REAL4
#define C_DATATYPE_KIND c_float #define C_DATATYPE_KIND c_float
#endif /* SINGLE_PRECISION */ #endif /* SINGLE_PRECISION */
...@@ -141,6 +146,7 @@ ...@@ -141,6 +146,7 @@
#undef DOUBLE_PRECISION_COMPLEX #undef DOUBLE_PRECISION_COMPLEX
#undef MATH_DATATYPE #undef MATH_DATATYPE
#undef BLAS_TRANS_OR_CONJ
#undef PRECISION #undef PRECISION
#undef COMPLEX_DATATYPE #undef COMPLEX_DATATYPE
/* in the complex case also sometime real valued variables are needed */ /* in the complex case also sometime real valued variables are needed */
...@@ -176,6 +182,7 @@ ...@@ -176,6 +182,7 @@
#undef cublas_PRECISION_SYMV #undef cublas_PRECISION_SYMV
#undef PRECISION_SUFFIX #undef PRECISION_SUFFIX
#undef MPI_COMPLEX_PRECISION #undef MPI_COMPLEX_PRECISION
#undef MPI_MATH_DATATYPE_PRECISION
#undef MPI_COMPLEX_EXPLICIT_PRECISION #undef MPI_COMPLEX_EXPLICIT_PRECISION
#undef MPI_REAL_PRECISION #undef MPI_REAL_PRECISION
#undef KIND_PRECISION #undef KIND_PRECISION
...@@ -195,6 +202,7 @@ ...@@ -195,6 +202,7 @@
/* General definitions needed in single and double case */ /* General definitions needed in single and double case */
#define MATH_DATATYPE complex #define MATH_DATATYPE complex
#define BLAS_TRANS_OR_CONJ 'C'
#ifdef DOUBLE_PRECISION #ifdef DOUBLE_PRECISION
...@@ -233,6 +241,7 @@ ...@@ -233,6 +241,7 @@
#define cublas_PRECISION_GEMV cublas_ZGEMV #define cublas_PRECISION_GEMV cublas_ZGEMV
#define cublas_PRECISION_SYMV cublas_ZSYMV #define cublas_PRECISION_SYMV cublas_ZSYMV
#define MPI_COMPLEX_PRECISION MPI_DOUBLE_COMPLEX #define MPI_COMPLEX_PRECISION MPI_DOUBLE_COMPLEX
#define MPI_MATH_DATATYPE_PRECISION MPI_DOUBLE_COMPLEX
#define MPI_COMPLEX_EXPLICIT_PRECISION MPI_COMPLEX16 #define MPI_COMPLEX_EXPLICIT_PRECISION MPI_COMPLEX16
#define MPI_REAL_PRECISION MPI_REAL8 #define MPI_REAL_PRECISION MPI_REAL8
#define KIND_PRECISION rk8 #define KIND_PRECISION rk8
...@@ -287,6 +296,7 @@ ...@@ -287,6 +296,7 @@
#define cublas_PRECISION_GEMV cublas_CGEMV #define cublas_PRECISION_GEMV cublas_CGEMV
#define cublas_PRECISION_SYMV cublas_CSYMV #define cublas_PRECISION_SYMV cublas_CSYMV
#define MPI_COMPLEX_PRECISION MPI_COMPLEX #define MPI_COMPLEX_PRECISION MPI_COMPLEX
#define MPI_MATH_DATATYPE_PRECISION MPI_COMPLEX
#define MPI_COMPLEX_EXPLICIT_PRECISION MPI_COMPLEX8 #define MPI_COMPLEX_EXPLICIT_PRECISION MPI_COMPLEX8
#define MPI_REAL_PRECISION MPI_REAL4 #define MPI_REAL_PRECISION MPI_REAL4
#define KIND_PRECISION rk4 #define KIND_PRECISION rk4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment