Commit 738b4615 authored by Andreas Marek's avatar Andreas Marek

Unify elpa_invert_trm

 This commit is triggered by issue #38
parent 36df86c7
This diff is collapsed.
#include "sanity.X90"
use precision
use elpa1_compute
use elpa_utilities
use elpa_mpi
#ifdef HAVE_DETAILED_TIMINGS
use timings
#else
use timings_dummy
#endif
implicit none
integer(kind=ik) :: na, lda, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols
#if REALCASE == 1
#ifdef USE_ASSUMED_SIZE
real(kind=REAL_DATATYPE) :: a(lda,*)
#else
real(kind=REAL_DATATYPE) :: a(lda,matrixCols)
#endif
#endif
#if COMPLEXCASE == 1
#ifdef USE_ASSUMED_SIZE
complex(kind=COMPLEX_DATATYPE) :: a(lda,*)
#else
complex(kind=COMPLEX_DATATYPE) :: a(lda,matrixCols)
#endif
#endif
integer(kind=ik) :: my_prow, my_pcol, np_rows, np_cols, mpierr
integer(kind=ik) :: l_cols, l_rows, l_col1, l_row1, l_colx, l_rowx
integer(kind=ik) :: n, nc, i, info, ns, nb
#if REALCASE == 1
real(kind=REAL_DATATYPE), allocatable :: tmp1(:), tmp2(:,:), tmat1(:,:), tmat2(:,:)
#endif
#if COMPLEXCASE == 1
complex(kind=COMPLEX_DATATYPE), allocatable :: tmp1(:), tmp2(:,:), tmat1(:,:), tmat2(:,:)
#endif
logical, intent(in) :: wantDebug
logical :: success
integer(kind=ik) :: istat
character(200) :: errorMessage
call timer%start("mpi_communication")
call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
call timer%stop("mpi_communication")
success = .true.
l_rows = local_index(na, my_prow, np_rows, nblk, -1) ! Local rows of a
l_cols = local_index(na, my_pcol, np_cols, nblk, -1) ! Local cols of a
allocate(tmp1(nblk*nblk), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_invert_trm_&
&MATH_DATATYPE&
&: error when allocating tmp1 "//errorMessage
stop
endif
allocate(tmp2(nblk,nblk), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_invert_trm_&
&MATH_DATATYPE&
&: error when allocating tmp2 "//errorMessage
stop
endif
tmp1 = 0
tmp2 = 0
allocate(tmat1(l_rows,nblk), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_invert_trm_&
&MATH_DATATYPE&
&: error when allocating tmat1 "//errorMessage
stop
endif
allocate(tmat2(nblk,l_cols), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_invert_trm_&
&MATH_DATATYPE&
&: error when allocating tmat2 "//errorMessage
stop
endif
tmat1 = 0
tmat2 = 0
ns = ((na-1)/nblk)*nblk + 1
do n = ns,1,-nblk
l_row1 = local_index(n, my_prow, np_rows, nblk, +1)
l_col1 = local_index(n, my_pcol, np_cols, nblk, +1)
nb = nblk
if (na-n+1 < nblk) nb = na-n+1
l_rowx = local_index(n+nb, my_prow, np_rows, nblk, +1)
l_colx = local_index(n+nb, my_pcol, np_cols, nblk, +1)
if (my_prow==prow(n, nblk, np_rows)) then
if (my_pcol==pcol(n, nblk, np_cols)) then
call timer%start("blas")
#if REALCASE == 1
#ifdef DOUBLE_PRECISION_REAL
call DTRTRI('U', 'N', nb, a(l_row1,l_col1), lda, info)
#else
call STRTRI('U', 'N', nb, a(l_row1,l_col1), lda, info)
#endif
#endif
#if COMPLEXCASE == 1
#ifdef DOUBLE_PRECISION_COMPLEX
call ZTRTRI('U', 'N', nb, a(l_row1,l_col1), lda, info)
#else
call CTRTRI('U', 'N', nb, a(l_row1,l_col1), lda, info)
#endif
#endif
call timer%stop("blas")
if (info/=0) then
if (wantDebug) write(error_unit,*) "elpa_invert_trm_&
&MATH_DATATYPE&
&: Error in DTRTRI"
success = .false.
return
endif
nc = 0
do i=1,nb
tmp1(nc+1:nc+i) = a(l_row1:l_row1+i-1,l_col1+i-1)
nc = nc+i
enddo
endif
#ifdef WITH_MPI
call timer%start("mpi_communication")
call MPI_Bcast(tmp1, nb*(nb+1)/2, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
pcol(n, nblk, np_cols), mpi_comm_cols, mpierr)
call timer%stop("mpi_communication")
#endif /* WITH_MPI */
nc = 0
do i=1,nb
tmp2(1:i,i) = tmp1(nc+1:nc+i)
nc = nc+i
enddo
call timer%start("blas")
if (l_cols-l_colx+1>0) &
#if REALCASE == 1
#ifdef DOUBLE_PRECISION_REAL
call DTRMM('L', 'U', 'N', 'N', nb, l_cols-l_colx+1, 1.0_rk8, tmp2, ubound(tmp2,dim=1), a(l_row1,l_colx), lda)
#else
call STRMM('L', 'U', 'N', 'N', nb, l_cols-l_colx+1, 1.0_rk4, tmp2, ubound(tmp2,dim=1), a(l_row1,l_colx), lda)
#endif
#endif
#if COMPLEXCASE == 1
#ifdef DOUBLE_PRECISION_COMPLEX
call ZTRMM('L', 'U', 'N', 'N', nb, l_cols-l_colx+1, (1.0_rk8,0.0_rk8), tmp2, ubound(tmp2,dim=1), a(l_row1,l_colx), lda)
#else
call CTRMM('L', 'U', 'N', 'N', nb, l_cols-l_colx+1, (1.0_rk4,0.0_rk4), tmp2, ubound(tmp2,dim=1), a(l_row1,l_colx), lda)
#endif
#endif
call timer%stop("blas")
if (l_colx<=l_cols) tmat2(1:nb,l_colx:l_cols) = a(l_row1:l_row1+nb-1,l_colx:l_cols)
if (my_pcol==pcol(n, nblk, np_cols)) tmat2(1:nb,l_col1:l_col1+nb-1) = tmp2(1:nb,1:nb) ! tmp2 has the lower left triangle 0
endif
if (l_row1>1) then
if (my_pcol==pcol(n, nblk, np_cols)) then
tmat1(1:l_row1-1,1:nb) = a(1:l_row1-1,l_col1:l_col1+nb-1)
a(1:l_row1-1,l_col1:l_col1+nb-1) = 0
endif
do i=1,nb
#ifdef WITH_MPI
call timer%start("mpi_communication")
call MPI_Bcast(tmat1(1,i), l_row1-1, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
pcol(n, nblk, np_cols), mpi_comm_cols, mpierr)
call timer%stop("mpi_communication")
#endif /* WITH_MPI */
enddo
endif
#ifdef WITH_MPI
call timer%start("mpi_communication")
if (l_cols-l_col1+1>0) &
call MPI_Bcast(tmat2(1,l_col1), (l_cols-l_col1+1)*nblk, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
prow(n, nblk, np_rows), mpi_comm_rows, mpierr)
call timer%stop("mpi_communication")
#endif /* WITH_MPI */
call timer%start("blas")
if (l_row1>1 .and. l_cols-l_col1+1>0) &
#if REALCASE == 1
call PRECISION_GEMM('N', 'N', l_row1-1, l_cols-l_col1+1, nb, -CONST_1_0, &
tmat1, ubound(tmat1,dim=1), tmat2(1,l_col1), ubound(tmat2,dim=1), &
CONST_1_0, a(1,l_col1), lda)
#endif
#if COMPLEXCASE == 1
call PRECISION_GEMM('N', 'N', l_row1-1, l_cols-l_col1+1, nb, -CONST_COMPLEX_PAIR_1_0, &
tmat1, ubound(tmat1,dim=1), tmat2(1,l_col1), ubound(tmat2,dim=1), &
CONST_COMPLEX_PAIR_1_0, a(1,l_col1), lda)
#endif
call timer%stop("blas")
enddo
deallocate(tmp1, tmp2, tmat1, tmat2, stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_invert_trm_&
&MATH_DATATYPE&
&: error when deallocating tmp1 "//errorMessage
stop
endif
#undef REALCASE
#undef COMPLEXCASE
#undef DOUBLE_PRECISION
#undef SINGLE_PRECISION
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment