Commit ae45bbb3 authored by Andreas Marek's avatar Andreas Marek
Browse files

Unify elpa_multiply_a_b

parent 94a4dc99
......@@ -66,6 +66,8 @@ EXTRA_libelpa@SUFFIX@_private_la_DEPENDENCIES = \
src/redist_band.X90 \
src/sanity.X90 \
src/elpa_cholesky_template.X90 \
src/elpa_invert_trm.X90 \
src/elpa_multiply_a_b.X90 \
src/precision_macros.h
lib_LTLIBRARIES = libelpa@SUFFIX@.la
......@@ -924,6 +926,8 @@ EXTRA_DIST = \
src/redist_band.X90 \
src/sanity.X90 \
src/elpa_cholesky_template.X90 \
src/elpa_invert_trm.X90 \
src/elpa_multiply_a_b.X90 \
src/elpa_qr/elpa_qrkernels.X90 \
src/ev_tridi_band_gpu_c_v2_complex_template.Xcu \
src/ev_tridi_band_gpu_c_v2_real_template.Xcu \
......
This diff is collapsed.
#include "sanity.X90"
#ifdef HAVE_DETAILED_TIMINGS
use timings
#else
use timings_dummy
#endif
use elpa1_compute
use elpa_mpi
use precision
implicit none
character*1 :: uplo_a, uplo_c
integer(kind=ik), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, nblk
integer(kind=ik) :: ncb, mpi_comm_rows, mpi_comm_cols
#if REALCASE == 1
#ifdef USE_ASSUMED_SIZE
real(kind=REAL_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*)
#else
real(kind=REAL_DATATYPE) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
#endif
#endif
#if COMPLEXCASE == 1
#ifdef USE_ASSUMED_SIZE
complex(kind=COMPLEX_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*)
#else
complex(kind=COMPLEX_DATATYPE) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
#endif
#endif
integer(kind=ik) :: my_prow, my_pcol, np_rows, np_cols, mpierr
integer(kind=ik) :: l_cols, l_rows, l_rows_np
integer(kind=ik) :: np, n, nb, nblk_mult, lrs, lre, lcs, lce
integer(kind=ik) :: gcol_min, gcol, goff
integer(kind=ik) :: nstor, nr_done, noff, np_bc, n_aux_bc, nvals
integer(kind=ik), allocatable :: lrs_save(:), lre_save(:)
logical :: a_lower, a_upper, c_lower, c_upper
#if REALCASE == 1
real(kind=REAL_DATATYPE), allocatable :: aux_mat(:,:), aux_bc(:), tmp1(:,:), tmp2(:,:)
#endif
#if COMPLEXCASE == 1
complex(kind=COMPLEX_DATATYPE), allocatable :: aux_mat(:,:), aux_bc(:), tmp1(:,:), tmp2(:,:)
#endif
integer(kind=ik) :: istat
character(200) :: errorMessage
logical :: success
call timer%start("elpa_mult_at_b_&
&MATH_DATATYPE&
&_&
&PRECISION &
")
success = .true.
call timer%start("mpi_communication")
call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
call timer%stop("mpi_communication")
l_rows = local_index(na, my_prow, np_rows, nblk, -1) ! Local rows of a and b
l_cols = local_index(ncb, my_pcol, np_cols, nblk, -1) ! Local cols of b
! Block factor for matrix multiplications, must be a multiple of nblk
if (na/np_rows<=256) then
nblk_mult = (31/nblk+1)*nblk
else
nblk_mult = (63/nblk+1)*nblk
endif
allocate(aux_mat(l_rows,nblk_mult), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when allocating aux_mat "//errorMessage
stop
endif
allocate(aux_bc(l_rows*nblk), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when allocating aux_bc "//errorMessage
stop
endif
allocate(lrs_save(nblk), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when allocating lrs_save "//errorMessage
stop
endif
allocate(lre_save(nblk), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when allocating lre_save "//errorMessage
stop
endif
a_lower = .false.
a_upper = .false.
c_lower = .false.
c_upper = .false.
if (uplo_a=='u' .or. uplo_a=='U') a_upper = .true.
if (uplo_a=='l' .or. uplo_a=='L') a_lower = .true.
if (uplo_c=='u' .or. uplo_c=='U') c_upper = .true.
if (uplo_c=='l' .or. uplo_c=='L') c_lower = .true.
! Build up the result matrix by processor rows
do np = 0, np_rows-1
! In this turn, procs of row np assemble the result
l_rows_np = local_index(na, np, np_rows, nblk, -1) ! local rows on receiving processors
nr_done = 0 ! Number of rows done
aux_mat = 0
nstor = 0 ! Number of columns stored in aux_mat
! Loop over the blocks on row np
do nb=0,(l_rows_np-1)/nblk
goff = nb*np_rows + np ! Global offset in blocks corresponding to nb
! Get the processor column which owns this block (A is transposed, so we need the column)
! and the offset in blocks within this column.
! The corresponding block column in A is then broadcast to all for multiplication with B
np_bc = MOD(goff,np_cols)
noff = goff/np_cols
n_aux_bc = 0
! Gather up the complete block column of A on the owner
do n = 1, min(l_rows_np-nb*nblk,nblk) ! Loop over columns to be broadcast
gcol = goff*nblk + n ! global column corresponding to n
if (nstor==0 .and. n==1) gcol_min = gcol
lrs = 1 ! 1st local row number for broadcast
lre = l_rows ! last local row number for broadcast
if (a_lower) lrs = local_index(gcol, my_prow, np_rows, nblk, +1)
if (a_upper) lre = local_index(gcol, my_prow, np_rows, nblk, -1)
if (lrs<=lre) then
nvals = lre-lrs+1
if (my_pcol == np_bc) aux_bc(n_aux_bc+1:n_aux_bc+nvals) = a(lrs:lre,noff*nblk+n)
n_aux_bc = n_aux_bc + nvals
endif
lrs_save(n) = lrs
lre_save(n) = lre
enddo
! Broadcast block column
#ifdef WITH_MPI
call timer%start("mpi_communication")
#if REALCASE == 1
call MPI_Bcast(aux_bc, n_aux_bc, &
MPI_REAL_PRECISION, &
np_bc, mpi_comm_cols, mpierr)
#endif
call timer%stop("mpi_communication")
#endif /* WITH_MPI */
! Insert what we got in aux_mat
n_aux_bc = 0
do n = 1, min(l_rows_np-nb*nblk,nblk)
nstor = nstor+1
lrs = lrs_save(n)
lre = lre_save(n)
if (lrs<=lre) then
nvals = lre-lrs+1
aux_mat(lrs:lre,nstor) = aux_bc(n_aux_bc+1:n_aux_bc+nvals)
n_aux_bc = n_aux_bc + nvals
endif
enddo
! If we got nblk_mult columns in aux_mat or this is the last block
! do the matrix multiplication
if (nstor==nblk_mult .or. nb*nblk+nblk >= l_rows_np) then
lrs = 1 ! 1st local row number for multiply
lre = l_rows ! last local row number for multiply
if (a_lower) lrs = local_index(gcol_min, my_prow, np_rows, nblk, +1)
if (a_upper) lre = local_index(gcol, my_prow, np_rows, nblk, -1)
lcs = 1 ! 1st local col number for multiply
lce = l_cols ! last local col number for multiply
if (c_upper) lcs = local_index(gcol_min, my_pcol, np_cols, nblk, +1)
if (c_lower) lce = MIN(local_index(gcol, my_pcol, np_cols, nblk, -1),l_cols)
if (lcs<=lce) then
allocate(tmp1(nstor,lcs:lce),tmp2(nstor,lcs:lce), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when allocating tmp1 "//errorMessage
stop
endif
if (lrs<=lre) then
call timer%start("blas")
#if REALCASE == 1
call PRECISION_GEMM('T', 'N', &
#endif
#if COMPLEXCASE == 1
call PRECISION_GEMM('C', 'N', &
#endif
nstor, lce-lcs+1, lre-lrs+1, &
#if REALCASE == 1
CONST_1_0, &
#endif
#if COMPLEXCASE == 1
CONST_COMPLEX_PAIR_1_0, &
#endif
aux_mat(lrs,1), ubound(aux_mat,dim=1), &
b(lrs,lcs), ldb, &
#if REALCASE == 1
CONST_0_0, &
#endif
#if COMPLEXCASE == 1
CONST_COMPLEX_PAIR_0_0, &
#endif
tmp1, nstor)
call timer%stop("blas")
else
tmp1 = 0
endif
! Sum up the results and send to processor row np
#ifdef WITH_MPI
call timer%start("mpi_communication")
call mpi_reduce(tmp1, tmp2, nstor*(lce-lcs+1), &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
MPI_SUM, np, mpi_comm_rows, mpierr)
call timer%stop("mpi_communication")
! Put the result into C
if (my_prow==np) c(nr_done+1:nr_done+nstor,lcs:lce) = tmp2(1:nstor,lcs:lce)
#else /* WITH_MPI */
! tmp2 = tmp1
! Put the result into C
if (my_prow==np) c(nr_done+1:nr_done+nstor,lcs:lce) = tmp1(1:nstor,lcs:lce)
#endif /* WITH_MPI */
deallocate(tmp1,tmp2, stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when deallocating tmp1 "//errorMessage
stop
endif
endif
nr_done = nr_done+nstor
nstor=0
aux_mat(:,:)=0
endif
enddo
enddo
deallocate(aux_mat, aux_bc, lrs_save, lre_save, stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when deallocating aux_mat "//errorMessage
stop
endif
call timer%stop("elpa_mult_at_b_&
&MATH_DATATYPE&
&_&
&PRECISION &
")
#undef REALCASE
#undef COMPLEXCASE
#undef DOUBLE_PRECISION
#undef SINGLE_PRECISION
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment