Commit a4d30cfb authored by Pavel Kus's avatar Pavel Kus
Browse files

more single/double real/complex

parent 6f4f00e5
...@@ -49,22 +49,13 @@ ...@@ -49,22 +49,13 @@
use precision use precision
use elpa_abstract_impl use elpa_abstract_impl
implicit none implicit none
#include "../general/precision_kinds.F90"
class(elpa_abstract_impl_t), intent(inout) :: obj class(elpa_abstract_impl_t), intent(inout) :: obj
integer(kind=ik) :: na, lda, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols integer(kind=ik) :: na, lda, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols
#if REALCASE == 1
#ifdef USE_ASSUMED_SIZE
real(kind=REAL_DATATYPE) :: a(obj%local_nrows,*)
#else
real(kind=REAL_DATATYPE) :: a(obj%local_nrows,obj%local_ncols)
#endif
#endif
#if COMPLEXCASE == 1
#ifdef USE_ASSUMED_SIZE #ifdef USE_ASSUMED_SIZE
complex(kind=COMPLEX_DATATYPE) :: a(obj%local_nrows,*) MATH_DATATYPE(kind=rck) :: a(obj%local_nrows,*)
#else #else
complex(kind=COMPLEX_DATATYPE) :: a(obj%local_nrows,obj%local_ncols) MATH_DATATYPE(kind=rck) :: a(obj%local_nrows,obj%local_ncols)
#endif
#endif #endif
integer(kind=ik) :: my_prow, my_pcol, np_rows, np_cols, mpierr integer(kind=ik) :: my_prow, my_pcol, np_rows, np_cols, mpierr
integer(kind=ik) :: l_cols, l_rows, l_col1, l_row1, l_colx, l_rowx integer(kind=ik) :: l_cols, l_rows, l_col1, l_row1, l_colx, l_rowx
...@@ -72,12 +63,7 @@ ...@@ -72,12 +63,7 @@
integer(kind=ik) :: lcs, lce, lrs, lre integer(kind=ik) :: lcs, lce, lrs, lre
integer(kind=ik) :: tile_size, l_rows_tile, l_cols_tile integer(kind=ik) :: tile_size, l_rows_tile, l_cols_tile
#if REALCASE == 1 MATH_DATATYPE(kind=rck), allocatable :: tmp1(:), tmp2(:,:), tmatr(:,:), tmatc(:,:)
real(kind=REAL_DATATYPE), allocatable :: tmp1(:), tmp2(:,:), tmatr(:,:), tmatc(:,:)
#endif
#if COMPLEXCASE == 1
complex(kind=COMPLEX_DATATYPE), allocatable :: tmp1(:), tmp2(:,:), tmatr(:,:), tmatc(:,:)
#endif
logical :: wantDebug logical :: wantDebug
logical :: success logical :: success
integer(kind=ik) :: istat, debug integer(kind=ik) :: istat, debug
...@@ -256,16 +242,8 @@ ...@@ -256,16 +242,8 @@
call obj%timer%start("blas") call obj%timer%start("blas")
if (l_cols-l_colx+1>0) & if (l_cols-l_colx+1>0) &
#if REALCASE == 1 call PRECISION_TRSM('L', 'U', BLAS_TRANS_OR_CONJ, 'N', nblk, l_cols-l_colx+1, ONE, tmp2, &
call PRECISION_TRSM('L', 'U', 'T', 'N', nblk, l_cols-l_colx+1, CONST_1_0, tmp2, ubound(tmp2,dim=1), & ubound(tmp2,dim=1), a(l_row1,l_colx), lda)
a(l_row1,l_colx), lda)
#endif
#if COMPLEXCASE == 1
call PRECISION_TRSM('L', 'U', 'C', 'N', nblk, l_cols-l_colx+1, CONST_COMPLEX_PAIR_1_0, &
tmp2, ubound(tmp2,dim=1), a(l_row1,l_colx), lda)
#endif
call obj%timer%stop("blas") call obj%timer%stop("blas")
endif endif
...@@ -282,13 +260,7 @@ ...@@ -282,13 +260,7 @@
call obj%timer%start("mpi_communication") call obj%timer%start("mpi_communication")
if (l_cols-l_colx+1>0) & if (l_cols-l_colx+1>0) &
call MPI_Bcast(tmatc(l_colx,i), l_cols-l_colx+1, & call MPI_Bcast(tmatc(l_colx,i), l_cols-l_colx+1, MPI_MATH_DATATYPE_PRECISION, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
prow(n, nblk, np_rows), mpi_comm_rows, mpierr) prow(n, nblk, np_rows), mpi_comm_rows, mpierr)
call obj%timer%stop("mpi_communication") call obj%timer%stop("mpi_communication")
...@@ -310,18 +282,9 @@ ...@@ -310,18 +282,9 @@
lre = min(l_rows,(i+1)*l_rows_tile) lre = min(l_rows,(i+1)*l_rows_tile)
if (lce<lcs .or. lre<lrs) cycle if (lce<lcs .or. lre<lrs) cycle
call obj%timer%start("blas") call obj%timer%start("blas")
call PRECISION_GEMM('N', BLAS_TRANS_OR_CONJ, lre-lrs+1, lce-lcs+1, nblk, -ONE, &
#if REALCASE == 1
call PRECISION_GEMM('N', 'T', lre-lrs+1, lce-lcs+1, nblk, -CONST_1_0, &
tmatr(lrs,1), ubound(tmatr,dim=1), tmatc(lcs,1), ubound(tmatc,dim=1), &
CONST_1_0, a(lrs,lcs), lda)
#endif
#if COMPLEXCASE == 1
call PRECISION_GEMM('N', 'C', lre-lrs+1, lce-lcs+1, nblk, -CONST_COMPLEX_PAIR_1_0, &
tmatr(lrs,1), ubound(tmatr,dim=1), tmatc(lcs,1), ubound(tmatc,dim=1), & tmatr(lrs,1), ubound(tmatr,dim=1), tmatc(lcs,1), ubound(tmatc,dim=1), &
CONST_COMPLEX_PAIR_1_0, a(lrs,lcs), lda) ONE, a(lrs,lcs), lda)
#endif
call obj%timer%stop("blas") call obj%timer%stop("blas")
enddo enddo
...@@ -351,8 +314,3 @@ ...@@ -351,8 +314,3 @@
&_& &_&
&PRECISION& &PRECISION&
&") &")
#undef REALCASE
#undef COMPLEXCASE
#undef DOUBLE_PRECISION
#undef SINGLE_PRECISION
...@@ -58,34 +58,19 @@ ...@@ -58,34 +58,19 @@
use elpa_mpi use elpa_mpi
use elpa_abstract_impl use elpa_abstract_impl
implicit none implicit none
#include "../general/precision_kinds.F90"
class(elpa_abstract_impl_t), intent(inout) :: obj class(elpa_abstract_impl_t), intent(inout) :: obj
integer(kind=ik) :: na, lda, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols integer(kind=ik) :: na, lda, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols
#if REALCASE == 1
#ifdef USE_ASSUMED_SIZE
real(kind=REAL_DATATYPE) :: a(obj%local_nrows,*)
#else
real(kind=REAL_DATATYPE) :: a(obj%local_nrows,obj%local_ncols)
#endif
#endif
#if COMPLEXCASE == 1
#ifdef USE_ASSUMED_SIZE #ifdef USE_ASSUMED_SIZE
complex(kind=COMPLEX_DATATYPE) :: a(obj%local_nrows,*) MATH_DATATYPE(kind=rck) :: a(obj%local_nrows,*)
#else #else
complex(kind=COMPLEX_DATATYPE) :: a(obj%local_nrows,obj%local_ncols) MATH_DATATYPE(kind=rck) :: a(obj%local_nrows,obj%local_ncols)
#endif
#endif #endif
integer(kind=ik) :: my_prow, my_pcol, np_rows, np_cols, mpierr integer(kind=ik) :: my_prow, my_pcol, np_rows, np_cols, mpierr
integer(kind=ik) :: l_cols, l_rows, l_col1, l_row1, l_colx, l_rowx integer(kind=ik) :: l_cols, l_rows, l_col1, l_row1, l_colx, l_rowx
integer(kind=ik) :: n, nc, i, info, ns, nb integer(kind=ik) :: n, nc, i, info, ns, nb
#if REALCASE == 1 MATH_DATATYPE(kind=rck), allocatable :: tmp1(:), tmp2(:,:), tmat1(:,:), tmat2(:,:)
real(kind=REAL_DATATYPE), allocatable :: tmp1(:), tmp2(:,:), tmat1(:,:), tmat2(:,:)
#endif
#if COMPLEXCASE == 1
complex(kind=COMPLEX_DATATYPE), allocatable :: tmp1(:), tmp2(:,:), tmat1(:,:), tmat2(:,:)
#endif
logical :: wantDebug logical :: wantDebug
logical :: success logical :: success
integer(kind=ik) :: istat, debug integer(kind=ik) :: istat, debug
...@@ -211,13 +196,7 @@ ...@@ -211,13 +196,7 @@
endif endif
#ifdef WITH_MPI #ifdef WITH_MPI
call obj%timer%start("mpi_communication") call obj%timer%start("mpi_communication")
call MPI_Bcast(tmp1, nb*(nb+1)/2, & call MPI_Bcast(tmp1, nb*(nb+1)/2, MPI_MATH_DATATYPE_PRECISION, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
pcol(n, nblk, np_cols), mpi_comm_cols, mpierr) pcol(n, nblk, np_cols), mpi_comm_cols, mpierr)
call obj%timer%stop("mpi_communication") call obj%timer%stop("mpi_communication")
#endif /* WITH_MPI */ #endif /* WITH_MPI */
...@@ -229,13 +208,7 @@ ...@@ -229,13 +208,7 @@
call obj%timer%start("blas") call obj%timer%start("blas")
if (l_cols-l_colx+1>0) & if (l_cols-l_colx+1>0) &
call PRECISION_TRMM ('L', 'U', 'N', 'N', nb, l_cols-l_colx+1, & call PRECISION_TRMM('L', 'U', 'N', 'N', nb, l_cols-l_colx+1, ONE, &
#if REALCASE == 1
CONST_1_0, &
#endif
#if COMPLEXCASE == 1
CONST_COMPLEX_PAIR_1_0, &
#endif
tmp2, ubound(tmp2,dim=1), a(l_row1,l_colx), lda) tmp2, ubound(tmp2,dim=1), a(l_row1,l_colx), lda)
call obj%timer%stop("blas") call obj%timer%stop("blas")
if (l_colx<=l_cols) tmat2(1:nb,l_colx:l_cols) = a(l_row1:l_row1+nb-1,l_colx:l_cols) if (l_colx<=l_cols) tmat2(1:nb,l_colx:l_cols) = a(l_row1:l_row1+nb-1,l_colx:l_cols)
...@@ -252,13 +225,7 @@ ...@@ -252,13 +225,7 @@
do i=1,nb do i=1,nb
#ifdef WITH_MPI #ifdef WITH_MPI
call obj%timer%start("mpi_communication") call obj%timer%start("mpi_communication")
call MPI_Bcast(tmat1(1,i), l_row1-1, & call MPI_Bcast(tmat2(1,i), l_row1-1, MPI_MATH_DATATYPE_PRECISION, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
pcol(n, nblk, np_cols), mpi_comm_cols, mpierr) pcol(n, nblk, np_cols), mpi_comm_cols, mpierr)
call obj%timer%stop("mpi_communication") call obj%timer%stop("mpi_communication")
...@@ -268,13 +235,7 @@ ...@@ -268,13 +235,7 @@
#ifdef WITH_MPI #ifdef WITH_MPI
call obj%timer%start("mpi_communication") call obj%timer%start("mpi_communication")
if (l_cols-l_col1+1>0) & if (l_cols-l_col1+1>0) &
call MPI_Bcast(tmat2(1,l_col1), (l_cols-l_col1+1)*nblk, & call MPI_Bcast(tmat2(1,l_col1), (l_cols-l_col1+1)*nblk, MPI_MATH_DATATYPE_PRECISION, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
prow(n, nblk, np_rows), mpi_comm_rows, mpierr) prow(n, nblk, np_rows), mpi_comm_rows, mpierr)
call obj%timer%stop("mpi_communication") call obj%timer%stop("mpi_communication")
...@@ -282,20 +243,8 @@ ...@@ -282,20 +243,8 @@
call obj%timer%start("blas") call obj%timer%start("blas")
if (l_row1>1 .and. l_cols-l_col1+1>0) & if (l_row1>1 .and. l_cols-l_col1+1>0) &
call PRECISION_GEMM('N', 'N', l_row1-1, l_cols-l_col1+1, nb, & call PRECISION_GEMM('N', 'N', l_row1-1, l_cols-l_col1+1, nb, -ONE, &
#if REALCASE == 1 tmat1, ubound(tmat1,dim=1), tmat2(1,l_col1), ubound(tmat2,dim=1), ONE, &
-CONST_1_0, &
#endif
#if COMPLEXCASE == 1
-CONST_COMPLEX_PAIR_1_0, &
#endif
tmat1, ubound(tmat1,dim=1), tmat2(1,l_col1), ubound(tmat2,dim=1), &
#if REALCASE == 1
CONST_1_0, &
#endif
#if COMPLEXCASE == 1
CONST_COMPLEX_PAIR_1_0, &
#endif
a(1,l_col1), lda) a(1,l_col1), lda)
call obj%timer%stop("blas") call obj%timer%stop("blas")
...@@ -315,7 +264,3 @@ ...@@ -315,7 +264,3 @@
&_& &_&
&PRECISION& &PRECISION&
&") &")
#undef REALCASE
#undef COMPLEXCASE
#undef DOUBLE_PRECISION
#undef SINGLE_PRECISION
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment