Commit 1d452c16 authored by Pavel Kus's avatar Pavel Kus
Browse files

elpa2_tridiag_band_complex single/double unification

parent afc5a14a
......@@ -62,6 +62,8 @@ blas_tokens = [
"PRECISION_SYRK",
"PRECISION_SYMV",
"PRECISION_SYMM",
"PRECISION_HEMV",
"PRECISION_HER2",
"PRECISION_SYR2",
"PRECISION_SYR2K",
"PRECISION_GEQRF",
......@@ -88,7 +90,12 @@ explicit_tokens_complex = [
("PRECISION_REAL", "DREAL", "REAL"),
("CONST_REAL_0_0", "0.0_rk8", "0.0_rk4"),
("CONST_REAL_1_0", "1.0_rk8", "1.0_rk4"),
("CONST_REAL_0_5", "0.5_rk8", "0.5_rk4"),
("CONST_COMPLEX_PAIR_0_0", "(0.0_rk8,0.0_rk8)", "(0.0_rk4,0.0_rk4)"),
("CONST_COMPLEX_PAIR_1_0", "(1.0_rk8,0.0_rk8)", "(1.0_rk4,0.0_rk4)"),
("CONST_COMPLEX_PAIR_NEGATIVE_1_0", "(-1.0_rk8,0.0_rk8)", "(-1.0_rk4,0.0_rk4)"),
("CONST_COMPLEX_0_0", "0.0_ck8", "0.0_ck4"),
("CONST_COMPLEX_1_0", "1.0_ck8", "1.0_ck4"),
("size_of_PRECISION_complex", "size_of_double_complex_datatype", "size_of_single_complex_datatype"),
]
......
#ifdef DOUBLE_PRECISION_COMPLEX
subroutine tridiag_band_complex_double(na, nb, nblk, a, lda, d, e, matrixCols, hh_trans_complex, &
mpi_comm_rows, mpi_comm_cols, mpi_comm)
#else
subroutine tridiag_band_complex_single(na, nb, nblk, a, lda, d, e, matrixCols, hh_trans_complex, &
subroutine tridiag_band_complex_PRECISION(na, nb, nblk, a, lda, d, e, matrixCols, hh_trans_complex, &
mpi_comm_rows, mpi_comm_cols, mpi_comm)
#endif
!-------------------------------------------------------------------------------
! tridiag_band_complex:
! Reduces a complex hermitian symmetric band matrix to tridiagonal form
......@@ -88,11 +81,7 @@
! ! dummies for calling redist_band
! real*8 :: r_a(1,1), r_ab(1,1)
#ifdef DOUBLE_PRECISION_COMPLEX
call timer%start("tridiag_band_complex_double")
#else
call timer%start("tridiag_band_complex_single")
#endif
call timer%start("tridiag_band_complex_PRECISION")
call timer%start("mpi_communication")
call mpi_comm_rank(mpi_comm,my_pe,mpierr)
call mpi_comm_size(mpi_comm,n_pes,mpierr)
......@@ -190,11 +179,7 @@
n_off = block_limits(my_pe)*nb
! Redistribute band in a to ab
#ifdef DOUBLE_PRECISION_COMPLEX
call redist_band_complex_double(a, lda, na, nblk, nb, matrixCols, mpi_comm_rows, mpi_comm_cols, mpi_comm, ab)
#else
call redist_band_complex_single(a, lda, na, nblk, nb, matrixCols, mpi_comm_rows, mpi_comm_cols, mpi_comm, ab)
#endif
call redist_band_complex_PRECISION(a, lda, na, nblk, nb, matrixCols, mpi_comm_rows, mpi_comm_cols, mpi_comm, ab)
! Calculate the workload for each sweep in the back transformation
! and the space requirements to hold the HH vectors
......@@ -254,13 +239,8 @@
num_chunks = num_chunks+1
#ifdef WITH_MPI
call timer%start("mpi_communication")
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_irecv(hh_trans_complex(1,num_hh_vecs+1), nb*local_size, MPI_COMPLEX16, nt, &
call mpi_irecv(hh_trans_complex(1,num_hh_vecs+1), nb*local_size, MPI_COMPLEX_EXPLICIT_PRECISION, nt, &
10+n-block_limits(nt), mpi_comm, ireq_hhr(num_chunks), mpierr)
#else
call mpi_irecv(hh_trans_complex(1,num_hh_vecs+1), nb*local_size, MPI_COMPLEX8, nt, &
10+n-block_limits(nt), mpi_comm, ireq_hhr(num_chunks), mpierr)
#endif
call timer%stop("mpi_communication")
#else /* WITH_MPI */
! carefull non-block recv data copy must be done at wait or send
......@@ -291,13 +271,8 @@
stop
endif
#ifdef DOUBLE_PRECISION_COMPLEX
hh_gath(:,:,:) = 0._ck8
hh_send(:,:,:) = 0._ck8
#else
hh_gath(:,:,:) = 0._ck4
hh_send(:,:,:) = 0._ck4
#endif
hh_gath(:,:,:) = CONST_COMPLEX_0_0
hh_send(:,:,:) = CONST_COMPLEX_0_0
! Some counters
......@@ -354,14 +329,8 @@
print *,"tridiag_band_complex: error when allocating hv_t, tau_t "//errorMessage
stop
endif
#ifdef DOUBLE_PRECISION_COMPLEX
hv_t = 0._ck8
tau_t = 0._ck8
#else
hv_t = 0._ck4
tau_t = 0._ck4
#endif
hv_t = CONST_COMPLEX_0_0
tau_t = CONST_COMPLEX_0_0
#endif
! ---------------------------------------------------------------------------
......@@ -375,11 +344,7 @@
ab_s(1:nb+1) = ab(1:nb+1,na_s-n_off)
#ifdef WITH_MPI
call timer%start("mpi_communication")
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_isend(ab_s, nb+1, MPI_COMPLEX16, my_pe-1, 1, mpi_comm, ireq_ab, mpierr)
#else
call mpi_isend(ab_s, nb+1, MPI_COMPLEX8, my_pe-1, 1, mpi_comm, ireq_ab, mpierr)
#endif
call mpi_isend(ab_s, nb+1, MPI_COMPLEX_EXPLICIT_PRECISION, my_pe-1, 1, mpi_comm, ireq_ab, mpierr)
call timer%stop("mpi_communication")
#endif /* WITH_MPI */
endif
......@@ -395,13 +360,8 @@
#endif
if (my_pe==0) then
n = MIN(na-na_s,nb) ! number of rows to be reduced
#ifdef DOUBLE_PRECISION_COMPLEX
hv(:) = 0._ck8
tau = 0._ck8
#else
hv(:) = 0._ck4
tau = 0._ck4
#endif
hv(:) = CONST_COMPLEX_0_0
tau = CONST_COMPLEX_0_0
! Transform first column of remaining matrix
! Opposed to the real case, the last step (istep=na-1) is needed here for making
! the last subdiagonal element a real number
......@@ -411,28 +371,16 @@
vnorm2 = sum(real(ab(3:n+1,na_s-n_off),kind=rk4)**2+aimag(ab(3:n+1,na_s-n_off))**2)
#endif
if (n<2) vnorm2 = 0. ! Safety only
#ifdef DOUBLE_PRECISION_COMPLEX
call hh_transform_complex_double(ab(2,na_s-n_off),vnorm2,hf,tau)
#else
call hh_transform_complex_single(ab(2,na_s-n_off),vnorm2,hf,tau)
#endif
call hh_transform_complex_PRECISION(ab(2,na_s-n_off),vnorm2,hf,tau)
#ifdef DOUBLE_PRECISION_COMPLEX
hv(1) = 1._ck8
#else
hv(1) = 1._ck4
#endif
hv(1) = CONST_COMPLEX_1_0
hv(2:n) = ab(3:n+1,na_s-n_off)*hf
d(istep) = ab(1,na_s-n_off)
e(istep) = ab(2,na_s-n_off)
if (istep == na-1) then
d(na) = ab(1,na_s+1-n_off)
#ifdef DOUBLE__PRECISION_COMPLEX
e(na) = 0._rk8
#else
e(na) = 0._rk4
#endif
e(na) = CONST_REAL_0_0
endif
else
if (na>na_s) then
......@@ -441,11 +389,7 @@
#ifdef WITH_MPI
call timer%start("mpi_communication")
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_recv(hv, nb, MPI_COMPLEX16, my_pe-1, 2, mpi_comm, MPI_STATUS_IGNORE, mpierr)
#else
call mpi_recv(hv, nb, MPI_COMPLEX8, my_pe-1, 2, mpi_comm, MPI_STATUS_IGNORE, mpierr)
#endif
call mpi_recv(hv, nb, MPI_COMPLEX_EXPLICIT_PRECISION, my_pe-1, 2, mpi_comm, MPI_STATUS_IGNORE, mpierr)
call timer%stop("mpi_communication")
#else /* WITH_MPI */
hv(1:nb) = hv_s(1:nb)
......@@ -455,11 +399,7 @@
#ifdef WITH_MPI
call timer%start("mpi_communication")
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_recv(hv, nb, MPI_COMPLEX16, my_pe-1, 2, mpi_comm, MPI_STATUS_IGNORE, mpierr)
#else
call mpi_recv(hv, nb, MPI_COMPLEX8, my_pe-1, 2, mpi_comm, MPI_STATUS_IGNORE, mpierr)
#endif
call mpi_recv(hv, nb, MPI_COMPLEX_EXPLICIT_PRECISION, my_pe-1, 2, mpi_comm, MPI_STATUS_IGNORE, mpierr)
call timer%stop("mpi_communication")
#else /* WITH_MPI */
hv(1:nb) = hv_s(1:nb)
......@@ -467,11 +407,7 @@
#endif /* WITH_OPENMP */
tau = hv(1)
#ifdef DOUBLE_PRECISION_COMPLEX
hv(1) = 1._ck8
#else
hv(1) = 1._ck4
#endif
hv(1) = CONST_COMPLEX_1_0
endif
endif
......@@ -510,11 +446,7 @@
! is completed by the next thread.
! After the first iteration it is also the place to exchange the last row
! with MPI calls
#ifdef DOUBLE_PRECISION_COMPLEX
call timer%start("OpenMP parallel_double")
#else
call timer%start("OpenMP parallel_single")
#endif
!$omp parallel do private(my_thread, my_block_s, my_block_e, iblk, ns, ne, hv, tau, &
!$omp& nc, nr, hs, hd, vnorm2, hf, x, h, i), schedule(static,1), num_threads(max_threads)
......@@ -550,35 +482,18 @@
! Note that nr>=0 implies that diagonal block is full (nc==nb)!
! Transform diagonal block
#ifdef DOUBLE_PRECISION_COMPLEX
call ZHEMV('L', nc, tau, ab(1,ns), 2*nb-1, hv, 1, (0.0_rk8, 0.0_rk8), hd, 1)
#else
call CHEMV('L', nc, tau, ab(1,ns), 2*nb-1, hv, 1, (0.0_rk4, 0.0_rk4), hd, 1)
#endif
call PRECISION_HEMV('L', nc, tau, ab(1,ns), 2*nb-1, hv, 1, CONST_COMPLEX_PAIR_0_0, hd, 1)
x = dot_product(hv(1:nc),hd(1:nc))*conjg(tau)
hd(1:nc) = hd(1:nc) - 0.5*x*hv(1:nc)
#ifdef DOUBLE_PRECISION_COMPLEX
call ZHER2('L', nc, (-1.0_rk8, 0.0_rk8), hd, 1, hv, 1, ab(1,ns), 2*nb-1)
#else
call CHER2('L', nc, (-1.0_rk4, 0.0_rk4), hd, 1, hv, 1, ab(1,ns), 2*nb-1)
#endif
call PRECISION_HER2('L', nc, CONST_COMPLEX_PAIR_NEGATIVE_1_0, hd, 1, hv, 1, ab(1,ns), 2*nb-1)
#ifdef DOUBLE_PRECISION_COMPLEX
hv_t(:,my_thread) = 0._ck8
tau_t(my_thread) = 0._ck8
#else
hv_t(:,my_thread) = 0._ck4
tau_t(my_thread) = 0._ck4
#endif
hv_t(:,my_thread) = CONST_COMPLEX_0_0
tau_t(my_thread) = CONST_COMPLEX_0_0
if (nr<=0) cycle ! No subdiagonal block present any more
! Transform subdiagonal block
#ifdef DOUBLE_PRECISION_COMPLEX
call ZGEMV('N', nr, nb, tau, ab(nb+1,ns), 2*nb-1, hv, 1, (0.0_rk8,0.0_rk8), hs, 1)
#else
call CGEMV('N', nr, nb, tau, ab(nb+1,ns), 2*nb-1, hv, 1, (0.0_rk4,0.0_rk4), hs, 1)
#endif
call PRECISION_GEMV('N', nr, nb, tau, ab(nb+1,ns), 2*nb-1, hv, 1, CONST_COMPLEX_PAIR_0_0, hs, 1)
if (nr>1) then
......@@ -594,29 +509,16 @@
vnorm2 = sum(real(ab(nb+2:nb+nr,ns))**2+aimag(ab(nb+2:nb+nr,ns))**2)
#endif
#ifdef DOUBLE_PRECISION_COMPLEX
call hh_transform_complex(ab(nb+1,ns),vnorm2,hf,tau_t(my_thread))
#else
call hh_transform_complex_single(ab(nb+1,ns),vnorm2,hf,tau_t(my_thread))
#endif
call hh_transform_complex_PRECISION(ab(nb+1,ns),vnorm2,hf,tau_t(my_thread))
#ifdef DOUBLE_PRECISION_COMPLEX
hv_t(1 ,my_thread) = 1._ck8
hv_t(2:nr,my_thread) = ab(nb+2:nb+nr,ns)*hf
ab(nb+2:,ns) = 0._ck8
#else
hv_t(1 ,my_thread) = 1._ck4
hv_t(1 ,my_thread) = CONST_COMPLEX_1_0
hv_t(2:nr,my_thread) = ab(nb+2:nb+nr,ns)*hf
ab(nb+2:,ns) = 0._ck4
#endif
ab(nb+2:,ns) = CONST_COMPLEX_0_0
! update subdiagonal block for old and new Householder transformation
! This way we can use a nonsymmetric rank 2 update which is (hopefully) faster
#ifdef DOUBLE_PRECISION_COMPLEX
call ZGEMV('C', nr, nb-1, tau_t(my_thread), ab(nb,ns+1), 2*nb-1, hv_t(1,my_thread), 1, (0.0_rk8,0.0_rk8), h(2), 1)
#else
call CGEMV('C', nr, nb-1, tau_t(my_thread), ab(nb,ns+1), 2*nb-1, hv_t(1,my_thread), 1, (0.0_rk4,0.0_rk4), h(2), 1)
#endif
call PRECISION_GEMV('C', nr, nb-1, tau_t(my_thread), ab(nb,ns+1), 2*nb-1, hv_t(1,my_thread), &
1, CONST_COMPLEX_PAIR_0_0, h(2), 1)
x = dot_product(hs(1:nr),hv_t(1:nr,my_thread))*tau_t(my_thread)
h(2:nb) = h(2:nb) - x*hv(2:nb)
! Unfortunately there is no BLAS routine like DSYR2 for a nonsymmetric rank 2 update ("DGER2")
......@@ -633,23 +535,14 @@
ab(2+nb-i,i+ns-1) = ab(2+nb-i,i+ns-1) - hs(1)*conjg(hv(i))
enddo
! For safety: there is one remaining dummy transformation (but tau is 0 anyways)
#ifdef DOUBLE_PRECISION_COMPLEX
hv_t(1,my_thread) = 1._ck8
#else
hv_t(1,my_thread) = 1._ck4
#endif
hv_t(1,my_thread) = CONST_COMPLEX_1_0
endif
enddo
enddo ! my_thread
!$omp end parallel do
#ifdef DOUBLE_PRECISION_COMPLEX
call timer%stop("OpenMP parallel_double")
#else
call timer%stop("OpenMP parallel_single")
#endif
if (iter==1) then
! We are at the end of the first block
......@@ -665,14 +558,8 @@
ab_s(1:nb+1) = ab(1:nb+1,na_s-n_off)
#ifdef WITH_MPI
call timer%start("mpi_communication")
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_isend(ab_s, nb+1, MPI_COMPLEX16, my_pe-1, 1, mpi_comm, ireq_ab, mpierr)
#else
call mpi_isend(ab_s, nb+1, MPI_COMPLEX8, my_pe-1, 1, mpi_comm, ireq_ab, mpierr)
#endif
call mpi_isend(ab_s, nb+1, MPI_COMPLEX_EXPLICIT_PRECISION, my_pe-1, 1, mpi_comm, ireq_ab, mpierr)
call timer%stop("mpi_communication")
#endif /* WITH_MPI */
endif
......@@ -682,12 +569,7 @@
#ifdef WITH_MPI
call timer%start("mpi_communication")
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_recv(ab(1,ne-n_off), nb+1, MPI_COMPLEX16, my_pe+1, 1, mpi_comm, MPI_STATUS_IGNORE, mpierr)
#else
call mpi_recv(ab(1,ne-n_off), nb+1, MPI_COMPLEX8, my_pe+1, 1, mpi_comm, MPI_STATUS_IGNORE, mpierr)
#endif
call mpi_recv(ab(1,ne-n_off), nb+1, MPI_COMPLEX_EXPLICIT_PRECISION, my_pe+1, 1, mpi_comm, MPI_STATUS_IGNORE, mpierr)
call timer%stop("mpi_communication")
#else /* WITH_MPI */
......@@ -705,21 +587,14 @@
#ifdef WITH_MPI
call timer%start("mpi_communication")
call mpi_wait(ireq_hv, MPI_STATUS_IGNORE,mpierr)
call timer%stop("mpi_communication")
#endif
hv_s(1) = tau_t(max_threads)
hv_s(2:) = hv_t(2:,max_threads)
#ifdef WITH_MPI
call timer%start("mpi_communication")
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_isend(hv_s, nb, MPI_COMPLEX16, my_pe+1, 2, mpi_comm, ireq_hv, mpierr)
#else
call mpi_isend(hv_s, nb, MPI_COMPLEX8, my_pe+1, 2, mpi_comm, ireq_hv, mpierr)
#endif
call mpi_isend(hv_s, nb, MPI_COMPLEX_EXPLICIT_PRECISION, my_pe+1, 2, mpi_comm, ireq_hv, mpierr)
call timer%stop("mpi_communication")
#endif /* WITH_MPI */
......@@ -778,15 +653,9 @@
! Send to destination
#ifdef WITH_MPI
call timer%start("mpi_communication")
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_isend(hh_send(1,1,iblk), nb*hh_cnt(iblk), MPI_COMPLEX16, &
call mpi_isend(hh_send(1,1,iblk), nb*hh_cnt(iblk), MPI_COMPLEX_EXPLICIT_PRECISION, &
global_id(hh_dst(iblk), mod(iblk+block_limits(my_pe)-1,np_cols)), &
10+iblk, mpi_comm, ireq_hhs(iblk), mpierr)
#else
call mpi_isend(hh_send(1,1,iblk), nb*hh_cnt(iblk), MPI_COMPLEX8, &
global_id(hh_dst(iblk), mod(iblk+block_limits(my_pe)-1,np_cols)), &
10+iblk, mpi_comm, ireq_hhs(iblk), mpierr)
#endif
call timer%stop("mpi_communication")
#else /* WITH_MPI */
startAddr = startAddr - hh_cnt(iblk)
......@@ -817,22 +686,21 @@
! First do the matrix multiplications without last column ...
! Diagonal block, the contribution of the last element is added below
#ifdef DOUBLE_PRECISION_COMPLEX
ab(1,ne) = 0._ck8
call ZHEMV('L', nc, tau, ab(1,ns), 2*nb-1, hv, 1,(0.0_rk8,0.0_rk8),hd,1)
ab(1,ne) = CONST_COMPLEX_0_0
call PRECISION_HEMV('L', nc, tau, ab(1,ns), 2*nb-1, hv, 1,CONST_COMPLEX_PAIR_0_0,hd,1)
! Subdiagonal block
if (nr>0) call ZGEMV('N', nr, nb-1, tau, ab(nb+1,ns), 2*nb-1, hv, 1,(0.0_rk8,0.0_rk8),hs,1)
if (nr>0) call PRECISION_GEMV('N', nr, nb-1, tau, ab(nb+1,ns), 2*nb-1, hv, 1,CONST_COMPLEX_PAIR_0_0,hs,1)
! ... then request last column ...
! ... then request last column ...
#ifdef WITH_MPI
call timer%start("mpi_communication")
#ifdef WITH_OPENMP
call mpi_recv(ab(1,ne), nb+1, MPI_COMPLEX16, my_pe+1, 1, mpi_comm, MPI_STATUS_IGNORE, mpierr)
call mpi_recv(ab(1,ne), nb+1, MPI_COMPLEX_EXPLICIT_PRECISION, my_pe+1, 1, mpi_comm, MPI_STATUS_IGNORE, mpierr)
#else
call mpi_recv(ab(1,ne), nb+1, MPI_COMPLEX16, my_pe+1, 1, mpi_comm, MPI_STATUS_IGNORE, mpierr)
call mpi_recv(ab(1,ne), nb+1, MPI_COMPLEX_EXPLICIT_PRECISION, my_pe+1, 1, mpi_comm, MPI_STATUS_IGNORE, mpierr)
#endif
call timer%stop("mpi_communication")
#else /* WITH_MPI */
......@@ -845,42 +713,10 @@
else
! Normal matrix multiply
call ZHEMV('L', nc, tau, ab(1,ns), 2*nb-1, hv, 1, (0.0_rk8,0.0_rk8), hd, 1)
if (nr>0) call ZGEMV('N', nr, nb, tau, ab(nb+1,ns), 2*nb-1, hv, 1, (0.0_rk8,0.0_rk8), hs, 1)
call PRECISION_HEMV('L', nc, tau, ab(1,ns), 2*nb-1, hv, 1, CONST_COMPLEX_PAIR_0_0, hd, 1)
if (nr>0) call PRECISION_GEMV('N', nr, nb, tau, ab(nb+1,ns), 2*nb-1, hv, 1, CONST_COMPLEX_PAIR_0_0, hs, 1)
endif
#else /* DOUBLE_PRECISION_COMPLEX */
ab(1,ne) = 0._ck4
call CHEMV('L', nc, tau, ab(1,ns), 2*nb-1, hv, 1,(0.0_rk4,0.0_rk4),hd,1)
! Subdiagonal block
if (nr>0) call CGEMV('N', nr, nb-1, tau, ab(nb+1,ns), 2*nb-1, hv, 1,(0.0_rk4,0.0_rk4),hs,1)
! ... then request last column ...
#ifdef WITH_MPI
call timer%start("mpi_communication")
#ifdef WITH_OPENMP
call mpi_recv(ab(1,ne), nb+1, MPI_COMPLEX8, my_pe+1, 1, mpi_comm, MPI_STATUS_IGNORE, mpierr)
#else
call mpi_recv(ab(1,ne), nb+1, MPI_COMPLEX8, my_pe+1, 1, mpi_comm, MPI_STATUS_IGNORE, mpierr)
#endif
call timer%stop("mpi_communication")
#else /* WITH_MPI */
ab(1:nb+1,ne) = ab_s(1:nb+1)
#endif /* WITH_MPI */
! ... and complete the result
hs(1:nr) = hs(1:nr) + ab(2:nr+1,ne)*tau*hv(nb)
hd(nb) = hd(nb) + ab(1,ne)*hv(nb)*tau
else
! Normal matrix multiply
call CHEMV('L', nc, tau, ab(1,ns), 2*nb-1, hv, 1, (0.0_rk4,0.0_rk4), hd, 1)
if (nr>0) call CGEMV('N', nr, nb, tau, ab(nb+1,ns), 2*nb-1, hv, 1, (0.0_rk4,0.0_rk4), hs, 1)
endif
#endif /* DOUBLE_PRECISION_COMPLEX */
! Calculate first column of subdiagonal block and calculate new
! Householder transformation for this column
......@@ -906,21 +742,11 @@
vnorm2 = sum(real(ab(nb+2:nb+nr,ns),kind=rk4)**2+aimag(ab(nb+2:nb+nr,ns))**2)
#endif
#ifdef DOUBLE_PRECISION_COMPLEX
call hh_transform_complex_double(ab(nb+1,ns),vnorm2,hf,tau_new)
#else
call hh_transform_complex_single(ab(nb+1,ns),vnorm2,hf,tau_new)
#endif
call hh_transform_complex_PRECISION(ab(nb+1,ns),vnorm2,hf,tau_new)
#ifdef DOUBLE_PRECISION_COMPLEX
hv_new(1) = 1._ck8
hv_new(2:nr) = ab(nb+2:nb+nr,ns)*hf
ab(nb+2:,ns) = 0._ck8
#else
hv_new(1) = 1._ck4
hv_new(1) = CONST_COMPLEX_1_0
hv_new(2:nr) = ab(nb+2:nb+nr,ns)*hf
ab(nb+2:,ns) = 0._ck4
#endif
ab(nb+2:,ns) = CONST_COMPLEX_0_0
endif
! ... and send it away immediatly if this is the last block
......@@ -940,14 +766,8 @@
hv_s(2:) = hv_new(2:)
#ifdef WITH_MPI
call timer%start("mpi_communication")
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_isend(hv_s, nb, MPI_COMPLEX16, my_pe+1, 2 ,mpi_comm, ireq_hv, mpierr)
#else
call mpi_isend(hv_s, nb, MPI_COMPLEX8, my_pe+1, 2 ,mpi_comm, ireq_hv, mpierr)
#endif
call mpi_isend(hv_s, nb, MPI_COMPLEX_EXPLICIT_PRECISION, my_pe+1, 2 ,mpi_comm, ireq_hv, mpierr)
call timer%stop("mpi_communication")
#endif /* WITH_MPI */
endif
......@@ -956,11 +776,7 @@
! Transform diagonal block
x = dot_product(hv(1:nc),hd(1:nc))*conjg(tau)
#ifdef DOUBLE_PRECISION_COMPLEX
hd(1:nc) = hd(1:nc) - 0.5_rk8*x*hv(1:nc)
#else
hd(1:nc) = hd(1:nc) - 0.5_rk4*x*hv(1:nc)
#endif
hd(1:nc) = hd(1:nc) - CONST_REAL_0_5*x*hv(1:nc)
!#ifdef WITH_GPU_VERSION
! istat = cuda_memcpy2d((ab_dev + (ns-1)*2*nb*size_of_complex_datatype), 2*nb*size_of_complex_datatype,loc(a(1,ns)), 2*nb*size_of_complex_datatype, 2*size_of_complex_datatype , &
......@@ -1000,11 +816,7 @@
ab_s(1:nb+1) = ab(1:nb+1,ns)
#ifdef WITH_MPI
call timer%start("mpi_communication")
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_isend(ab_s, nb+1, MPI_COMPLEX16, my_pe-1, 1, mpi_comm, ireq_ab, mpierr)
#else
call mpi_isend(ab_s, nb+1, MPI_COMPLEX8, my_pe-1, 1, mpi_comm, ireq_ab, mpierr)
#endif
call mpi_isend(ab_s, nb+1, MPI_COMPLEX_EXPLICIT_PRECISION, my_pe-1, 1, mpi_comm, ireq_ab, mpierr)
call timer%stop("mpi_communication")
#endif /* WITH_MPI */
......@@ -1012,26 +824,18 @@
if (nc>1) then
!#ifdef WITH_GPU_VERSION
! call cublas_ZHER2( 'L',nc -1,(-1.d0,0.d0), hd_dev + 1*16, 1, hv_dev +1*16, 1 , ab_dev + (ns*2*nb )*16, 2*nb-1)
! call cublas_PRECISION_HER2( 'L',nc -1,(-1.d0,0.d0), hd_dev + 1*16, 1, hv_dev +1*16, 1 , ab_dev + (ns*2*nb )*16, 2*nb-1)
!#else
#ifdef DOUBLE_PRECISION_COMPLEX
call ZHER2('L', nc-1, (-1.0_rk8,0.0_rk8), hd(2), 1, hv(2), 1, ab(1,ns+1), 2*nb-1)
#else
call CHER2('L', nc-1, (-1.0_rk4,0.0_rk4), hd(2), 1, hv(2), 1, ab(1,ns+1), 2*nb-1)
#endif
call PRECISION_HER2('L', nc-1, CONST_COMPLEX_PAIR_NEGATIVE_1_0, hd(2), 1, hv(2), 1, ab(1,ns+1), 2*nb-1)
!#endif
endif
else
! No need to send, just a rank-2 update
!#ifdef WITH_GPU_VERSION
! call cublas_ZHER2( 'L',nc ,(-1.d0,0.d0), hd_dev, 1, hv_dev, 1 , ab_dev + ((ns-1)*2*nb )*16, 2*nb-1)
! call cublas_PRECISION_HER2( 'L',nc ,(-1.d0,0.d0), hd_dev, 1, hv_dev, 1 , ab_dev + ((ns-1)*2*nb )*16, 2*nb-1)
!#else
#ifdef DOUBLE_PRECISION_COMPLEX
call ZHER2('L', nc, (-1.0_rk8,0.0_rk8), hd, 1, hv, 1, ab(1,ns), 2*nb-1)
#else
call CHER2('L', nc, (-1.0_rk4,0.0_rk4), hd, 1, hv, 1, ab(1,ns), 2*nb-1)
#endif
call PRECISION_HER2('L', nc, CONST_COMPLEX_PAIR_NEGATIVE_1_0, hd, 1, hv, 1, ab(1,ns), 2*nb-1)
!#endif
endif
......@@ -1056,7 +860,7 @@
! istat = cuda_memcpy(h_dev,loc(h),nb*size_of_complex_datatype,cudaMemcpyHostToDevice)
! if (istat .ne. 0) print *,"cuda memcpy failed h_dev", istat
!
! call cublas_ZGEMV('C',nr,nb-1,tau_new,ab_dev + (nb-1 + ns *2*nb)*16,2*nb-1,hv_new_dev,1,(0.d0,0.d0),h_dev + 1* 16,1)
! call cublas_PRECISION_GEMV('C',nr,nb-1,tau_new,ab_dev + (nb-1 + ns *2*nb)*16,2*nb-1,hv_new_dev,1,(0.d0,0.d0),h_dev + 1* 16,1)
!
! istat = cuda_memcpy(tau_new_dev,loc(tau_new),1*size_of_complex_datatype,cudaMemcpyHostToDevice)
! if (istat .ne. 0) print *,"cuda memcpy failed tau_new_dev", istat
......@@ -1070,11 +874,8 @@
! istat =cuda_memcpy(loc(h),h_dev,nb*size_of_complex_datatype,cudaMemcpyDeviceToHost)
! if (istat .ne. 0) print *, " cuda memcpy failed h ", istat
!#else
#ifdef DOUBLE_PRECISION_COMPLEX
call ZGEMV('C', nr, nb-1, tau_new, ab(nb,ns+1), 2*nb-1, hv_new, 1, (0.0_rk8, 0.0_rk8), h(2), 1)
#else
call CGEMV('C', nr, nb-1, tau_new, ab(nb,ns+1), 2*nb-1, hv_new, 1, (0.0_rk4, 0.0_rk4), h(2), 1)
#endif
call PRECISION_GEMV('C', nr, nb-1, tau_new, ab(nb,ns+1), 2*nb-1, hv_new, 1, CONST_COMPLEX_PAIR_0_0, h(2), 1)
x = dot_product(hs(1:nr),hv_new(1:nr))*tau_new
h(2:nb) = h(2:nb) - x*hv(2:nb)
! Unfortunately there is no BLAS routine like DSYR2 for a nonsymmetric rank 2 update
......@@ -1127,15 +928,9 @@
! Send to destination
#ifdef WITH_MPI
call timer%start("mpi_communication")
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_isend(hh_send(1,1,iblk), nb*hh_cnt(iblk), mpi_complex16, &
global_id(hh_dst(iblk), mod(iblk+block_limits(my_pe)-1, np_cols)), &
10+iblk, mpi_comm, ireq_hhs(iblk), mpierr)
#else
call mpi_isend(hh_send(1,1,iblk), nb*hh_cnt(iblk), mpi_complex8, &
call mpi_isend(hh_send(1,1,iblk), nb*hh_cnt(iblk), MPI_COMPLEX_EXPLICIT_PRECISION, &
global_id(hh_dst(iblk), mod(iblk+block_limits(my_pe)-1, np_cols)), &
10+iblk, mpi_comm, ireq_hhs(iblk), mpierr)
#endif
call timer%stop("mpi_communication")
#else /* WITH_MPI */
startAddr = startAddr - hh_cnt(iblk)
......@@ -1285,11 +1080,7 @@
! endif
!
!#endif
#ifdef DOUBLE_PRECISION_COMPLEX
call timer%stop("tridiag_band_complex_double")
#else
call timer%stop("tridiag_band_complex_single")
#endif
call timer%stop("tridiag_band_complex_PRECISION")
!#ifdef WITH_GPU_VERSION
! contains
......@@ -1323,10 +1114,4 @@
! call launch_double_hh_transform_2(ab_dev,hd_dev,hv_dev,nc , ns, nb)
! end subroutine
!#endif
#ifdef DOUBLE_PRECISION_COMPLEX
end subroutine tridiag_band_complex_double ! has to be checked for GPU
#else
end subroutine tridiag_band_complex_single ! has to be checked for GPU
#endif
end subroutine tridiag_band_complex_PRECISION ! has to be checked for GPU
......@@ -54,6 +54,8 @@
#undef PRECISION_SYRK
#undef PRECISION_SYMV
#undef PRECISION_SYMM
#undef PRECISION_HEMV
#undef PRECISION_HER2
#undef PRECISION_SYR2
#undef PRECISION_SYR2K
#undef PRECISION_GEQRF
......@@ -77,7 +79,12 @@
#undef PRECISION_REAL
#undef CONST_REAL_0_0
#undef CONST_REAL_1_0
#undef CONST_REAL_0_5
#undef CONST_COMPLEX_PAIR_0_0
#undef CONST_COMPLEX_PAIR_1_0
#undef CONST_COMPLEX_PAIR_NEGATIVE_1_0
#undef CONST_COMPLEX_0_0
#undef CONST_COMPLEX_1_0
#undef size_of_PRECISION_complex
#define elpa_transpose_vectors_complex_PRECISION elpa_transpose_vectors_complex_double
#define elpa_reduce_add_vectors_complex_PRECISION elpa_reduce_add_vectors_complex_double
......@@ -134,6 +141,8 @@
#define PRECISION_SYRK ZSYRK
#define PRECISION_SYMV ZSYMV
#define PRECISION_SYMM ZSYMM
#define PRECISION_HEMV ZHEMV
#define PRECISION_HER2 ZHER2
#define PRECISION_SYR2 ZSYR2
#define PRECISION_SYR2K ZSYR2K
#define PRECISION_GEQRF ZGEQRF
......@@ -157,7 +166,12 @@
#define PRECISION_REAL DREAL
#define CONST_REAL_0_0 0.0_rk8
#define CONST_REAL_1_0 1.0_rk8
#define CONST_REAL_0_5 0.5_rk8
#define CONST_COMPLEX_PAIR_0_0 (0.0_rk8,0.0_rk8)
#define CONST_COMPLEX_PAIR_1_0 (1.0_rk8,0.0_rk8)
#define CONST_COMPLEX_PAIR_NEGATIVE_1_0 (-1.0_rk8,0.0_rk8)