Commit 36df86c7 authored by Andreas Marek's avatar Andreas Marek

Unify elpa_cholesky

This commit is triggered by issue #38
parent 348644f2
......@@ -64,6 +64,8 @@ EXTRA_libelpa@SUFFIX@_private_la_DEPENDENCIES = \
src/elpa2_kernels/elpa2_kernels_complex_template.X90 \
src/elpa2_kernels/elpa2_kernels_simple_template.X90 \
src/redist_band.X90 \
src/sanity.X90 \
src/elpa_cholesky_template.X90 \
src/precision_macros.h
lib_LTLIBRARIES = libelpa@SUFFIX@.la
......@@ -920,6 +922,8 @@ EXTRA_DIST = \
src/elpa2_kernels/elpa2_kernels_complex_template.X90 \
src/elpa2_kernels/elpa2_kernels_simple_template.X90 \
src/redist_band.X90 \
src/sanity.X90 \
src/elpa_cholesky_template.X90 \
src/elpa_qr/elpa_qrkernels.X90 \
src/ev_tridi_band_gpu_c_v2_complex_template.Xcu \
src/ev_tridi_band_gpu_c_v2_real_template.Xcu \
......
This diff is collapsed.
#include "sanity.X90"
use elpa1_compute
use elpa_utilities
use elpa_mpi
#ifdef HAVE_DETAILED_TIMINGS
use timings
#else
use timings_dummy
#endif
use precision
implicit none
integer(kind=ik) :: na, lda, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols
#if REALCASE == 1
#ifdef USE_ASSUMED_SIZE
real(kind=REAL_DATATYPE) :: a(lda,*)
#else
real(kind=REAL_DATATYPE) :: a(lda,matrixCols)
#endif
#endif
#if COMPLEXCASE == 1
#ifdef USE_ASSUMED_SIZE
complex(kind=COMPLEX_DATATYPE) :: a(lda,*)
#else
complex(kind=COMPLEX_DATATYPE) :: a(lda,matrixCols)
#endif
#endif
integer(kind=ik) :: my_prow, my_pcol, np_rows, np_cols, mpierr
integer(kind=ik) :: l_cols, l_rows, l_col1, l_row1, l_colx, l_rowx
integer(kind=ik) :: n, nc, i, info
integer(kind=ik) :: lcs, lce, lrs, lre
integer(kind=ik) :: tile_size, l_rows_tile, l_cols_tile
#if REALCASE == 1
real(kind=REAL_DATATYPE), allocatable :: tmp1(:), tmp2(:,:), tmatr(:,:), tmatc(:,:)
#endif
#if COMPLEXCASE == 1
complex(kind=COMPLEX_DATATYPE), allocatable :: tmp1(:), tmp2(:,:), tmatr(:,:), tmatc(:,:)
#endif
logical, intent(in) :: wantDebug
logical :: success
integer(kind=ik) :: istat
character(200) :: errorMessage
call timer%start("elpa_cholesky_&
&MATH_DATATYPE&
&_&
&PRECISION &
")
call timer%start("mpi_communication")
call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
call timer%stop("mpi_communication")
success = .true.
! Matrix is split into tiles; work is done only for tiles on the diagonal or above
tile_size = nblk*least_common_multiple(np_rows,np_cols) ! minimum global tile size
tile_size = ((128*max(np_rows,np_cols)-1)/tile_size+1)*tile_size ! make local tiles at least 128 wide
l_rows_tile = tile_size/np_rows ! local rows of a tile
l_cols_tile = tile_size/np_cols ! local cols of a tile
l_rows = local_index(na, my_prow, np_rows, nblk, -1) ! Local rows of a
l_cols = local_index(na, my_pcol, np_cols, nblk, -1) ! Local cols of a
allocate(tmp1(nblk*nblk), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_cholesky_&
&MATH_DATATYPE&: error when allocating tmp1 "//errorMessage
stop
endif
allocate(tmp2(nblk,nblk), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_cholesky_&
&MATH_DATATYPE&
&: error when allocating tmp2 "//errorMessage
stop
endif
tmp1 = 0
tmp2 = 0
allocate(tmatr(l_rows,nblk), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_cholesky_&
&MATH_DATATYPE&
&: error when allocating tmatr "//errorMessage
stop
endif
allocate(tmatc(l_cols,nblk), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_cholesky_&
&MATH_DATATYPE&
&: error when allocating tmatc "//errorMessage
stop
endif
tmatr = 0
tmatc = 0
do n = 1, na, nblk
! Calculate first local row and column of the still remaining matrix
! on the local processor
l_row1 = local_index(n, my_prow, np_rows, nblk, +1)
l_col1 = local_index(n, my_pcol, np_cols, nblk, +1)
l_rowx = local_index(n+nblk, my_prow, np_rows, nblk, +1)
l_colx = local_index(n+nblk, my_pcol, np_cols, nblk, +1)
if (n+nblk > na) then
! This is the last step, just do a Cholesky-Factorization
! of the remaining block
if (my_prow==prow(n, nblk, np_rows) .and. my_pcol==pcol(n, nblk, np_cols)) then
call timer%start("blas")
call PRECISION_POTRF('U', na-n+1, a(l_row1,l_col1), lda, info)
call timer%stop("blas")
if (info/=0) then
if (wantDebug) write(error_unit,*) "elpa_cholesky_&
&MATH_DATATYPE&
#if REALCASE == 1
&: Error in dpotrf: ",info
#endif
#if COMPLEXCASE == 1
&: Error in zpotrf: ",info
#endif
success = .false.
return
endif
endif
exit ! Loop
endif
if (my_prow==prow(n, nblk, np_rows)) then
if (my_pcol==pcol(n, nblk, np_cols)) then
! The process owning the upper left remaining block does the
! Cholesky-Factorization of this block
call timer%start("blas")
call PRECISION_POTRF('U', nblk, a(l_row1,l_col1), lda, info)
call timer%stop("blas")
if (info/=0) then
if (wantDebug) write(error_unit,*) "elpa_cholesky_&
&MATH_DATATYPE&
#if REALCASE == 1
&: Error in dpotrf 2: ",info
#endif
#if COMPLEXCASE == 1
&: Error in zpotrf 2: ",info
#endif
success = .false.
return
endif
nc = 0
do i=1,nblk
tmp1(nc+1:nc+i) = a(l_row1:l_row1+i-1,l_col1+i-1)
nc = nc+i
enddo
endif
#ifdef WITH_MPI
call timer%start("mpi_communication")
call MPI_Bcast(tmp1, nblk*(nblk+1)/2, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
pcol(n, nblk, np_cols), mpi_comm_cols, mpierr)
call timer%stop("mpi_communication")
#endif /* WITH_MPI */
nc = 0
do i=1,nblk
tmp2(1:i,i) = tmp1(nc+1:nc+i)
nc = nc+i
enddo
call timer%start("blas")
if (l_cols-l_colx+1>0) &
#if REALCASE == 1
call PRECISION_TRSM('L', 'U', 'T', 'N', nblk, l_cols-l_colx+1, CONST_1_0, tmp2, ubound(tmp2,dim=1), &
a(l_row1,l_colx), lda)
#endif
#if COMPLEXCASE == 1
call PRECISION_TRSM('L', 'U', 'C', 'N', nblk, l_cols-l_colx+1, CONST_COMPLEX_PAIR_1_0, &
tmp2, ubound(tmp2,dim=1), a(l_row1,l_colx), lda)
#endif
call timer%stop("blas")
endif
do i=1,nblk
#if REALCASE == 1
if (my_prow==prow(n, nblk, np_rows)) tmatc(l_colx:l_cols,i) = a(l_row1+i-1,l_colx:l_cols)
#endif
#if COMPLEXCASE == 1
if (my_prow==prow(n, nblk, np_rows)) tmatc(l_colx:l_cols,i) = conjg(a(l_row1+i-1,l_colx:l_cols))
#endif
#ifdef WITH_MPI
call timer%start("mpi_communication")
if (l_cols-l_colx+1>0) &
call MPI_Bcast(tmatc(l_colx,i), l_cols-l_colx+1, &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
prow(n, nblk, np_rows), mpi_comm_rows, mpierr)
call timer%stop("mpi_communication")
#endif /* WITH_MPI */
enddo
! this has to be checked since it was changed substantially when doing type safe
call elpa_transpose_vectors_&
&MATH_DATATYPE&
&_&
&PRECISION &
(tmatc, ubound(tmatc,dim=1), mpi_comm_cols, &
tmatr, ubound(tmatr,dim=1), mpi_comm_rows, &
n, na, nblk, nblk)
do i=0,(na-1)/tile_size
lcs = max(l_colx,i*l_cols_tile+1)
lce = min(l_cols,(i+1)*l_cols_tile)
lrs = l_rowx
lre = min(l_rows,(i+1)*l_rows_tile)
if (lce<lcs .or. lre<lrs) cycle
call timer%start("blas")
#if REALCASE == 1
call PRECISION_GEMM('N', 'T', lre-lrs+1, lce-lcs+1, nblk, -CONST_1_0, &
tmatr(lrs,1), ubound(tmatr,dim=1), tmatc(lcs,1), ubound(tmatc,dim=1), &
CONST_1_0, a(lrs,lcs), lda)
#endif
#if COMPLEXCASE == 1
call PRECISION_GEMM('N', 'C', lre-lrs+1, lce-lcs+1, nblk, -CONST_COMPLEX_PAIR_1_0, &
tmatr(lrs,1), ubound(tmatr,dim=1), tmatc(lcs,1), ubound(tmatc,dim=1), &
CONST_COMPLEX_PAIR_1_0, a(lrs,lcs), lda)
#endif
call timer%stop("blas")
enddo
enddo
deallocate(tmp1, tmp2, tmatr, tmatc, stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_cholesky_&
&MATH_DATATYPE&
&: error when deallocating tmp1 "//errorMessage
stop
endif
! Set the lower triangle to 0, it contains garbage (form the above matrix multiplications)
do i=1,na
if (my_pcol==pcol(i, nblk, np_cols)) then
! column i is on local processor
l_col1 = local_index(i , my_pcol, np_cols, nblk, +1) ! local column number
l_row1 = local_index(i+1, my_prow, np_rows, nblk, +1) ! first row below diagonal
a(l_row1:l_rows,l_col1) = 0
endif
enddo
call timer%stop("elpa_cholesky_&
&MATH_DATATYPE&
&_&
&PRECISION&
")
#undef REALCASE
#undef COMPLEXCASE
#undef DOUBLE_PRECISION
#undef SINGLE_PRECISION
......@@ -6,6 +6,8 @@
#undef PRECISION_STR
#undef REAL_DATATYPE
#undef PRECISION_POTRF
#undef PRECISION_TRSM
#undef PRECISION_GEMV
#undef PRECISION_TRMV
#undef PRECISION_GEMM
......@@ -50,6 +52,8 @@
#define PRECISION_SUFFIX "_double"
#define REAL_DATATYPE rk8
#define PRECISION_POTRF DPOTRF
#define PRECISION_TRSM DTRSM
#define PRECISION_GEMV DGEMV
#define PRECISION_TRMV DTRMV
#define PRECISION_GEMM DGEMM
......@@ -91,6 +95,8 @@
#define PRECISION_SUFFIX "_single"
#define REAL_DATATYPE rk4
#define PRECISION_POTRF SPOTRF
#define PRECISION_TRSM STRSM
#define PRECISION_GEMV SGEMV
#define PRECISION_TRMV STRMV
#define PRECISION_GEMM SGEMM
......@@ -136,6 +142,8 @@
/* in the complex case also sometime real valued variables are needed */
#undef REAL_DATATYPE
#undef PRECISION_POTRF
#undef PRECISION_TRSM
#undef PRECISION_STR
#undef PRECISION_GEMV
#undef PRECISION_TRMV
......@@ -192,6 +200,8 @@
#define COMPLEX_DATATYPE CK8
#define REAL_DATATYPE RK8
#define PRECISION_POTRF ZPOTRF
#define PRECISION_TRSM ZTRSM
#define PRECISION_GEMV ZGEMV
#define PRECISION_TRMV ZTRMV
#define PRECISION_GEMM ZGEMM
......@@ -243,6 +253,8 @@
#define COMPLEX_DATATYPE CK4
#define REAL_DATATYPE RK4
#define PRECISION_POTRF CPOTRF
#define PRECISION_TRSM CTRSM
#define PRECISION_GEMV CGEMV
#define PRECISION_TRMV CTRMV
#define PRECISION_GEMM CGEMM
......
#ifdef REALCASE
#ifdef COMPLEXCASE
#error Cannot define both REALCASE and COMPLEXCASE
#endif
#endif
#ifndef REALCASE
#ifndef COMPLEXCASE
#error Must define one of REALCASE or COMPLEXCASE
#endif
#endif
#ifdef SINGLE_PRECISION
#ifdef DOUBLE_PRECISION
#error Cannot define both SINGLE_PRECISION and DOUBLE_PRECISION
#endif
#endif
#ifndef SINGLE_PRECISION
#ifndef DOUBLE_PRECISION
#error Must define one of SINGLE_PRECISION or DOUBLE_PRECISION
#endif
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment