Commit d2732032 authored by Pavel Kus's avatar Pavel Kus
Browse files

real/complex and single/double unification of elpa2/redist_band.F90

parent 6f68a719
...@@ -51,20 +51,7 @@ subroutine redist_band_& ...@@ -51,20 +51,7 @@ subroutine redist_band_&
&MATH_DATATYPE& &MATH_DATATYPE&
&_& &_&
&PRECISION & &PRECISION &
(obj, & (obj, a, a_dev, lda, na, nblk, nbw, matrixCols, mpi_comm_rows, mpi_comm_cols, communicator, ab, useGPU)
#if REALCASE == 1
r_a, &
#endif
#if COMPLEXCASE == 1
c_a, &
#endif
a_dev, lda, na, nblk, nbw, matrixCols, mpi_comm_rows, mpi_comm_cols, communicator, &
#if REALCASE == 1
r_ab, useGPU)
#endif
#if COMPLEXCASE == 1
c_ab, useGPU)
#endif
use elpa_abstract_impl use elpa_abstract_impl
use elpa2_workload use elpa2_workload
...@@ -78,30 +65,13 @@ subroutine redist_band_& ...@@ -78,30 +65,13 @@ subroutine redist_band_&
class(elpa_abstract_impl_t), intent(inout) :: obj class(elpa_abstract_impl_t), intent(inout) :: obj
logical, intent(in) :: useGPU logical, intent(in) :: useGPU
integer(kind=ik), intent(in) :: lda, na, nblk, nbw, matrixCols, mpi_comm_rows, mpi_comm_cols, communicator integer(kind=ik), intent(in) :: lda, na, nblk, nbw, matrixCols, mpi_comm_rows, mpi_comm_cols, communicator
#if REALCASE == 1 MATH_DATATYPE(kind=C_DATATYPE_KIND), intent(in) :: a(lda, matrixCols)
MATH_DATATYPE(kind=C_DATATYPE_KIND), intent(in) :: r_a(lda, matrixCols) MATH_DATATYPE(kind=C_DATATYPE_KIND), intent(out) :: ab(:,:)
#endif
#if COMPLEXCASE == 1
MATH_DATATYPE(kind=C_DATATYPE_KIND), intent(in) :: c_a(lda, matrixCols)
#endif
#if REALCASE == 1
MATH_DATATYPE(kind=C_DATATYPE_KIND), intent(out) :: r_ab(:,:)
#endif
#if COMPLEXCASE == 1
MATH_DATATYPE(kind=C_DATATYPE_KIND), intent(out) :: c_ab(:,:)
#endif
integer(kind=ik), allocatable :: ncnt_s(:), nstart_s(:), ncnt_r(:), nstart_r(:), & integer(kind=ik), allocatable :: ncnt_s(:), nstart_s(:), ncnt_r(:), nstart_r(:), &
global_id(:,:), global_id_tmp(:,:), block_limits(:) global_id(:,:), global_id_tmp(:,:), block_limits(:)
#if REALCASE == 1 MATH_DATATYPE(kind=C_DATATYPE_KIND), allocatable :: sbuf(:,:,:), rbuf(:,:,:), buf(:,:)
MATH_DATATYPE(kind=C_DATATYPE_KIND), allocatable :: r_sbuf(:,:,:), r_rbuf(:,:,:), r_buf(:,:)
#endif
#if COMPLEXCASE == 1
MATH_DATATYPE(kind=C_DATATYPE_KIND), allocatable :: c_sbuf(:,:,:), c_rbuf(:,:,:), c_buf(:,:)
#endif
integer(kind=ik) :: i, j, my_pe, n_pes, my_prow, np_rows, my_pcol, np_cols, & integer(kind=ik) :: i, j, my_pe, n_pes, my_prow, np_rows, my_pcol, np_cols, &
nfact, np, npr, npc, mpierr, is, js nfact, np, npr, npc, mpierr, is, js
integer(kind=ik) :: nblocks_total, il, jl, l_rows, l_cols, n_off integer(kind=ik) :: nblocks_total, il, jl, l_rows, l_cols, n_off
...@@ -121,14 +91,7 @@ subroutine redist_band_& ...@@ -121,14 +91,7 @@ subroutine redist_band_&
if (useGPU) then if (useGPU) then
! copy a_dev to aMatrix ! copy a_dev to aMatrix
successCUDA = cuda_memcpy ( & successCUDA = cuda_memcpy (loc(a), int(a_dev,kind=c_intptr_t), int(lda*matrixCols* size_of_datatype, kind=c_intptr_t), &
#if REALCASE == 1
loc(r_a), &
#endif
#if COMPLEXCASE == 1
loc(c_a(1,1)), &
#endif
int(a_dev,kind=c_intptr_t), int(lda*matrixCols* size_of_datatype, kind=c_intptr_t), &
cudaMemcpyDeviceToHost) cudaMemcpyDeviceToHost)
if (.not.(successCUDA)) then if (.not.(successCUDA)) then
print *,"redist_band_& print *,"redist_band_&
...@@ -200,14 +163,8 @@ subroutine redist_band_& ...@@ -200,14 +163,8 @@ subroutine redist_band_&
! Allocate send buffer ! Allocate send buffer
#if REALCASE==1 allocate(sbuf(nblk,nblk,sum(ncnt_s)))
allocate(r_sbuf(nblk,nblk,sum(ncnt_s))) sbuf(:,:,:) = 0.
r_sbuf(:,:,:) = 0.
#endif
#if COMPLEXCASE==1
allocate(c_sbuf(nblk,nblk,sum(ncnt_s)))
c_sbuf(:,:,:) = 0.
#endif
! Determine start offsets in send buffer ! Determine start offsets in send buffer
...@@ -233,12 +190,7 @@ subroutine redist_band_& ...@@ -233,12 +190,7 @@ subroutine redist_band_&
jl = MIN(nblk,l_rows-js) jl = MIN(nblk,l_rows-js)
il = MIN(nblk,l_cols-is) il = MIN(nblk,l_cols-is)
#if REALCASE==1 sbuf(1:jl,1:il,nstart_s(np)) = a(js+1:js+jl,is+1:is+il)
r_sbuf(1:jl,1:il,nstart_s(np)) = r_a(js+1:js+jl,is+1:is+il)
#endif
#if COMPLEXCASE==1
c_sbuf(1:jl,1:il,nstart_s(np)) = c_a(js+1:js+jl,is+1:is+il)
#endif
endif endif
enddo enddo
endif endif
...@@ -258,12 +210,7 @@ subroutine redist_band_& ...@@ -258,12 +210,7 @@ subroutine redist_band_&
! Allocate receive buffer ! Allocate receive buffer
#if REALCASE==1 allocate(rbuf(nblk,nblk,sum(ncnt_r)))
allocate(r_rbuf(nblk,nblk,sum(ncnt_r)))
#endif
#if COMPLEXCASE==1
allocate(c_rbuf(nblk,nblk,sum(ncnt_r)))
#endif
! Set send counts/send offsets, receive counts/receive offsets ! Set send counts/send offsets, receive counts/receive offsets
! now actually in variables, not in blocks ! now actually in variables, not in blocks
...@@ -286,37 +233,12 @@ subroutine redist_band_& ...@@ -286,37 +233,12 @@ subroutine redist_band_&
#ifdef WITH_MPI #ifdef WITH_MPI
call obj%timer%start("mpi_communication") call obj%timer%start("mpi_communication")
#if REALCASE==1 call MPI_Alltoallv(sbuf, ncnt_s, nstart_s, MPI_MATH_DATATYPE_PRECISION_EXPL, &
rbuf, ncnt_r, nstart_r, MPI_MATH_DATATYPE_PRECISION_EXPL, communicator, mpierr)
#ifdef DOUBLE_PRECISION_REAL
call MPI_Alltoallv(r_sbuf, ncnt_s, nstart_s, MPI_REAL8, r_rbuf, ncnt_r, nstart_r, MPI_REAL8, communicator, mpierr)
#else
call MPI_Alltoallv(r_sbuf, ncnt_s, nstart_s, MPI_REAL4, r_rbuf, ncnt_r, nstart_r, MPI_REAL4, communicator, mpierr)
#endif
#endif /* REALCASE==1 */
#if COMPLEXCASE==1
#ifdef DOUBLE_PRECISION_COMPLEX
call MPI_Alltoallv(c_sbuf, ncnt_s, nstart_s, MPI_COMPLEX16, c_rbuf, ncnt_r, nstart_r, MPI_COMPLEX16, communicator, mpierr)
#else
call MPI_Alltoallv(c_sbuf, ncnt_s, nstart_s, MPI_COMPLEX, c_rbuf, ncnt_r, nstart_r, MPI_COMPLEX, communicator, mpierr)
#endif
#endif /* COMPLEXCASE==1 */
call obj%timer%stop("mpi_communication") call obj%timer%stop("mpi_communication")
#else /* WITH_MPI */ #else /* WITH_MPI */
rbuf = sbuf
#if REALCASE==1
r_rbuf = r_sbuf
#endif
#if COMPLEXCASE==1
c_rbuf = c_sbuf
#endif
#endif /* WITH_MPI */ #endif /* WITH_MPI */
! set band from receive buffer ! set band from receive buffer
...@@ -328,12 +250,7 @@ subroutine redist_band_& ...@@ -328,12 +250,7 @@ subroutine redist_band_&
nstart_r(i) = nstart_r(i-1) + ncnt_r(i-1) nstart_r(i) = nstart_r(i-1) + ncnt_r(i-1)
enddo enddo
#if REALCASE==1 allocate(buf((nfact+1)*nblk,nblk))
allocate(r_buf((nfact+1)*nblk,nblk))
#endif
#if COMPLEXCASE==1
allocate(c_buf((nfact+1)*nblk,nblk))
#endif
! n_off: Offset of ab within band ! n_off: Offset of ab within band
n_off = block_limits(my_pe)*nbw n_off = block_limits(my_pe)*nbw
...@@ -345,19 +262,14 @@ subroutine redist_band_& ...@@ -345,19 +262,14 @@ subroutine redist_band_&
np = global_id(npr,npc) np = global_id(npr,npc)
nstart_r(np) = nstart_r(np) + 1 nstart_r(np) = nstart_r(np) + 1
#if REALCASE==1 #if REALCASE==1
r_buf(i*nblk+1:i*nblk+nblk,:) = transpose(r_rbuf(:,:,nstart_r(np))) buf(i*nblk+1:i*nblk+nblk,:) = transpose(rbuf(:,:,nstart_r(np)))
#endif #endif
#if COMPLEXCASE==1 #if COMPLEXCASE==1
c_buf(i*nblk+1:i*nblk+nblk,:) = conjg(transpose(c_rbuf(:,:,nstart_r(np)))) buf(i*nblk+1:i*nblk+nblk,:) = conjg(transpose(rbuf(:,:,nstart_r(np))))
#endif #endif
enddo enddo
do i=1,MIN(nblk,na-j*nblk) do i=1,MIN(nblk,na-j*nblk)
#if REALCASE==1 ab(1:nbw+1,i+j*nblk-n_off) = buf(i:i+nbw,i)
r_ab(1:nbw+1,i+j*nblk-n_off) = r_buf(i:i+nbw,i)
#endif
#if COMPLEXCASE==1
c_ab(1:nbw+1,i+j*nblk-n_off) = c_buf(i:i+nbw,i)
#endif
enddo enddo
enddo enddo
...@@ -366,12 +278,7 @@ subroutine redist_band_& ...@@ -366,12 +278,7 @@ subroutine redist_band_&
deallocate(global_id) deallocate(global_id)
deallocate(block_limits) deallocate(block_limits)
#if REALCASE==1 deallocate(sbuf, rbuf, buf)
deallocate(r_sbuf, r_rbuf, r_buf)
#endif
#if COMPLEXCASE==1
deallocate(c_sbuf, c_rbuf, c_buf)
#endif
call obj%timer%stop("redist_band_& call obj%timer%stop("redist_band_&
&MATH_DATATYPE& &MATH_DATATYPE&
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment