Commit 3c546600 authored by Andreas Marek's avatar Andreas Marek
Browse files

Allow blocking in band_to_full

Now it possible for both the real and complex case to switch on
the blocked code path in the cpu version of band_to_full.
This closes issue #42.

Still to do:
- make blocking a run-time option
- allow to tune the blocking parameters at run-time
parent b6283c67
......@@ -182,10 +182,16 @@
integer(kind=ik) :: i
#ifdef BAND_TO_FULL_BLOCKING
#if REALCASE == 1
real(kind=REAL_DATATYPE), allocatable :: tmat_complete(:,:), t_tmp(:,:), t_tmp2(:,:)
#endif
#if COMPLEXCASE == 1
complex(kind=COMPLEX_DATATYPE), allocatable :: tmat_complete(:,:), t_tmp(:,:), t_tmp2(:,:)
#endif
integer(kind=ik) :: cwy_blocking, t_blocking, t_cols, t_rows
#endif
integer(kind=ik) :: istat
character(200) :: errorMessage
logical :: successCUDA
......@@ -217,6 +223,9 @@
! is not implemented in the GPU version
#endif
! the GPU version does not (yet) support blocking
! but the handling is the same for real/complex case
allocate(tmp1(max_local_cols*nbw), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"trans_ev_band_to_full_&
......@@ -618,7 +627,8 @@
else ! do not useGPU
#if REALCASE == 1
#ifdef BAND_TO_FULL_BLOCKING
! t_blocking was formerly 2; 3 is a better choice
t_blocking = 3 ! number of matrices T (tmat) which are aggregated into a new (larger) T matrix (tmat_complete) and applied at once
......@@ -628,69 +638,90 @@
allocate(tmp1(max_local_cols*cwy_blocking), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"trans_ev_band_to_full_real: error when allocating tmp1 "//errorMessage
print *,"trans_ev_band_to_full_&
&MATH_DATATYPE&
&: error when allocating tmp1 "//errorMessage
stop
endif
allocate(tmp2(max_local_cols*cwy_blocking), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"trans_ev_band_to_full_real: error when allocating tmp2 "//errorMessage
print *,"trans_ev_band_to_full_&
&MATH_DATATYPE&
&: error when allocating tmp2 "//errorMessage
stop
endif
allocate(hvb(max_local_rows*cwy_blocking), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"trans_ev_band_to_full_real: error when allocating hvb "//errorMessage
print *,"trans_ev_band_to_full_&
&MATH_DATATYPE&
&: error when allocating hvb "//errorMessage
stop
endif
allocate(hvm(max_local_rows,cwy_blocking), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"trans_ev_band_to_full_real: error when allocating hvm "//errorMessage
print *,"trans_ev_band_to_full_&
&MATH_DATATYPE&
&: error when allocating hvm "//errorMessage
stop
endif
#endif
#if COMPLEXCASE == 1
#else /* BAND_TO_FULL_BLOCKING */
allocate(tmp1(max_local_cols*nbw), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"trans_ev_band_to_full_complex: error when allocating tmp1 "//errorMessage
print *,"trans_ev_band_to_full_&
&MATH_DATATYPE&
&: error when allocating tmp1 "//errorMessage
stop
endif
allocate(tmp2(max_local_cols*nbw), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"trans_ev_band_to_full_complex: error when allocating tmp2 "//errorMessage
print *,"trans_ev_band_to_full_&
&MATH_DATATYPE&: error when allocating tmp2 "//errorMessage
stop
endif
allocate(hvb(max_local_rows*nbw), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"trans_ev_band_to_full_complex: error when allocating hvb "//errorMessage
print *,"trans_ev_band_to_full_&
&MATH_DATATYPE&
&: error when allocating hvb "//errorMessage
stop
endif
allocate(hvm(max_local_rows,nbw), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"trans_ev_band_to_full_complex: error when allocating hvm "//errorMessage
print *,"trans_ev_band_to_full_&
&MATH_DATATYPE&
&: error when allocating hvm "//errorMessage
stop
endif
#endif
#if REALCASE == 1
#ifdef BAND_TO_FULL_BLOCKING
allocate(tmat_complete(cwy_blocking,cwy_blocking), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"trans_ev_band_to_full_real: error when allocating tmat_complete "//errorMessage
print *,"trans_ev_band_to_full_&
&MATH_DATATYPE&
&: error when allocating tmat_complete "//errorMessage
stop
endif
allocate(t_tmp(cwy_blocking,nbw), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"trans_ev_band_to_full_real: error when allocating t_tmp "//errorMessage
print *,"trans_ev_band_to_full_&
&MATH_DATATYPE&
&: error when allocating t_tmp "//errorMessage
stop
endif
allocate(t_tmp2(cwy_blocking,nbw), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"trans_ev_band_to_full_real: error when allocating t_tmp2 "//errorMessage
print *,"trans_ev_band_to_full_&
&MATH_DATATYPE&
&: error when allocating t_tmp2 "//errorMessage
stop
endif
#endif
......@@ -713,14 +744,13 @@
! if ( na >= ((t_blocking+1)*nbw) ) then
#if REALCASE == 1
#ifdef BAND_TO_FULL_BLOCKING
do istep=1,((na-1)/nbw-1)/t_blocking + 1
#endif
#if COMPLEXCASE == 1
#else
do istep=1,(na-1)/nbw
#endif
#if REALCASE == 1
#ifdef BAND_TO_FULL_BLOCKING
! This the call when using na >= ((t_blocking+1)*nbw)
! n_cols = MIN(na,istep*cwy_blocking+nbw) - (istep-1)*cwy_blocking - nbw ! Number of columns in current step
! As an alternative we add some special case handling if na < cwy_blocking
......@@ -732,20 +762,18 @@
ELSE
n_cols = MIN(na,istep*cwy_blocking+nbw) - (istep-1)*cwy_blocking - nbw ! Number of columns in current step
END IF
#endif
#if COMPLEXCASE == 1
#else /* BAND_TO_FULL_BLOCKING */
n_cols = MIN(na,(istep+1)*nbw) - istep*nbw ! Number of columns in current step
#endif
#endif /* BAND_TO_FULL_BLOCKING */
! Broadcast all Householder vectors for current step compressed in hvb
nb = 0
ns = 0
do lc = 1, n_cols
#if REALCASE == 1
#ifdef BAND_TO_FULL_BLOCKING
ncol = (istep-1)*cwy_blocking + nbw + lc ! absolute column number of householder vector
#endif
#if COMPLEXCASE == 1
#else
ncol = istep*nbw + lc ! absolute column number of householder vector
#endif
nrow = ncol - nbw ! absolute number of pivot row
......@@ -780,10 +808,9 @@
nb = 0
do lc = 1, n_cols
#if REALCASE == 1
#ifdef BAND_TO_FULL_BLOCKING
nrow = (istep-1)*cwy_blocking + lc ! absolute number of pivot row
#endif
#if COMPLEXCASE == 1
#else
nrow = (istep-1)*nbw+lc ! absolute number of pivot row
#endif
l_rows = local_index(nrow-1, my_prow, np_rows, nblk, -1) ! row length for bcast
......@@ -798,7 +825,7 @@
nb = nb+l_rows
enddo
#if REALCASE == 1
#ifdef BAND_TO_FULL_BLOCKING
l_rows = local_index(MIN(na,(istep+1)*cwy_blocking), my_prow, np_rows, nblk, -1)
! compute tmat2 out of tmat(:,:,)
......@@ -810,34 +837,47 @@
tmat_complete(t_rows+1:t_rows+t_cols,t_rows+1:t_rows+t_cols) = tmat(1:t_cols,1:t_cols,(istep-1)*t_blocking + i)
call timer%start("blas")
if (i > 1) then
call PRECISION_GEMM('T', 'N', t_rows, t_cols, l_rows, CONST_1_0, hvm(1,1), max_local_rows, hvm(1,(i-1)*nbw+1), &
max_local_rows, CONST_0_0, t_tmp, cwy_blocking)
#if REALCASE == 1
call PRECISION_GEMM('T', 'N', &
#endif
#if COMPLEXCASE == 1
call PRECISION_GEMM('C', 'N', &
#endif
t_rows, t_cols, l_rows, ONE, hvm(1,1), max_local_rows, hvm(1,(i-1)*nbw+1), &
max_local_rows, ZERO, t_tmp, cwy_blocking)
call timer%stop("blas")
#ifdef WITH_MPI
call timer%start("mpi_communication")
call mpi_allreduce(t_tmp, t_tmp2, cwy_blocking*nbw, MPI_REAL_PRECISION, MPI_SUM, mpi_comm_rows, mpierr)
call mpi_allreduce(t_tmp, t_tmp2, cwy_blocking*nbw, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
MPI_SUM, mpi_comm_rows, mpierr)
call timer%stop("mpi_communication")
call timer%start("blas")
call PRECISION_TRMM('L', 'U', 'N', 'N', t_rows, t_cols, CONST_1_0, tmat_complete, cwy_blocking, t_tmp2, cwy_blocking)
call PRECISION_TRMM('R', 'U', 'N', 'N', t_rows, t_cols, -CONST_1_0, tmat_complete(t_rows+1,t_rows+1), cwy_blocking, &
call PRECISION_TRMM('L', 'U', 'N', 'N', t_rows, t_cols, ONE, tmat_complete, cwy_blocking, t_tmp2, cwy_blocking)
call PRECISION_TRMM('R', 'U', 'N', 'N', t_rows, t_cols, -ONE, tmat_complete(t_rows+1,t_rows+1), cwy_blocking, &
t_tmp2, cwy_blocking)
call timer%stop("blas")
tmat_complete(1:t_rows,t_rows+1:t_rows+t_cols) = t_tmp2(1:t_rows,1:t_cols)
#else
#else /* WITH_MPI */
! t_tmp2(1:cwy_blocking,1:nbw) = t_tmp(1:cwy_blocking,1:nbw)
call timer%start("blas")
call PRECISION_TRMM('L', 'U', 'N', 'N', t_rows, t_cols, CONST_1_0, tmat_complete, cwy_blocking, t_tmp, cwy_blocking)
call PRECISION_TRMM('R', 'U', 'N', 'N', t_rows, t_cols, -CONST_1_0, tmat_complete(t_rows+1,t_rows+1), cwy_blocking, &
call PRECISION_TRMM('L', 'U', 'N', 'N', t_rows, t_cols, ONE, tmat_complete, cwy_blocking, t_tmp, cwy_blocking)
call PRECISION_TRMM('R', 'U', 'N', 'N', t_rows, t_cols, -ONE, tmat_complete(t_rows+1,t_rows+1), cwy_blocking, &
t_tmp, cwy_blocking)
call timer%stop("blas")
tmat_complete(1:t_rows,t_rows+1:t_rows+t_cols) = t_tmp(1:t_rows,1:t_cols)
#endif
#endif /* WITH_MPI */
! call PRECISION_TRMM('L', 'U', 'N', 'N', t_rows, t_cols, CONST_1_0, tmat_complete, cwy_blocking, t_tmp2, cwy_blocking)
! call PRECISION_TRMM('R', 'U', 'N', 'N', t_rows, t_cols, -CONST_1_0, tmat_complete(t_rows+1,t_rows+1), cwy_blocking, &
......@@ -846,9 +886,7 @@
! tmat_complete(1:t_rows,t_rows+1:t_rows+t_cols) = t_tmp2(1:t_rows,1:t_cols)
endif
enddo
#endif /* REALCASE == 1 */
#if COMPLEXCASE == 1
#else /* BAND_TO_FULL_BLOCKING */
l_rows = local_index(MIN(na,(istep+1)*nbw), my_prow, np_rows, nblk, -1)
#endif
......@@ -894,15 +932,30 @@
call timer%start("blas")
if (l_rows>0) then
#ifdef BAND_TO_FULL_BLOCKING
#if REALCASE == 1
call PRECISION_TRMM('L', 'U', 'T', 'N', n_cols, l_cols, ONE, tmat_complete, cwy_blocking, tmp2, n_cols)
call PRECISION_TRMM('L', 'U', 'T', 'N', &
#endif
#if COMPLEXCASE == 1
call PRECISION_TRMM('L', 'U', 'C', 'N', &
#endif
n_cols, l_cols, ONE, tmat_complete, cwy_blocking, tmp2, n_cols)
call PRECISION_GEMM('N', 'N', l_rows, l_cols, n_cols, -ONE, hvm, ubound(hvm,dim=1), tmp2, n_cols, ONE, q, ldq)
#else /* BAND_TO_FULL_BLOCKING */
#if REALCASE == 1
call PRECISION_TRMM('L', 'U', 'T', 'N', &
#endif
#if COMPLEXCASE == 1
call PRECISION_TRMM('L', 'U', 'C', 'N', n_cols, l_cols, ONE, tmat(1,1,istep), ubound(tmat,dim=1), tmp2, n_cols)
call PRECISION_TRMM('L', 'U', 'C', 'N', &
#endif
n_cols, l_cols, ONE, tmat(1,1,istep), ubound(tmat,dim=1), tmp2, n_cols)
call PRECISION_GEMM('N', 'N', l_rows, l_cols, n_cols, -ONE, hvm, ubound(hvm,dim=1), &
tmp2, n_cols, ONE, q, ldq)
#endif
#endif /* BAND_TO_FULL_BLOCKING */
endif
call timer%stop("blas")
......@@ -1013,13 +1066,13 @@
stop
endif
#if REALCASE == 1
if (useQr) then
deallocate(tmat_complete, t_tmp, t_tmp2, stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"trans_ev_band_to_full_real: error when deallocating tmat_complete, t_tmp, t_tmp2 "//errorMessage
stop
endif
#if BAND_TO_FULL_BLOCKING
deallocate(tmat_complete, t_tmp, t_tmp2, stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"trans_ev_band_to_full_&
MATH_DATATYPE&
&: error when deallocating tmat_complete, t_tmp, t_tmp2 "//errorMessage
stop
endif
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment