Commit c2b2dfa9 authored by Andreas Marek's avatar Andreas Marek

Cleanup of ELPA 1/2_stage_solve_impl signatures

parent e5a3670c
......@@ -58,9 +58,8 @@ function elpa_solve_evp_&
&MATH_DATATYPE&
&_1stage_&
&PRECISION&
&_impl (obj, na, nev, a, lda, ev, q, ldq, nblk, matrixCols, mpi_comm_rows, &
mpi_comm_cols, mpi_comm_all, useGPU, time_evp_fwd, &
time_evp_solve, time_evp_back, summary_timings) result(success)
&_impl (obj, a, ev, q, time_evp_fwd, &
time_evp_solve, time_evp_back) result(success)
use precision
use cuda_functions
use mod_check_for_gpu
......@@ -77,27 +76,25 @@ function elpa_solve_evp_&
implicit none
class(elpa_t), intent(in) :: obj
integer(kind=c_int), intent(in) :: na, nev, lda, ldq, nblk, matrixCols, mpi_comm_rows, &
mpi_comm_cols, mpi_comm_all
real(kind=REAL_DATATYPE), intent(out) :: ev(na)
real(kind=REAL_DATATYPE), intent(out) :: ev(obj%na)
#if REALCASE == 1
#ifdef USE_ASSUMED_SIZE
real(kind=C_DATATYPE_KIND), intent(inout) :: a(lda,*)
real(kind=C_DATATYPE_KIND), intent(out) :: q(ldq,*)
real(kind=C_DATATYPE_KIND), intent(inout) :: a(obj%local_nrows,*)
real(kind=C_DATATYPE_KIND), intent(out) :: q(obj%local_nrows,*)
#else
real(kind=C_DATATYPE_KIND), intent(inout) :: a(lda,matrixCols)
real(kind=C_DATATYPE_KIND), intent(out) :: q(ldq,matrixCols)
real(kind=C_DATATYPE_KIND), intent(inout) :: a(obj%local_nrows,matrixCols)
real(kind=C_DATATYPE_KIND), intent(out) :: q(obj%local_nrows,matrixCols)
#endif
real(kind=C_DATATYPE_KIND), allocatable :: tau(:)
#endif /* REALCASE */
#if COMPLEXCASE == 1
#ifdef USE_ASSUMED_SIZE
complex(kind=C_DATATYPE_KIND), intent(inout) :: a(lda,*)
complex(kind=C_DATATYPE_KIND), intent(out) :: q(ldq,*)
complex(kind=C_DATATYPE_KIND), intent(inout) :: a(obj%local_nrows,*)
complex(kind=C_DATATYPE_KIND), intent(out) :: q(obj%local_nrows,*)
#else
complex(kind=C_DATATYPE_KIND), intent(inout) :: a(lda,matrixCols)
complex(kind=C_DATATYPE_KIND), intent(out) :: q(ldq,matrixCols)
complex(kind=C_DATATYPE_KIND), intent(inout) :: a(obj%local_nrows,matrixCols)
complex(kind=C_DATATYPE_KIND), intent(out) :: q(obj%local_nrows,matrixCols)
#endif
real(kind=REAL_DATATYPE), allocatable :: q_real(:,:)
......@@ -105,12 +102,12 @@ function elpa_solve_evp_&
integer(kind=c_int) :: l_cols, l_rows, l_cols_nev, np_rows, np_cols
#endif /* COMPLEXCASE */
logical, intent(in) :: useGPU
logical :: useGPU
logical :: success
real(kind=c_double) :: time_evp_fwd, &
time_evp_solve, time_evp_back
logical, intent(in) :: summary_timings
logical :: summary_timings
logical :: do_useGPU
integer(kind=ik) :: numberOfGPUDevices
......@@ -120,6 +117,8 @@ function elpa_solve_evp_&
logical :: wantDebug
integer(kind=c_int) :: istat
character(200) :: errorMessage
integer(kind=ik) :: na, nev, lda, ldq, nblk, matrixCols, &
mpi_comm_rows, mpi_comm_cols, mpi_comm_all
call timer%start("elpa_solve_evp_&
&MATH_DATATYPE&
......@@ -127,6 +126,28 @@ function elpa_solve_evp_&
&PRECISION&
&")
na = obj%na
nev = obj%nev
lda = obj%local_nrows
ldq = obj%local_nrows
nblk = obj%nblk
matrixCols = obj%local_ncols
mpi_comm_rows = obj%get("mpi_comm_rows")
mpi_comm_cols = obj%get("mpi_comm_cols")
mpi_comm_all = obj%get("mpi_comm_parent")
if (obj%get("gpu") .eq. 1) then
useGPU =.true.
else
useGPU = .false.
endif
if (obj%get("summary_timings") .eq. 1) then
summary_timings = .true.
else
summary_timings = .false.
endif
call timer%start("mpi_communication")
call mpi_comm_rank(mpi_comm_all,my_pe,mpierr)
......@@ -182,7 +203,6 @@ function elpa_solve_evp_&
endif
endif
endif
#if COMPLEXCASE == 1
l_rows = local_index(na, my_prow, np_rows, nblk, -1) ! Local rows of a and q
l_cols = local_index(na, my_pcol, np_cols, nblk, -1) ! Local columns of q
......@@ -199,7 +219,6 @@ function elpa_solve_evp_&
stop 1
endif
#endif
allocate(e(na), tau(na), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"solve_evp_&
......@@ -209,20 +228,17 @@ function elpa_solve_evp_&
&" // ": error when allocating e, tau "//errorMessage
stop 1
endif
ttt0 = MPI_Wtime()
call tridiag_&
&MATH_DATATYPE&
&_&
&PRECISION&
& (na, a, lda, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, ev, e, tau, do_useGPU)
ttt1 = MPI_Wtime()
if(my_prow==0 .and. my_pcol==0 .and. summary_timings) write(error_unit,*) 'Time tridiag_real :',ttt1-ttt0
time_evp_fwd = ttt1-ttt0
ttt0 = MPI_Wtime()
call solve_tridi_&
&PRECISION&
& (na, nev, ev, e, &
......@@ -243,7 +259,6 @@ function elpa_solve_evp_&
#if COMPLEXCASE == 1
q(1:l_rows,1:l_cols_nev) = q_real(1:l_rows,1:l_cols_nev)
#endif
call trans_ev_&
&MATH_DATATYPE&
&_&
......@@ -252,7 +267,6 @@ function elpa_solve_evp_&
ttt1 = MPI_Wtime()
if(my_prow==0 .and. my_pcol==0 .and. summary_timings) write(error_unit,*) 'Time trans_ev_real:',ttt1-ttt0
time_evp_back = ttt1-ttt0
#if COMPLEXCASE == 1
deallocate(q_real, stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
......@@ -264,7 +278,6 @@ function elpa_solve_evp_&
stop 1
endif
#endif
deallocate(e, tau, stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"solve_evp_&
......@@ -280,7 +293,6 @@ function elpa_solve_evp_&
&_1stage_&
&PRECISION&
&")
end function
......@@ -54,12 +54,8 @@
&_&
&2stage_&
&PRECISION&
&_impl (obj, na, nev, a, lda, ev, q, ldq, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, mpi_comm_all, &
time_evp_fwd, time_evp_solve, time_evp_back, summary_timings, useGPU, kernel &
#if REALCASE == 1
, useQR &
#endif
&) result(success)
&_impl (obj, a, ev, q, &
time_evp_fwd, time_evp_solve, time_evp_back) result(success)
#ifdef HAVE_DETAILED_TIMINGS
use timings
......@@ -77,9 +73,9 @@
use iso_c_binding
implicit none
class(elpa_t), intent(in) :: obj
logical, intent(in) :: useGPU
logical :: useGPU
#if REALCASE == 1
logical, intent(in), optional :: useQR
logical :: useQR
#endif
logical :: useQRActual
......@@ -87,16 +83,13 @@
integer(kind=c_int) :: kernel
integer(kind=c_int), intent(in) :: na, nev, lda, ldq, matrixCols, mpi_comm_rows, &
mpi_comm_cols, mpi_comm_all
integer(kind=c_int), intent(in) :: nblk
#ifdef USE_ASSUMED_SIZE
MATH_DATATYPE(kind=C_DATATYPE_KIND), intent(inout) :: a(lda,*), q(ldq,*)
MATH_DATATYPE(kind=C_DATATYPE_KIND), intent(inout) :: a(obj%local_nrows,*), q(obj%local_nrows,*)
#else
MATH_DATATYPE(kind=C_DATATYPE_KIND), intent(inout) :: a(lda,matrixCols), q(ldq,matrixCols)
MATH_DATATYPE(kind=C_DATATYPE_KIND), intent(inout) :: a(obj%local_nrows,obj%local_ncols), &
q(obj%local_nrows,obj%local_ncols)
#endif
real(kind=C_DATATYPE_KIND), intent(inout) :: ev(na)
real(kind=C_DATATYPE_KIND), intent(inout) :: ev(obj%na)
MATH_DATATYPE(kind=C_DATATYPE_KIND), allocatable :: hh_trans(:,:)
integer(kind=c_int) :: my_pe, n_pes, my_prow, my_pcol, np_rows, np_cols, mpierr
......@@ -107,7 +100,7 @@
real(kind=C_DATATYPE_KIND), allocatable :: q_real(:,:)
#endif
real(kind=c_double) :: time_evp_fwd, time_evp_solve, time_evp_back
logical, intent(in) :: summary_timings
logical :: summary_timings
integer(kind=c_intptr_t) :: tmat_dev, q_dev, a_dev
real(kind=c_double) :: ttt0, ttt1, ttts ! MPI_WTIME always needs double
......@@ -122,6 +115,8 @@
&PRECISION&
&_&
&MATH_DATATYPE
integer(kind=ik) :: na, nev, lda, ldq, nblk, matrixCols, &
mpi_comm_rows, mpi_comm_cols, mpi_comm_all
call timer%start("solve_evp_&
&MATH_DATATYPE&
......@@ -129,6 +124,59 @@
&PRECISION_SUFFIX &
)
na = obj%na
nev = obj%nev
lda = obj%local_nrows
ldq = obj%local_nrows
nblk = obj%nblk
matrixCols = obj%local_ncols
#if REALCASE == 1
kernel = obj%get("real_kernel")
! check consistency between request for GPUs and defined kernel
if (obj%get("gpu") == 1) then
if (kernel .ne. ELPA_2STAGE_REAL_GPU) then
write(error_unit,*) "ELPA: Warning, GPU usage has been requested but compute kernel is defined as non-GPU!"
else if (obj%get("nblk") .ne. 128) then
kernel = ELPA_2STAGE_REAL_GENERIC
endif
endif
#endif
#if COMPLEXCASE == 1
kernel = obj%get("complex_kernel")
! check consistency between request for GPUs and defined kernel
if (obj%get("gpu") == 1) then
if (kernel .ne. ELPA_2STAGE_COMPLEX_GPU) then
write(error_unit,*) "ELPA: Warning, GPU usage has been requested but compute kernel is defined as non-GPU!"
else if (obj%get("nblk") .ne. 128) then
kernel = ELPA_2STAGE_COMPLEX_GENERIC
endif
endif
#endif
mpi_comm_rows = obj%get("mpi_comm_rows")
mpi_comm_cols = obj%get("mpi_comm_cols")
mpi_comm_all = obj%get("mpi_comm_parent")
if (obj%get("summary_timings") .eq. 1) then
summary_timings = .true.
else
summary_timings = .false.
endif
if (obj%get("gpu") .eq. 1) then
useGPU = .true.
else
useGPU = .false.
endif
#if REALCASE == 1
if (obj%get("qr") .eq. 1) then
useQR = .true.
else
useQR = .false.
endif
#endif
call timer%start("mpi_communication")
call mpi_comm_rank(mpi_comm_all,my_pe,mpierr)
call mpi_comm_size(mpi_comm_all,n_pes,mpierr)
......@@ -149,10 +197,8 @@
#if REALCASE == 1
useQRActual = .false.
! set usage of qr decomposition via API call
if (present(useQR)) then
if (useQR) useQRActual = .true.
if (.not.(useQR)) useQRACtual = .false.
endif
if (useQR) useQRActual = .true.
if (.not.(useQR)) useQRACtual = .false.
if (useQRActual) then
if (mod(na,2) .ne. 0) then
......
......@@ -86,9 +86,6 @@ module elpa_api
procedure(elpa_is_set_i), deferred, public :: is_set
procedure(elpa_can_set_i), deferred, public :: can_set
procedure(elpa_get_int_i), deferred, private :: get_real_kernel
procedure(elpa_get_int_i), deferred, private :: get_complex_kernel
! actual math routines
generic, public :: solve => &
elpa_solve_real_double, &
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment