Commit 0e5e0c20 authored by Andreas Marek's avatar Andreas Marek
Browse files

Map multiply_at_b, multiply_ah_b to new interface

parent 70c71d7d
...@@ -54,13 +54,13 @@ ...@@ -54,13 +54,13 @@
#include "../sanity.X90" #include "../sanity.X90"
use elpa_type
#ifdef HAVE_DETAILED_TIMINGS #ifdef HAVE_DETAILED_TIMINGS
use timings use timings
#else #else
use timings_dummy use timings_dummy
#endif #endif
use elpa1_compute ! use elpa1_compute
use elpa_mpi use elpa_mpi
use precision use precision
implicit none implicit none
...@@ -70,36 +70,39 @@ ...@@ -70,36 +70,39 @@
integer(kind=ik), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, nblk integer(kind=ik), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, nblk
integer(kind=ik) :: ncb, mpi_comm_rows, mpi_comm_cols integer(kind=ik) :: ncb, mpi_comm_rows, mpi_comm_cols
#if REALCASE == 1 #if REALCASE == 1
#ifdef USE_ASSUMED_SIZE !#ifdef USE_ASSUMED_SIZE
real(kind=REAL_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*) ! real(kind=REAL_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*)
#else !#else
real(kind=REAL_DATATYPE) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols) real(kind=REAL_DATATYPE) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
#endif !#endif
#endif #endif
#if COMPLEXCASE == 1 #if COMPLEXCASE == 1
#ifdef USE_ASSUMED_SIZE !#ifdef USE_ASSUMED_SIZE
complex(kind=COMPLEX_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*) complex(kind=COMPLEX_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*)
#else !#else
complex(kind=COMPLEX_DATATYPE) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols) ! complex(kind=COMPLEX_DATATYPE) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
#endif !#endif
#endif #endif
integer(kind=ik) :: my_prow, my_pcol, np_rows, np_cols, mpierr ! integer(kind=ik) :: my_prow, my_pcol, np_rows, np_cols, mpierr
integer(kind=ik) :: l_cols, l_rows, l_rows_np integer(kind=ik) :: nev
integer(kind=ik) :: np, n, nb, nblk_mult, lrs, lre, lcs, lce ! integer(kind=ik) :: l_cols, l_rows, l_rows_np
integer(kind=ik) :: gcol_min, gcol, goff ! integer(kind=ik) :: np, n, nb, nblk_mult, lrs, lre, lcs, lce
integer(kind=ik) :: nstor, nr_done, noff, np_bc, n_aux_bc, nvals ! integer(kind=ik) :: gcol_min, gcol, goff
integer(kind=ik), allocatable :: lrs_save(:), lre_save(:) ! integer(kind=ik) :: nstor, nr_done, noff, np_bc, n_aux_bc, nvals
! integer(kind=ik), allocatable :: lrs_save(:), lre_save(:)
logical :: a_lower, a_upper, c_lower, c_upper
#if REALCASE == 1 ! logical :: a_lower, a_upper, c_lower, c_upper
real(kind=REAL_DATATYPE), allocatable :: aux_mat(:,:), aux_bc(:), tmp1(:,:), tmp2(:,:) !#if REALCASE == 1
#endif ! real(kind=REAL_DATATYPE), allocatable :: aux_mat(:,:), aux_bc(:), tmp1(:,:), tmp2(:,:)
#if COMPLEXCASE == 1 !#endif
complex(kind=COMPLEX_DATATYPE), allocatable :: aux_mat(:,:), aux_bc(:), tmp1(:,:), tmp2(:,:) !#if COMPLEXCASE == 1
#endif ! complex(kind=COMPLEX_DATATYPE), allocatable :: aux_mat(:,:), aux_bc(:), tmp1(:,:), tmp2(:,:)
integer(kind=ik) :: istat !#endif
character(200) :: errorMessage ! integer(kind=ik) :: istat
logical :: success ! character(200) :: errorMessage
logical :: success
integer(kind=ik) :: successInternal
type(elpa_t) :: elpaAPI
call timer%start("elpa_mult_at_b_& call timer%start("elpa_mult_at_b_&
&MATH_DATATYPE& &MATH_DATATYPE&
...@@ -109,239 +112,44 @@ ...@@ -109,239 +112,44 @@
success = .true. success = .true.
call timer%start("mpi_communication") !call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr) !call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
call mpi_comm_size(mpi_comm_rows,np_rows,mpierr) !call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr) !call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
call timer%stop("mpi_communication")
l_rows = local_index(na, my_prow, np_rows, nblk, -1) ! Local rows of a and b
l_cols = local_index(ncb, my_pcol, np_cols, nblk, -1) ! Local cols of b
! Block factor for matrix multiplications, must be a multiple of nblk !l_rows = local_index(na, my_prow, np_rows, nblk, -1) ! Local rows of a and b
!l_cols = local_index(ncb, my_pcol, np_cols, nblk, -1) ! Local cols of b
if (na/np_rows<=256) then if (elpa_init(20170403) /= ELPA_OK) then
nblk_mult = (31/nblk+1)*nblk success = .false.
else error stop "ELPA API version not supported"
nblk_mult = (63/nblk+1)*nblk
endif endif
allocate(aux_mat(l_rows,nblk_mult), stat=istat, errmsg=errorMessage) nev = 10
if (istat .ne. 0) then elpaAPI = elpa_create(na, nev, lda, ldaCols, nblk, successInternal)
print *,"elpa_mult_at_b_& if (successInternal .ne. ELPA_OK) then
&MATH_DATATYPE& print *, "Cannot create elpa object"
&: error when allocating aux_mat "//errorMessage success = .false.
stop 1 stop
return
endif endif
allocate(aux_bc(l_rows*nblk), stat=istat, errmsg=errorMessage) call elpaAPI%set_comm_rows(mpi_comm_rows)
if (istat .ne. 0) then call elpaAPI%set_comm_cols(mpi_comm_cols)
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when allocating aux_bc "//errorMessage
stop 1
endif
allocate(lrs_save(nblk), stat=istat, errmsg=errorMessage) call elpaAPI%multiply_a_b(uplo_a, uplo_c, na, ncb, a(1:lda,1:ldaCols), lda, ldaCols, &
if (istat .ne. 0) then b(1:ldb,1:ldbCols), ldb, ldbCols, &
print *,"elpa_mult_at_b_& c(1:ldc,1:ldcCols), ldc, ldcCols, successInternal)
&MATH_DATATYPE&
&: error when allocating lrs_save "//errorMessage
stop 1
endif
allocate(lre_save(nblk), stat=istat, errmsg=errorMessage) if (successInternal .ne. ELPA_OK) then
if (istat .ne. 0) then print *, "Cannot run multiply_a_b"
print *,"elpa_mult_at_b_& stop
&MATH_DATATYPE& success = .false.
&: error when allocating lre_save "//errorMessage return
stop 1
endif endif
call elpaAPI%destroy()
a_lower = .false. call elpa_uninit()
a_upper = .false.
c_lower = .false.
c_upper = .false.
if (uplo_a=='u' .or. uplo_a=='U') a_upper = .true.
if (uplo_a=='l' .or. uplo_a=='L') a_lower = .true.
if (uplo_c=='u' .or. uplo_c=='U') c_upper = .true.
if (uplo_c=='l' .or. uplo_c=='L') c_lower = .true.
! Build up the result matrix by processor rows
do np = 0, np_rows-1
! In this turn, procs of row np assemble the result
l_rows_np = local_index(na, np, np_rows, nblk, -1) ! local rows on receiving processors
nr_done = 0 ! Number of rows done
aux_mat = 0
nstor = 0 ! Number of columns stored in aux_mat
! Loop over the blocks on row np
do nb=0,(l_rows_np-1)/nblk
goff = nb*np_rows + np ! Global offset in blocks corresponding to nb
! Get the processor column which owns this block (A is transposed, so we need the column)
! and the offset in blocks within this column.
! The corresponding block column in A is then broadcast to all for multiplication with B
np_bc = MOD(goff,np_cols)
noff = goff/np_cols
n_aux_bc = 0
! Gather up the complete block column of A on the owner
do n = 1, min(l_rows_np-nb*nblk,nblk) ! Loop over columns to be broadcast
gcol = goff*nblk + n ! global column corresponding to n
if (nstor==0 .and. n==1) gcol_min = gcol
lrs = 1 ! 1st local row number for broadcast
lre = l_rows ! last local row number for broadcast
if (a_lower) lrs = local_index(gcol, my_prow, np_rows, nblk, +1)
if (a_upper) lre = local_index(gcol, my_prow, np_rows, nblk, -1)
if (lrs<=lre) then
nvals = lre-lrs+1
if (my_pcol == np_bc) aux_bc(n_aux_bc+1:n_aux_bc+nvals) = a(lrs:lre,noff*nblk+n)
n_aux_bc = n_aux_bc + nvals
endif
lrs_save(n) = lrs
lre_save(n) = lre
enddo
! Broadcast block column
#ifdef WITH_MPI
call timer%start("mpi_communication")
#if REALCASE == 1
call MPI_Bcast(aux_bc, n_aux_bc, &
MPI_REAL_PRECISION, &
np_bc, mpi_comm_cols, mpierr)
#endif
call timer%stop("mpi_communication")
#endif /* WITH_MPI */
! Insert what we got in aux_mat
n_aux_bc = 0
do n = 1, min(l_rows_np-nb*nblk,nblk)
nstor = nstor+1
lrs = lrs_save(n)
lre = lre_save(n)
if (lrs<=lre) then
nvals = lre-lrs+1
aux_mat(lrs:lre,nstor) = aux_bc(n_aux_bc+1:n_aux_bc+nvals)
n_aux_bc = n_aux_bc + nvals
endif
enddo
! If we got nblk_mult columns in aux_mat or this is the last block
! do the matrix multiplication
if (nstor==nblk_mult .or. nb*nblk+nblk >= l_rows_np) then
lrs = 1 ! 1st local row number for multiply
lre = l_rows ! last local row number for multiply
if (a_lower) lrs = local_index(gcol_min, my_prow, np_rows, nblk, +1)
if (a_upper) lre = local_index(gcol, my_prow, np_rows, nblk, -1)
lcs = 1 ! 1st local col number for multiply
lce = l_cols ! last local col number for multiply
if (c_upper) lcs = local_index(gcol_min, my_pcol, np_cols, nblk, +1)
if (c_lower) lce = MIN(local_index(gcol, my_pcol, np_cols, nblk, -1),l_cols)
if (lcs<=lce) then
allocate(tmp1(nstor,lcs:lce),tmp2(nstor,lcs:lce), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when allocating tmp1 "//errorMessage
stop 1
endif
if (lrs<=lre) then
call timer%start("blas")
#if REALCASE == 1
call PRECISION_GEMM('T', 'N', &
#endif
#if COMPLEXCASE == 1
call PRECISION_GEMM('C', 'N', &
#endif
nstor, lce-lcs+1, lre-lrs+1, &
#if REALCASE == 1
CONST_1_0, &
#endif
#if COMPLEXCASE == 1
CONST_COMPLEX_PAIR_1_0, &
#endif
aux_mat(lrs,1), ubound(aux_mat,dim=1), &
b(lrs,lcs), ldb, &
#if REALCASE == 1
CONST_0_0, &
#endif
#if COMPLEXCASE == 1
CONST_COMPLEX_PAIR_0_0, &
#endif
tmp1, nstor)
call timer%stop("blas")
else
tmp1 = 0
endif
! Sum up the results and send to processor row np
#ifdef WITH_MPI
call timer%start("mpi_communication")
call mpi_reduce(tmp1, tmp2, nstor*(lce-lcs+1), &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
MPI_SUM, np, mpi_comm_rows, mpierr)
call timer%stop("mpi_communication")
! Put the result into C
if (my_prow==np) c(nr_done+1:nr_done+nstor,lcs:lce) = tmp2(1:nstor,lcs:lce)
#else /* WITH_MPI */
! tmp2 = tmp1
! Put the result into C
if (my_prow==np) c(nr_done+1:nr_done+nstor,lcs:lce) = tmp1(1:nstor,lcs:lce)
#endif /* WITH_MPI */
deallocate(tmp1,tmp2, stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when deallocating tmp1 "//errorMessage
stop 1
endif
endif
nr_done = nr_done+nstor
nstor=0
aux_mat(:,:)=0
endif
enddo
enddo
deallocate(aux_mat, aux_bc, lrs_save, lre_save, stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when deallocating aux_mat "//errorMessage
stop 1
endif
call timer%stop("elpa_mult_at_b_& call timer%stop("elpa_mult_at_b_&
&MATH_DATATYPE& &MATH_DATATYPE&
......
...@@ -104,8 +104,8 @@ ...@@ -104,8 +104,8 @@
call timer%start("elpa_mult_at_b_& call timer%start("elpa_mult_at_b_&
&MATH_DATATYPE& &MATH_DATATYPE&
&_& &_&
&PRECISION & &PRECISION&
") &_new")
success = .true. success = .true.
...@@ -346,8 +346,8 @@ ...@@ -346,8 +346,8 @@
call timer%stop("elpa_mult_at_b_& call timer%stop("elpa_mult_at_b_&
&MATH_DATATYPE& &MATH_DATATYPE&
&_& &_&
&PRECISION & &PRECISION&
") &_new")
#undef REALCASE #undef REALCASE
#undef COMPLEXCASE #undef COMPLEXCASE
......
...@@ -70,6 +70,13 @@ module ELPA2_new ...@@ -70,6 +70,13 @@ module ELPA2_new
public :: elpa_solve_evp_real_2stage_double_new !< Driver routine for real double-precision 2-stage eigenvalue problem public :: elpa_solve_evp_real_2stage_double_new !< Driver routine for real double-precision 2-stage eigenvalue problem
public :: elpa_solve_evp_complex_2stage_double_new !< Driver routine for complex double-precision 2-stage eigenvalue problem public :: elpa_solve_evp_complex_2stage_double_new !< Driver routine for complex double-precision 2-stage eigenvalue problem
#ifdef WANT_SINGLE_PRECISION_REAL
public :: elpa_solve_evp_real_2stage_single_new !< Driver routine for real single-precision 2-stage eigenvalue problem
#endif
#ifdef WANT_SINGLE_PRECISION_COMPLEX
public :: elpa_solve_evp_complex_2stage_single_new !< Driver routine for complex single-precision 2-stage eigenvalue problem
#endif
contains contains
......
#include <elpa/elpa_constants.h> #include <elpa/elpa_constants.h>
#include "config-f90.h"
module elpa_type module elpa_type
use, intrinsic :: iso_c_binding use, intrinsic :: iso_c_binding
...@@ -47,6 +48,11 @@ module elpa_type ...@@ -47,6 +48,11 @@ module elpa_type
public :: elpa_init, elpa_initialized, elpa_uninit, elpa_create, elpa_t, c_int, c_double, c_float public :: elpa_init, elpa_initialized, elpa_uninit, elpa_create, elpa_t, c_int, c_double, c_float
interface elpa_create
module procedure elpa_create_generic
module procedure elpa_create_special
end interface
type :: elpa_t type :: elpa_t
private private
type(c_ptr) :: options = C_NULL_PTR type(c_ptr) :: options = C_NULL_PTR
...@@ -63,6 +69,10 @@ module elpa_type ...@@ -63,6 +69,10 @@ module elpa_type
generic, public :: get => elpa_get_integer generic, public :: get => elpa_get_integer
procedure, public :: get_communicators => get_communicators procedure, public :: get_communicators => get_communicators
procedure, public :: set_comm_rows
procedure, public :: set_comm_cols
generic, public :: solve => elpa_solve_real_double, & generic, public :: solve => elpa_solve_real_double, &
elpa_solve_real_single, & elpa_solve_real_single, &
elpa_solve_complex_double, & elpa_solve_complex_double, &
...@@ -162,7 +172,8 @@ module elpa_type ...@@ -162,7 +172,8 @@ module elpa_type
end subroutine end subroutine
function elpa_create(na, nev, local_nrows, local_ncols, nblk, mpi_comm_parent, process_row, process_col, success) result(obj) function elpa_create_generic(na, nev, local_nrows, local_ncols, nblk, mpi_comm_parent, &
process_row, process_col, success) result(obj)
use precision use precision
use elpa_mpi use elpa_mpi
use elpa_utilities, only : error_unit use elpa_utilities, only : error_unit
...@@ -208,6 +219,58 @@ module elpa_type ...@@ -208,6 +219,58 @@ module elpa_type
end function end function
function elpa_create_special(na, nev, local_nrows, local_ncols, nblk, success) result(obj)
use precision
use elpa_mpi
use elpa_utilities, only : error_unit
use elpa1_new, only : elpa_get_communicators_new
implicit none
integer(kind=ik), intent(in) :: na, nev, local_nrows, local_ncols, nblk
!integer, intent(in) :: mpi_comm_rows, mpi_comm_cols, process_row, process_col
type(elpa_t) :: obj
integer :: mpierr
integer :: success
! check whether init has ever been called
if (.not.(elpa_initialized())) then
write(error_unit, *) "elpa_create(): you must call elpa_init() once before creating instances of ELPA"
success = ELPA_ERROR
return
endif
obj%options = elpa_allocate_options()
obj%na = na
obj%nev = nev
obj%local_nrows = local_nrows
obj%local_ncols = local_ncols
obj%nblk = nblk
!obj%mpi_comm_rows = mpi_comm_rows
!obj%mpi_comm_cols = mpi_comm_rows
success = ELPA_OK
end function
subroutine set_comm_rows(self, mpi_comm_rows)
use iso_c_binding
implicit none
integer, intent(in) :: mpi_comm_rows
class(elpa_t) :: self
self%mpi_comm_rows = mpi_comm_rows
end subroutine
subroutine set_comm_cols(self, mpi_comm_cols)
use iso_c_binding
implicit none
integer, intent(in) :: mpi_comm_cols
class(elpa_t) :: self
self%mpi_comm_cols = mpi_comm_cols
end subroutine
subroutine elpa_set_integer(self, name, value, success) subroutine elpa_set_integer(self, name, value, success)
use iso_c_binding use iso_c_binding
...@@ -475,14 +538,15 @@ module elpa_type ...@@ -475,14 +538,15 @@ module elpa_type
use elpa_utilities, only : error_unit use elpa_utilities, only : error_unit
use iso_c_binding use iso_c_binding
use precision
implicit none implicit none
class(elpa_t) :: self class(elpa_t) :: self
!#ifdef USE_ASSUMED_SIZE !#ifdef USE_ASSUMED_SIZE
! complex(kind=c_float_complex) :: a(self%local_nrows, *), q(self%local_nrows, *) ! complex(kind=c_float_complex) :: a(self%local_nrows, *), q(self%local_nrows, *)
!#else !#else
complex(kind=c_float_complex) :: a(self%local_nrows, self%local_ncols), q(self%local_nrows, self%local_ncols) complex(kind=ck4) :: a(self%local_nrows, self%local_ncols), q(self%local_nrows, self%local_ncols)
!#endif !#endif
real(kind=c_float) :: ev(self%na) real(kind=rk4) :: ev(self%na)
integer, optional :: success integer, optional :: success
integer(kind=c_int) :: success_internal integer(kind=c_int) :: success_internal
...@@ -546,14 +610,15 @@ module elpa_type ...@@ -546,14 +610,15 @@ module elpa_type
c, ldc, ldcCols, success) c, ldc, ldcCols, success)
use iso_c_binding use iso_c_binding
use elpa1_auxiliary_new use elpa1_auxiliary_new
use precision
implicit none implicit none
class(elpa_t) :: self class(elpa_t) :: self
character*1 :: uplo_a, uplo_c character*1 :: uplo_a, uplo_c
integer(kind=c_int), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, ncb integer(kind=ik), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, ncb
!#ifdef USE_ASSUMED_SIZE !#ifdef USE_ASSUMED_SIZE
! real(kind=REAL_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*) ! real(kind=REAL_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*)
!#else !#else
real(kind=c_double) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols) real(kind=rk8) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
!#endif !#endif
integer, optional :: success integer, optional :: success
logical :: success_l logical :: success_l
...@@ -575,14 +640,15 @@ module elpa_type ...@@ -575,14 +640,15 @@ module elpa_type
c, ldc, ldcCols, success) c, ldc, ldcCols, success)
use iso_c_binding use iso_c_binding
use elpa1_auxiliary_new use elpa1_auxiliary_new
use precision
implicit none implicit none
class(elpa_t) :: self class(elpa_t) :: self
character*1 :: uplo_a, uplo_c character*1 :: uplo_a, uplo_c
integer(kind=c_int), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, ncb integer(kind=ik), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, ncb
!#ifdef USE_ASSUMED_SIZE !#ifdef USE_ASSUMED_SIZE
! real(kind=REAL_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*) ! real(kind=REAL_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*)
!#else !#else
real(kind=c_float) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols) real(kind=rk4) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
!#endif !#endif
integer, optional :: success integer, optional :: success
logical :: success_l logical :: success_l
...@@ -605,14 +671,15 @@ module elpa_type ...@@ -605,14 +671,15 @@ module elpa_type
c, ldc, ldcCols, success) c, ldc, ldcCols, success)
use iso_c_binding use iso_c_binding
use elpa1_auxiliary_new use elpa1_auxiliary_new
use precision
implicit none implicit none
class(elpa_t) :: self class(elpa_t) :: self
character*1 :: uplo_a, uplo_c character*1 :: uplo_a, uplo_c
integer(kind=c_int), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, ncb integer(kind=ik), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, ncb
!#ifdef USE_ASSUMED_SIZE !#ifdef USE_ASSUMED_SIZE
! complex(kind=REAL_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*)