Commit 0e5e0c20 authored by Andreas Marek's avatar Andreas Marek
Browse files

Map multiply_at_b, multiply_ah_b to new interface

parent 70c71d7d
......@@ -54,13 +54,13 @@
#include "../sanity.X90"
use elpa_type
#ifdef HAVE_DETAILED_TIMINGS
use timings
#else
use timings_dummy
#endif
use elpa1_compute
! use elpa1_compute
use elpa_mpi
use precision
implicit none
......@@ -70,36 +70,39 @@
integer(kind=ik), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, nblk
integer(kind=ik) :: ncb, mpi_comm_rows, mpi_comm_cols
#if REALCASE == 1
#ifdef USE_ASSUMED_SIZE
real(kind=REAL_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*)
#else
!#ifdef USE_ASSUMED_SIZE
! real(kind=REAL_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*)
!#else
real(kind=REAL_DATATYPE) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
#endif
!#endif
#endif
#if COMPLEXCASE == 1
#ifdef USE_ASSUMED_SIZE
!#ifdef USE_ASSUMED_SIZE
complex(kind=COMPLEX_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*)
#else
complex(kind=COMPLEX_DATATYPE) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
#endif
#endif
integer(kind=ik) :: my_prow, my_pcol, np_rows, np_cols, mpierr
integer(kind=ik) :: l_cols, l_rows, l_rows_np
integer(kind=ik) :: np, n, nb, nblk_mult, lrs, lre, lcs, lce
integer(kind=ik) :: gcol_min, gcol, goff
integer(kind=ik) :: nstor, nr_done, noff, np_bc, n_aux_bc, nvals
integer(kind=ik), allocatable :: lrs_save(:), lre_save(:)
logical :: a_lower, a_upper, c_lower, c_upper
#if REALCASE == 1
real(kind=REAL_DATATYPE), allocatable :: aux_mat(:,:), aux_bc(:), tmp1(:,:), tmp2(:,:)
#endif
#if COMPLEXCASE == 1
complex(kind=COMPLEX_DATATYPE), allocatable :: aux_mat(:,:), aux_bc(:), tmp1(:,:), tmp2(:,:)
#endif
integer(kind=ik) :: istat
character(200) :: errorMessage
logical :: success
!#else
! complex(kind=COMPLEX_DATATYPE) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
!#endif
#endif
! integer(kind=ik) :: my_prow, my_pcol, np_rows, np_cols, mpierr
integer(kind=ik) :: nev
! integer(kind=ik) :: l_cols, l_rows, l_rows_np
! integer(kind=ik) :: np, n, nb, nblk_mult, lrs, lre, lcs, lce
! integer(kind=ik) :: gcol_min, gcol, goff
! integer(kind=ik) :: nstor, nr_done, noff, np_bc, n_aux_bc, nvals
! integer(kind=ik), allocatable :: lrs_save(:), lre_save(:)
! logical :: a_lower, a_upper, c_lower, c_upper
!#if REALCASE == 1
! real(kind=REAL_DATATYPE), allocatable :: aux_mat(:,:), aux_bc(:), tmp1(:,:), tmp2(:,:)
!#endif
!#if COMPLEXCASE == 1
! complex(kind=COMPLEX_DATATYPE), allocatable :: aux_mat(:,:), aux_bc(:), tmp1(:,:), tmp2(:,:)
!#endif
! integer(kind=ik) :: istat
! character(200) :: errorMessage
logical :: success
integer(kind=ik) :: successInternal
type(elpa_t) :: elpaAPI
call timer%start("elpa_mult_at_b_&
&MATH_DATATYPE&
......@@ -109,239 +112,44 @@
success = .true.
call timer%start("mpi_communication")
call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
call timer%stop("mpi_communication")
l_rows = local_index(na, my_prow, np_rows, nblk, -1) ! Local rows of a and b
l_cols = local_index(ncb, my_pcol, np_cols, nblk, -1) ! Local cols of b
!call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
!call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
!call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
!call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
! Block factor for matrix multiplications, must be a multiple of nblk
!l_rows = local_index(na, my_prow, np_rows, nblk, -1) ! Local rows of a and b
!l_cols = local_index(ncb, my_pcol, np_cols, nblk, -1) ! Local cols of b
if (na/np_rows<=256) then
nblk_mult = (31/nblk+1)*nblk
else
nblk_mult = (63/nblk+1)*nblk
if (elpa_init(20170403) /= ELPA_OK) then
success = .false.
error stop "ELPA API version not supported"
endif
allocate(aux_mat(l_rows,nblk_mult), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when allocating aux_mat "//errorMessage
stop 1
nev = 10
elpaAPI = elpa_create(na, nev, lda, ldaCols, nblk, successInternal)
if (successInternal .ne. ELPA_OK) then
print *, "Cannot create elpa object"
success = .false.
stop
return
endif
allocate(aux_bc(l_rows*nblk), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when allocating aux_bc "//errorMessage
stop 1
endif
call elpaAPI%set_comm_rows(mpi_comm_rows)
call elpaAPI%set_comm_cols(mpi_comm_cols)
allocate(lrs_save(nblk), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when allocating lrs_save "//errorMessage
stop 1
endif
call elpaAPI%multiply_a_b(uplo_a, uplo_c, na, ncb, a(1:lda,1:ldaCols), lda, ldaCols, &
b(1:ldb,1:ldbCols), ldb, ldbCols, &
c(1:ldc,1:ldcCols), ldc, ldcCols, successInternal)
allocate(lre_save(nblk), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when allocating lre_save "//errorMessage
stop 1
if (successInternal .ne. ELPA_OK) then
print *, "Cannot run multiply_a_b"
stop
success = .false.
return
endif
call elpaAPI%destroy()
a_lower = .false.
a_upper = .false.
c_lower = .false.
c_upper = .false.
if (uplo_a=='u' .or. uplo_a=='U') a_upper = .true.
if (uplo_a=='l' .or. uplo_a=='L') a_lower = .true.
if (uplo_c=='u' .or. uplo_c=='U') c_upper = .true.
if (uplo_c=='l' .or. uplo_c=='L') c_lower = .true.
! Build up the result matrix by processor rows
do np = 0, np_rows-1
! In this turn, procs of row np assemble the result
l_rows_np = local_index(na, np, np_rows, nblk, -1) ! local rows on receiving processors
nr_done = 0 ! Number of rows done
aux_mat = 0
nstor = 0 ! Number of columns stored in aux_mat
! Loop over the blocks on row np
do nb=0,(l_rows_np-1)/nblk
goff = nb*np_rows + np ! Global offset in blocks corresponding to nb
! Get the processor column which owns this block (A is transposed, so we need the column)
! and the offset in blocks within this column.
! The corresponding block column in A is then broadcast to all for multiplication with B
np_bc = MOD(goff,np_cols)
noff = goff/np_cols
n_aux_bc = 0
! Gather up the complete block column of A on the owner
do n = 1, min(l_rows_np-nb*nblk,nblk) ! Loop over columns to be broadcast
gcol = goff*nblk + n ! global column corresponding to n
if (nstor==0 .and. n==1) gcol_min = gcol
lrs = 1 ! 1st local row number for broadcast
lre = l_rows ! last local row number for broadcast
if (a_lower) lrs = local_index(gcol, my_prow, np_rows, nblk, +1)
if (a_upper) lre = local_index(gcol, my_prow, np_rows, nblk, -1)
if (lrs<=lre) then
nvals = lre-lrs+1
if (my_pcol == np_bc) aux_bc(n_aux_bc+1:n_aux_bc+nvals) = a(lrs:lre,noff*nblk+n)
n_aux_bc = n_aux_bc + nvals
endif
lrs_save(n) = lrs
lre_save(n) = lre
enddo
! Broadcast block column
#ifdef WITH_MPI
call timer%start("mpi_communication")
#if REALCASE == 1
call MPI_Bcast(aux_bc, n_aux_bc, &
MPI_REAL_PRECISION, &
np_bc, mpi_comm_cols, mpierr)
#endif
call timer%stop("mpi_communication")
#endif /* WITH_MPI */
! Insert what we got in aux_mat
n_aux_bc = 0
do n = 1, min(l_rows_np-nb*nblk,nblk)
nstor = nstor+1
lrs = lrs_save(n)
lre = lre_save(n)
if (lrs<=lre) then
nvals = lre-lrs+1
aux_mat(lrs:lre,nstor) = aux_bc(n_aux_bc+1:n_aux_bc+nvals)
n_aux_bc = n_aux_bc + nvals
endif
enddo
! If we got nblk_mult columns in aux_mat or this is the last block
! do the matrix multiplication
if (nstor==nblk_mult .or. nb*nblk+nblk >= l_rows_np) then
lrs = 1 ! 1st local row number for multiply
lre = l_rows ! last local row number for multiply
if (a_lower) lrs = local_index(gcol_min, my_prow, np_rows, nblk, +1)
if (a_upper) lre = local_index(gcol, my_prow, np_rows, nblk, -1)
lcs = 1 ! 1st local col number for multiply
lce = l_cols ! last local col number for multiply
if (c_upper) lcs = local_index(gcol_min, my_pcol, np_cols, nblk, +1)
if (c_lower) lce = MIN(local_index(gcol, my_pcol, np_cols, nblk, -1),l_cols)
if (lcs<=lce) then
allocate(tmp1(nstor,lcs:lce),tmp2(nstor,lcs:lce), stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when allocating tmp1 "//errorMessage
stop 1
endif
if (lrs<=lre) then
call timer%start("blas")
#if REALCASE == 1
call PRECISION_GEMM('T', 'N', &
#endif
#if COMPLEXCASE == 1
call PRECISION_GEMM('C', 'N', &
#endif
nstor, lce-lcs+1, lre-lrs+1, &
#if REALCASE == 1
CONST_1_0, &
#endif
#if COMPLEXCASE == 1
CONST_COMPLEX_PAIR_1_0, &
#endif
aux_mat(lrs,1), ubound(aux_mat,dim=1), &
b(lrs,lcs), ldb, &
#if REALCASE == 1
CONST_0_0, &
#endif
#if COMPLEXCASE == 1
CONST_COMPLEX_PAIR_0_0, &
#endif
tmp1, nstor)
call timer%stop("blas")
else
tmp1 = 0
endif
! Sum up the results and send to processor row np
#ifdef WITH_MPI
call timer%start("mpi_communication")
call mpi_reduce(tmp1, tmp2, nstor*(lce-lcs+1), &
#if REALCASE == 1
MPI_REAL_PRECISION, &
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION, &
#endif
MPI_SUM, np, mpi_comm_rows, mpierr)
call timer%stop("mpi_communication")
! Put the result into C
if (my_prow==np) c(nr_done+1:nr_done+nstor,lcs:lce) = tmp2(1:nstor,lcs:lce)
#else /* WITH_MPI */
! tmp2 = tmp1
! Put the result into C
if (my_prow==np) c(nr_done+1:nr_done+nstor,lcs:lce) = tmp1(1:nstor,lcs:lce)
#endif /* WITH_MPI */
deallocate(tmp1,tmp2, stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when deallocating tmp1 "//errorMessage
stop 1
endif
endif
nr_done = nr_done+nstor
nstor=0
aux_mat(:,:)=0
endif
enddo
enddo
deallocate(aux_mat, aux_bc, lrs_save, lre_save, stat=istat, errmsg=errorMessage)
if (istat .ne. 0) then
print *,"elpa_mult_at_b_&
&MATH_DATATYPE&
&: error when deallocating aux_mat "//errorMessage
stop 1
endif
call elpa_uninit()
call timer%stop("elpa_mult_at_b_&
&MATH_DATATYPE&
......
......@@ -104,8 +104,8 @@
call timer%start("elpa_mult_at_b_&
&MATH_DATATYPE&
&_&
&PRECISION &
")
&PRECISION&
&_new")
success = .true.
......@@ -346,8 +346,8 @@
call timer%stop("elpa_mult_at_b_&
&MATH_DATATYPE&
&_&
&PRECISION &
")
&PRECISION&
&_new")
#undef REALCASE
#undef COMPLEXCASE
......
......@@ -70,6 +70,13 @@ module ELPA2_new
public :: elpa_solve_evp_real_2stage_double_new !< Driver routine for real double-precision 2-stage eigenvalue problem
public :: elpa_solve_evp_complex_2stage_double_new !< Driver routine for complex double-precision 2-stage eigenvalue problem
#ifdef WANT_SINGLE_PRECISION_REAL
public :: elpa_solve_evp_real_2stage_single_new !< Driver routine for real single-precision 2-stage eigenvalue problem
#endif
#ifdef WANT_SINGLE_PRECISION_COMPLEX
public :: elpa_solve_evp_complex_2stage_single_new !< Driver routine for complex single-precision 2-stage eigenvalue problem
#endif
contains
......
#include <elpa/elpa_constants.h>
#include "config-f90.h"
module elpa_type
use, intrinsic :: iso_c_binding
......@@ -47,6 +48,11 @@ module elpa_type
public :: elpa_init, elpa_initialized, elpa_uninit, elpa_create, elpa_t, c_int, c_double, c_float
interface elpa_create
module procedure elpa_create_generic
module procedure elpa_create_special
end interface
type :: elpa_t
private
type(c_ptr) :: options = C_NULL_PTR
......@@ -63,6 +69,10 @@ module elpa_type
generic, public :: get => elpa_get_integer
procedure, public :: get_communicators => get_communicators
procedure, public :: set_comm_rows
procedure, public :: set_comm_cols
generic, public :: solve => elpa_solve_real_double, &
elpa_solve_real_single, &
elpa_solve_complex_double, &
......@@ -162,7 +172,8 @@ module elpa_type
end subroutine
function elpa_create(na, nev, local_nrows, local_ncols, nblk, mpi_comm_parent, process_row, process_col, success) result(obj)
function elpa_create_generic(na, nev, local_nrows, local_ncols, nblk, mpi_comm_parent, &
process_row, process_col, success) result(obj)
use precision
use elpa_mpi
use elpa_utilities, only : error_unit
......@@ -208,6 +219,58 @@ module elpa_type
end function
function elpa_create_special(na, nev, local_nrows, local_ncols, nblk, success) result(obj)
use precision
use elpa_mpi
use elpa_utilities, only : error_unit
use elpa1_new, only : elpa_get_communicators_new
implicit none
integer(kind=ik), intent(in) :: na, nev, local_nrows, local_ncols, nblk
!integer, intent(in) :: mpi_comm_rows, mpi_comm_cols, process_row, process_col
type(elpa_t) :: obj
integer :: mpierr
integer :: success
! check whether init has ever been called
if (.not.(elpa_initialized())) then
write(error_unit, *) "elpa_create(): you must call elpa_init() once before creating instances of ELPA"
success = ELPA_ERROR
return
endif
obj%options = elpa_allocate_options()
obj%na = na
obj%nev = nev
obj%local_nrows = local_nrows
obj%local_ncols = local_ncols
obj%nblk = nblk
!obj%mpi_comm_rows = mpi_comm_rows
!obj%mpi_comm_cols = mpi_comm_rows
success = ELPA_OK
end function
subroutine set_comm_rows(self, mpi_comm_rows)
use iso_c_binding
implicit none
integer, intent(in) :: mpi_comm_rows
class(elpa_t) :: self
self%mpi_comm_rows = mpi_comm_rows
end subroutine
subroutine set_comm_cols(self, mpi_comm_cols)
use iso_c_binding
implicit none
integer, intent(in) :: mpi_comm_cols
class(elpa_t) :: self
self%mpi_comm_cols = mpi_comm_cols
end subroutine
subroutine elpa_set_integer(self, name, value, success)
use iso_c_binding
......@@ -475,14 +538,15 @@ module elpa_type
use elpa_utilities, only : error_unit
use iso_c_binding
use precision
implicit none
class(elpa_t) :: self
!#ifdef USE_ASSUMED_SIZE
! complex(kind=c_float_complex) :: a(self%local_nrows, *), q(self%local_nrows, *)
!#else
complex(kind=c_float_complex) :: a(self%local_nrows, self%local_ncols), q(self%local_nrows, self%local_ncols)
complex(kind=ck4) :: a(self%local_nrows, self%local_ncols), q(self%local_nrows, self%local_ncols)
!#endif
real(kind=c_float) :: ev(self%na)
real(kind=rk4) :: ev(self%na)
integer, optional :: success
integer(kind=c_int) :: success_internal
......@@ -546,14 +610,15 @@ module elpa_type
c, ldc, ldcCols, success)
use iso_c_binding
use elpa1_auxiliary_new
use precision
implicit none
class(elpa_t) :: self
character*1 :: uplo_a, uplo_c
integer(kind=c_int), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, ncb
integer(kind=ik), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, ncb
!#ifdef USE_ASSUMED_SIZE
! real(kind=REAL_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*)
!#else
real(kind=c_double) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
real(kind=rk8) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
!#endif
integer, optional :: success
logical :: success_l
......@@ -575,14 +640,15 @@ module elpa_type
c, ldc, ldcCols, success)
use iso_c_binding
use elpa1_auxiliary_new
use precision
implicit none
class(elpa_t) :: self
character*1 :: uplo_a, uplo_c
integer(kind=c_int), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, ncb
integer(kind=ik), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, ncb
!#ifdef USE_ASSUMED_SIZE
! real(kind=REAL_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*)
!#else
real(kind=c_float) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
real(kind=rk4) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
!#endif
integer, optional :: success
logical :: success_l
......@@ -605,14 +671,15 @@ module elpa_type
c, ldc, ldcCols, success)
use iso_c_binding
use elpa1_auxiliary_new
use precision
implicit none
class(elpa_t) :: self
character*1 :: uplo_a, uplo_c
integer(kind=c_int), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, ncb
integer(kind=ik), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, ncb
!#ifdef USE_ASSUMED_SIZE
! complex(kind=REAL_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*)
!#else
complex(kind=c_double_complex) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
complex(kind=ck8) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
!#endif
integer, optional :: success
logical :: success_l
......@@ -634,14 +701,15 @@ module elpa_type
c, ldc, ldcCols, success)
use iso_c_binding
use elpa1_auxiliary_new
use precision
implicit none
class(elpa_t) :: self
character*1 :: uplo_a, uplo_c
integer(kind=c_int), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, ncb
integer(kind=ik), intent(in) :: na, lda, ldaCols, ldb, ldbCols, ldc, ldcCols, ncb
!#ifdef USE_ASSUMED_SIZE
! real(kind=REAL_DATATYPE) :: a(lda,*), b(ldb,*), c(ldc,*)
!#else
complex(kind=c_float_complex) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
complex(kind=ck4) :: a(lda,ldaCols), b(ldb,ldbCols), c(ldc,ldcCols)
!#endif
integer, optional :: success
logical :: success_l
......
......@@ -222,7 +222,6 @@ program test_transpose_multiply
#ifdef WITH_MPI
call mpi_barrier(mpi_comm_world, mpierr) ! for correct timings only
#endif
success = elpa_mult_at_b_real_double("F","F", na, na, a, na_rows, na_cols, b, na_rows, &
na_cols, nblk, mpi_comm_rows, mpi_comm_cols, c, &
na_rows, na_cols)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment