Commit 7c19d880 authored by Pavel Kus's avatar Pavel Kus
Browse files

single/double unification of elpa1_compute_complex_template.X90

Conflicts:
	src/elpa1_compute_complex_template.X90
parent c2037d4e
......@@ -56,7 +56,8 @@ EXTRA_libelpa@SUFFIX@_private_la_DEPENDENCIES = \
src/elpa2_kernels/elpa2_kernels_complex_template.X90 \
src/elpa2_kernels/elpa2_kernels_simple_template.X90 \
src/redist_band.X90 \
src/precision_macros.h
src/precision_macros.h \
src/precision_macros_complex.h
lib_LTLIBRARIES = libelpa@SUFFIX@.la
libelpa@SUFFIX@_la_LINK = $(FCLINK) $(AM_LDFLAGS) -version-info $(ELPA_SO_VERSION)
......@@ -888,6 +889,7 @@ EXTRA_DIST = \
src/elpa2_compute_real_template.X90 \
src/elpa2_compute_complex_template.X90 \
src/precision_macros.h \
src/precision_macros_complex.h \
src/elpa2_kernels/elpa2_kernels_real_template.X90 \
src/elpa2_kernels/elpa2_kernels_complex_template.X90 \
src/elpa2_kernels/elpa2_kernels_simple_template.X90 \
......
#!/usr/bin/python
simple_tokens = ["tridiag_complex_PRECISION",
"trans_ev_complex_PRECISION",
"solve_complex_PRECISION",
"hh_transform_complex_PRECISION",
"elpa_transpose_vectors_complex_PRECISION",
"elpa_reduce_add_vectors_complex_PRECISION",
]
blas_tokens = ["PRECISION_GEMV",
"PRECISION_TRMV",
"PRECISION_GEMM",
"PRECISION_TRMM",
"PRECISION_HERK",
]
explicit_tokens = [("PRECISION_SUFFIX", "\"_double\"", "\"_single\""),
("MPI_COMPLEX_PRECISION", "MPI_DOUBLE_COMPLEX", "MPI_COMPLEX"),
("MPI_REAL_PRECISION", "MPI_REAL8", "MPI_REAL4"),
("KIND_PRECISION", "rk8", "rk4"),
("PRECISION_CMPLX", "DCMPLX", "CMPLX"),
("PRECISION_IMAG", "DIMAG", "AIMAG"),
("CONST_REAL_0_0", "0.0_rk8", "0.0_rk4"),
("CONST_REAL_1_0", "1.0_rk8", "1.0_rk4"),
]
print "#ifdef DOUBLE_PRECISION_COMPLEX"
for token in simple_tokens:
print "#define ", token, token.replace("PRECISION", "double")
for token in blas_tokens:
print "#define ", token, token.replace("PRECISION_", "Z")
for token in explicit_tokens:
print "#define ", token[0], token[1]
print "#else"
for token in simple_tokens:
print "#undef ", token
for token in blas_tokens:
print "#undef ", token
for token in explicit_tokens:
print "#undef ", token[0]
for token in simple_tokens:
print "#define ", token, token.replace("PRECISION", "single")
for token in blas_tokens:
print "#define ", token, token.replace("PRECISION_", "C")
for token in explicit_tokens:
print "#define ", token[0], token[2]
print "#endif"
......@@ -52,11 +52,10 @@
! distributed along with the original code in the file "COPYING".
#endif
#ifdef DOUBLE_PRECISION_COMPLEX
subroutine tridiag_complex_double(na, a, lda, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, d, e, tau)
#else
subroutine tridiag_complex_single(na, a, lda, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, d, e, tau)
#endif
#include "precision_macros_complex.h"
subroutine tridiag_complex_PRECISION(na, a, lda, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, d, e, tau)
!-------------------------------------------------------------------------------
! tridiag_complex: Reduces a distributed hermitian matrix to tridiagonal form
! (like Scalapack Routine PZHETRD)
......@@ -88,6 +87,8 @@
!-------------------------------------------------------------------------------
#ifdef HAVE_DETAILED_TIMINGS
use timings
#else
use timings_dummy
#endif
use precision
implicit none
......@@ -129,13 +130,7 @@
integer(kind=ik) :: istat
character(200) :: errorMessage
#ifdef HAVE_DETAILED_TIMINGS
#ifdef DOUBLE_PRECISION_COMPLEX
call timer%start("tridiag_complex_double")
#else
call timer%start("tridiag_complex_single")
#endif
#endif
call timer%start("tridiag_complex" // PRECISION_SUFFIX)
#ifdef HAVE_DETAILED_TIMINGS
call timer%start("mpi_communication")
#endif
......@@ -257,13 +252,8 @@
vr(1:l_rows) = a(1:l_rows,l_cols+1)
if (nstor>0 .and. l_rows>0) then
aux(1:2*nstor) = conjg(uvc(l_cols+1,1:2*nstor))
#ifdef DOUBLE_PRECISION_COMPLEX
call ZGEMV('N', l_rows, 2*nstor, CONE, vur, ubound(vur,dim=1), &
aux, 1, CONE, vr, 1)
#else
call CGEMV('N', l_rows, 2*nstor, CONE, vur, ubound(vur,dim=1), &
call PRECISION_GEMV('N', l_rows, 2*nstor, CONE, vur, ubound(vur,dim=1), &
aux, 1, CONE, vr, 1)
#endif
endif
if (my_prow==prow(istep-1, nblk, np_rows)) then
......@@ -277,13 +267,7 @@
#ifdef HAVE_DETAILED_TIMINGS
call timer%start("mpi_communication")
#endif
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_allreduce(aux1, aux2, 2, MPI_DOUBLE_COMPLEX, MPI_SUM, mpi_comm_rows, mpierr)
#else
call mpi_allreduce(aux1, aux2, 2, MPI_COMPLEX, MPI_SUM, mpi_comm_rows, mpierr)
#endif
call mpi_allreduce(aux1, aux2, 2, MPI_COMPLEX_PRECISION, MPI_SUM, mpi_comm_rows, mpierr)
vnorm2 = aux2(1)
vrl = aux2(2)
#ifdef HAVE_DETAILED_TIMINGS
......@@ -302,11 +286,7 @@
! vrl = aux2(2)
! Householder transformation
#ifdef DOUBLE_PRECISION_COMPLEX
call hh_transform_complex_double(vrl, vnorm2, xf, tau(istep))
#else
call hh_transform_complex_single(vrl, vnorm2, xf, tau(istep))
#endif
call hh_transform_complex_PRECISION(vrl, vnorm2, xf, tau(istep))
! Scale vr and store Householder vector for back transformation
vr(1:l_rows) = vr(1:l_rows) * xf
......@@ -325,12 +305,7 @@
#ifdef HAVE_DETAILED_TIMINGS
call timer%start("mpi_communication")
#endif
#ifdef DOUBLE_PRECISION_COMPLEX
call MPI_Bcast(vr, l_rows+1, MPI_DOUBLE_COMPLEX, pcol(istep, nblk, np_cols), mpi_comm_cols, mpierr)
#else
call MPI_Bcast(vr, l_rows+1, MPI_COMPLEX, pcol(istep, nblk, np_cols), mpi_comm_cols, mpierr)
#endif
call MPI_Bcast(vr, l_rows+1, MPI_COMPLEX_PRECISION, pcol(istep, nblk, np_cols), mpi_comm_cols, mpierr)
#ifdef HAVE_DETAILED_TIMINGS
call timer%stop("mpi_communication")
#endif
......@@ -343,15 +318,9 @@
! call elpa_transpose_vectors (vr, 2*ubound(vr,dim=1), mpi_comm_rows, &
! vc, 2*ubound(vc,dim=1), mpi_comm_cols, &
! 1, 2*(istep-1), 1, 2*nblk)
#ifdef DOUBLE_PRECISION_COMPLEX
call elpa_transpose_vectors_complex_double (vr, ubound(vr,dim=1), mpi_comm_rows, &
call elpa_transpose_vectors_complex_PRECISION (vr, ubound(vr,dim=1), mpi_comm_rows, &
vc, ubound(vc,dim=1), mpi_comm_cols, &
1, (istep-1), 1, nblk)
#else
call elpa_transpose_vectors_complex_single (vr, ubound(vr,dim=1), mpi_comm_rows, &
vc, ubound(vc,dim=1), mpi_comm_cols, &
1, (istep-1), 1, nblk)
#endif
! Calculate u = (A + VU**T + UV**T)*v
! For cache efficiency, we use only the upper half of the matrix tiles for this,
......@@ -362,14 +331,7 @@
if (l_rows>0 .and. l_cols>0) then
#ifdef WITH_OPENMP
#ifdef HAVE_DETAILED_TIMINGS
#ifdef DOUBLE_PRECISION_COMPLEX
call timer%start("OpenMP parallel_double")
#else
call timer%start("OpenMP parallel_single")
#endif
#endif
call timer%start("OpenMP parallel" // PRECISION_SUFFIX)
!$OMP PARALLEL PRIVATE(my_thread,n_threads,n_iter,i,lcs,lce,j,lrs,lre)
......@@ -392,46 +354,24 @@
if (lre<lrs) cycle
#ifdef WITH_OPENMP
if (mod(n_iter,n_threads) == my_thread) then
#ifdef DOUBLE_PRECISION_COMPLEX
call ZGEMV('C', lre-lrs+1 ,lce-lcs+1, CONE, a(lrs,lcs), lda, vr(lrs), 1, CONE, uc_p(lcs,my_thread), 1)
call PRECISION_GEMV('C', lre-lrs+1 ,lce-lcs+1, CONE, a(lrs,lcs), lda, vr(lrs), 1, CONE, uc_p(lcs,my_thread), 1)
if (i/=j) then
call ZGEMV('N', lre-lrs+1, lce-lcs+1, CONE, a(lrs,lcs), lda, vc(lcs), 1, CONE, ur_p(lrs,my_thread), 1)
call PRECISION_GEMV('N', lre-lrs+1, lce-lcs+1, CONE, a(lrs,lcs), lda, vc(lcs), 1, CONE, ur_p(lrs,my_thread), 1)
endif
#else
call CGEMV('C', lre-lrs+1 ,lce-lcs+1, CONE, a(lrs,lcs), lda, vr(lrs), 1, CONE, uc_p(lcs,my_thread), 1)
if (i/=j) then
call CGEMV('N', lre-lrs+1, lce-lcs+1, CONE, a(lrs,lcs), lda, vc(lcs), 1, CONE, ur_p(lrs,my_thread), 1)
endif
#endif
endif
n_iter = n_iter+1
#else /* WITH_OPENMP */
#ifdef DOUBLE_PRECISION_COMPLEX
call ZGEMV('C', lre-lrs+1, lce-lcs+1, CONE, a(lrs,lcs), lda, vr(lrs), 1, CONE, uc(lcs), 1)
call PRECISION_GEMV('C', lre-lrs+1, lce-lcs+1, CONE, a(lrs,lcs), lda, vr(lrs), 1, CONE, uc(lcs), 1)
if (i/=j) then
call ZGEMV('N', lre-lrs+1, lce-lcs+1, CONE, a(lrs,lcs), lda, vc(lcs), 1, CONE, ur(lrs), 1)
call PRECISION_GEMV('N', lre-lrs+1, lce-lcs+1, CONE, a(lrs,lcs), lda, vc(lcs), 1, CONE, ur(lrs), 1)
endif
#else
call CGEMV('C', lre-lrs+1, lce-lcs+1, CONE, a(lrs,lcs), lda, vr(lrs), 1, CONE, uc(lcs), 1)
if (i/=j) then
call CGEMV('N', lre-lrs+1, lce-lcs+1, CONE, a(lrs,lcs), lda, vc(lcs), 1, CONE, ur(lrs), 1)
endif
#endif
#endif /* WITH_OPENMP */
enddo
enddo
#ifdef WITH_OPENMP
!$OMP END PARALLEL
#ifdef HAVE_DETAILED_TIMINGS
#ifdef DOUBLE_PRECISION_COMPLEX
call timer%stop("OpenMP parallel_double")
#else
call timer%stop("OpenMP parallel_single")
#endif
#endif
call timer%stop("OpenMP parallel" // PRECISION_SUFFIX)
do i=0,max_threads-1
uc(1:l_cols) = uc(1:l_cols) + uc_p(1:l_cols,i)
......@@ -440,13 +380,8 @@
#endif
if (nstor>0) then
#ifdef DOUBLE_PRECISION_COMPLEX
call ZGEMV('C', l_rows, 2*nstor, CONE, vur, ubound(vur,dim=1), vr, 1, CZERO, aux, 1)
call ZGEMV('N', l_cols, 2*nstor, CONE, uvc, ubound(uvc,dim=1), aux, 1, CONE, uc, 1)
#else
call CGEMV('C', l_rows, 2*nstor, CONE, vur, ubound(vur,dim=1), vr, 1, CZERO, aux, 1)
call CGEMV('N', l_cols, 2*nstor, CONE, uvc, ubound(uvc,dim=1), aux, 1, CONE, uc, 1)
#endif
call PRECISION_GEMV('C', l_rows, 2*nstor, CONE, vur, ubound(vur,dim=1), vr, 1, CZERO, aux, 1)
call PRECISION_GEMV('N', l_cols, 2*nstor, CONE, uvc, ubound(uvc,dim=1), aux, 1, CONE, uc, 1)
endif
endif
......@@ -457,15 +392,9 @@
! global tile size is smaller than the global remaining matrix
if (tile_size < istep-1) then
#ifdef DOUBLE_PRECISION_COMPLEX
call elpa_reduce_add_vectors_COMPLEX_double (ur, ubound(ur,dim=1), mpi_comm_rows, &
call elpa_reduce_add_vectors_complex_PRECISION (ur, ubound(ur,dim=1), mpi_comm_rows, &
uc, ubound(uc,dim=1), mpi_comm_cols, &
(istep-1), 1, nblk)
#else
call elpa_reduce_add_vectors_COMPLEX_single (ur, ubound(ur,dim=1), mpi_comm_rows, &
uc, ubound(uc,dim=1), mpi_comm_cols, &
(istep-1), 1, nblk)
#endif
endif
! Sum up all the uc(:) parts, transpose uc -> ur
......@@ -476,12 +405,7 @@
#ifdef HAVE_DETAILED_TIMINGS
call timer%start("mpi_communication")
#endif
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_allreduce(tmp, uc, l_cols, MPI_DOUBLE_COMPLEX, MPI_SUM, mpi_comm_rows, mpierr)
#else
call mpi_allreduce(tmp, uc, l_cols, MPI_COMPLEX, MPI_SUM, mpi_comm_rows, mpierr)
#endif
call mpi_allreduce(tmp, uc, l_cols, MPI_COMPLEX_PRECISION, MPI_SUM, mpi_comm_rows, mpierr)
#ifdef HAVE_DETAILED_TIMINGS
call timer%stop("mpi_communication")
#endif
......@@ -494,16 +418,9 @@
! call elpa_transpose_vectors (uc, 2*ubound(uc,dim=1), mpi_comm_cols, &
! ur, 2*ubound(ur,dim=1), mpi_comm_rows, &
! 1, 2*(istep-1), 1, 2*nblk)
#ifdef DOUBLE_PRECISION_COMPLEX
call elpa_transpose_vectors_complex_double (uc, ubound(uc,dim=1), mpi_comm_cols, &
call elpa_transpose_vectors_complex_PRECISION (uc, ubound(uc,dim=1), mpi_comm_cols, &
ur, ubound(ur,dim=1), mpi_comm_rows, &
1, (istep-1), 1, nblk)
#else
call elpa_transpose_vectors_complex_single (uc, ubound(uc,dim=1), mpi_comm_cols, &
ur, ubound(ur,dim=1), mpi_comm_rows, &
1, (istep-1), 1, nblk)
#endif
! calculate u**T * v (same as v**T * (A + VU**T + UV**T) * v )
......@@ -513,12 +430,7 @@
#ifdef HAVE_DETAILED_TIMINGS
call timer%start("mpi_communication")
#endif
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_allreduce(xc, vav, 1 , MPI_DOUBLE_COMPLEX, MPI_SUM, mpi_comm_cols, mpierr)
#else
call mpi_allreduce(xc, vav, 1 , MPI_COMPLEX, MPI_SUM, mpi_comm_cols, mpierr)
#endif
call mpi_allreduce(xc, vav, 1 , MPI_COMPLEX_PRECISION, MPI_SUM, mpi_comm_cols, mpierr)
#ifdef HAVE_DETAILED_TIMINGS
call timer%stop("mpi_communication")
#endif
......@@ -551,15 +463,9 @@
lrs = 1
lre = min(l_rows,(i+1)*l_rows_tile)
if (lce<lcs .or. lre<lrs) cycle
#ifdef DOUBLE_PRECISION_COMPLEX
call ZGEMM('N', 'C', lre-lrs+1, lce-lcs+1, 2*nstor, CONE, &
call PRECISION_GEMM('N', 'C', lre-lrs+1, lce-lcs+1, 2*nstor, CONE, &
vur(lrs,1), ubound(vur,dim=1), uvc(lcs,1), ubound(uvc,dim=1), &
CONE, a(lrs,lcs), lda)
#else
call CGEMM('N', 'C', lre-lrs+1, lce-lcs+1, 2*nstor, CONE, &
vur(lrs,1), ubound(vur,dim=1), uvc(lcs,1), ubound(uvc,dim=1), &
CONE, a(lrs,lcs), lda)
#endif
enddo
nstor = 0
......@@ -580,11 +486,7 @@
if (my_prow==prow(1, nblk, np_rows)) then
! We use last l_cols value of loop above
vrl = a(1,l_cols)
#ifdef DOUBLE_PRECISION_COMPLEX
call hh_transform_complex_double(vrl, 0.0_rk8, xf, tau(2))
#else
call hh_transform_complex_single(vrl, 0.0_rk4, xf, tau(2))
#endif
call hh_transform_complex_PRECISION(vrl, CONST_REAL_0_0, xf, tau(2))
e(1) = vrl
a(1,l_cols) = 1. ! for consistency only
endif
......@@ -593,12 +495,7 @@
#ifdef HAVE_DETAILED_TIMINGS
call timer%start("mpi_communication")
#endif
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_bcast(tau(2), 1, MPI_DOUBLE_COMPLEX, prow(1, nblk, np_rows), mpi_comm_rows, mpierr)
#else
call mpi_bcast(tau(2), 1, MPI_COMPLEX, prow(1, nblk, np_rows), mpi_comm_rows, mpierr)
#endif
call mpi_bcast(tau(2), 1, MPI_COMPLEX_PRECISION, prow(1, nblk, np_rows), mpi_comm_rows, mpierr)
#ifdef HAVE_DETAILED_TIMINGS
call timer%stop("mpi_communication")
#endif
......@@ -610,12 +507,7 @@
#ifdef HAVE_DETAILED_TIMINGS
call timer%start("mpi_communication")
#endif
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_bcast(tau(2), 1, MPI_DOUBLE_COMPLEX, pcol(2, nblk, np_cols), mpi_comm_cols, mpierr)
#else
call mpi_bcast(tau(2), 1, MPI_COMPLEX, pcol(2, nblk, np_cols), mpi_comm_cols, mpierr)
#endif
call mpi_bcast(tau(2), 1, MPI_COMPLEX_PRECISION, pcol(2, nblk, np_cols), mpi_comm_cols, mpierr)
#ifdef HAVE_DETAILED_TIMINGS
call timer%stop("mpi_communication")
#endif
......@@ -642,26 +534,14 @@
#ifdef HAVE_DETAILED_TIMINGS
call timer%start("mpi_communication")
#endif
#ifdef DOUBLE_PRECISION_COMPLEX
tmpr = d
call mpi_allreduce(tmpr, d, na, MPI_REAL8, MPI_SUM, mpi_comm_rows, mpierr)
tmpr = d
call mpi_allreduce(tmpr, d, na, MPI_REAL8 ,MPI_SUM, mpi_comm_cols, mpierr)
tmpr = e
call mpi_allreduce(tmpr, e, na, MPI_REAL8, MPI_SUM, mpi_comm_rows, mpierr)
tmpr = e
call mpi_allreduce(tmpr, e, na, MPI_REAL8, MPI_SUM, mpi_comm_cols, mpierr)
#else
tmpr = d
call mpi_allreduce(tmpr, d, na, MPI_REAL4, MPI_SUM, mpi_comm_rows, mpierr)
call mpi_allreduce(tmpr, d, na, MPI_REAL_PRECISION, MPI_SUM, mpi_comm_rows, mpierr)
tmpr = d
call mpi_allreduce(tmpr, d, na, MPI_REAL4 ,MPI_SUM, mpi_comm_cols, mpierr)
call mpi_allreduce(tmpr, d, na, MPI_REAL_PRECISION ,MPI_SUM, mpi_comm_cols, mpierr)
tmpr = e
call mpi_allreduce(tmpr, e, na, MPI_REAL4, MPI_SUM, mpi_comm_rows, mpierr)
call mpi_allreduce(tmpr, e, na, MPI_REAL_PRECISION, MPI_SUM, mpi_comm_rows, mpierr)
tmpr = e
call mpi_allreduce(tmpr, e, na, MPI_REAL4, MPI_SUM, mpi_comm_cols, mpierr)
#endif
call mpi_allreduce(tmpr, e, na, MPI_REAL_PRECISION, MPI_SUM, mpi_comm_cols, mpierr)
#ifdef HAVE_DETAILED_TIMINGS
call timer%stop("mpi_communication")
#endif
......@@ -673,25 +553,11 @@
stop
endif
#ifdef HAVE_DETAILED_TIMINGS
#ifdef DOUBLE_PRECISION_COMPLEX
call timer%stop("tridiag_complex_double")
#else
call timer%stop("tridiag_complex_single")
#endif
#endif
call timer%stop("tridiag_complex" // PRECISION_SUFFIX)
#ifdef DOUBLE_PRECISION_COMPLEX
end subroutine tridiag_complex_double
#else
end subroutine tridiag_complex_single
#endif
end subroutine tridiag_complex_PRECISION
#ifdef DOUBLE_PRECISION_COMPLEX
subroutine trans_ev_complex_double(na, nqc, a, lda, tau, q, ldq, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols)
#else
subroutine trans_ev_complex_single(na, nqc, a, lda, tau, q, ldq, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols)
#endif
subroutine trans_ev_complex_PRECISION(na, nqc, a, lda, tau, q, ldq, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols)
!-------------------------------------------------------------------------------
! trans_ev_complex: Transforms the eigenvectors of a tridiagonal matrix back
! to the eigenvectors of the original matrix
......@@ -725,6 +591,8 @@
!-------------------------------------------------------------------------------
#ifdef HAVE_DETAILED_TIMINGS
use timings
#else
use timings_dummy
#endif
use precision
implicit none
......@@ -751,13 +619,8 @@
complex(kind=COMPLEX_DATATYPE), allocatable :: tmat(:,:), h1(:), h2(:)
integer(kind=ik) :: istat
character(200) :: errorMessage
#ifdef HAVE_DETAILED_TIMINGS
#ifdef DOUBLE_PRECISION_COMPLEX
call timer%start("trans_ev_complex_double")
#else
call timer%start("trans_ev_complex_single")
#endif
#endif
call timer%start("trans_ev_complex" // PRECISION_SUFFIX)
#ifdef HAVE_DETAILED_TIMINGS
call timer%start("mpi_communication")
#endif
......@@ -830,11 +693,7 @@
! In the complex case tau(2) /= 0
if (my_prow == prow(1, nblk, np_rows)) then
#ifdef DOUBLE_PRECISION_COMPLEX
q(1,1:l_cols) = q(1,1:l_cols)*((1.0_rk8,0.0_rk8)-tau(2))
#else
q(1,1:l_cols) = q(1,1:l_cols)*((1.0_rk4,0.0_rk4)-tau(2))
#endif
q(1,1:l_cols) = q(1,1:l_cols)*(CONE-tau(2))
endif
do istep=1,na,nblk
......@@ -868,11 +727,7 @@
#endif
if (nb>0) &
#ifdef DOUBLE_PRECISION_COMPLEX
call MPI_Bcast(hvb, nb, MPI_DOUBLE_COMPLEX, cur_pcol, mpi_comm_cols, mpierr)
#else
call MPI_Bcast(hvb, nb, MPI_COMPLEX, cur_pcol, mpi_comm_cols, mpierr)
#endif
call MPI_Bcast(hvb, nb, MPI_COMPLEX_PRECISION, cur_pcol, mpi_comm_cols, mpierr)
#ifdef HAVE_DETAILED_TIMINGS
call timer%stop("mpi_communication")
#endif
......@@ -894,11 +749,7 @@
tmat = 0
if (l_rows>0) &
#ifdef DOUBLE_PRECISION_COMPLEX
call zherk('U', 'C', nstor, l_rows, CONE, hvm, ubound(hvm,dim=1), CZERO, tmat, max_stored_rows)
#else
call cherk('U', 'C', nstor, l_rows, CONE, hvm, ubound(hvm,dim=1), CZERO, tmat, max_stored_rows)
#endif
call PRECISION_HERK('U', 'C', nstor, l_rows, CONE, hvm, ubound(hvm,dim=1), CZERO, tmat, max_stored_rows)
nc = 0
do n=1,nstor-1
h1(nc+1:nc+n) = tmat(1:n,n+1)
......@@ -908,12 +759,7 @@
#ifdef HAVE_DETAILED_TIMINGS
call timer%start("mpi_communication")
#endif
#ifdef DOUBLE_PRECISION_COMPLEX
if (nc>0) call mpi_allreduce(h1, h2, nc, MPI_DOUBLE_COMPLEX, MPI_SUM, mpi_comm_rows, mpierr)
#else
if (nc>0) call mpi_allreduce(h1, h2, nc, MPI_COMPLEX, MPI_SUM, mpi_comm_rows, mpierr)
#endif
if (nc>0) call mpi_allreduce(h1, h2, nc, MPI_COMPLEX_PRECISION, MPI_SUM, mpi_comm_rows, mpierr)
#ifdef HAVE_DETAILED_TIMINGS
call timer%stop("mpi_communication")
#endif
......@@ -926,11 +772,7 @@
nc = 0
tmat(1,1) = tau(ice-nstor+1)
do n=1,nstor-1
#ifdef DOUBLE_PRECISION_COMPLEX
call ztrmv('L', 'C', 'N', n, tmat, max_stored_rows, h2(nc+1),1)
#else
call ctrmv('L', 'C', 'N', n, tmat, max_stored_rows, h2(nc+1),1)
#endif
call PRECISION_TRMV('L', 'C', 'N', n, tmat, max_stored_rows, h2(nc+1),1)
tmat(n+1,1:n) = -conjg(h2(nc+1:nc+n))*tau(ice-nstor+n+1)
tmat(n+1,n+1) = tau(ice-nstor+n+1)
nc = nc+n
......@@ -939,31 +781,24 @@
! Q = Q - V * T * V**T * Q
if (l_rows>0) then
#ifdef DOUBLE_PRECISION_COMPLEX
call zgemm('C', 'N', nstor, l_cols, l_rows, CONE, hvm, ubound(hvm,dim=1), &
call PRECISION_GEMM('C', 'N', nstor, l_cols, l_rows, CONE, hvm, ubound(hvm,dim=1), &
q, ldq, CZERO, tmp1 ,nstor)
#else
call cgemm('C', 'N', nstor, l_cols, l_rows, CONE, hvm, ubound(hvm,dim=1), &
q, ldq, CZERO, tmp1 ,nstor)
#endif
else
tmp1(1:l_cols*nstor) = 0
endif
#ifdef DOUBLE_PRECISION_COMPLEX
#ifdef WITH_MPI
#ifdef HAVE_DETAILED_TIMINGS
call timer%start("mpi_communication")
#endif
call mpi_allreduce(tmp1, tmp2, nstor*l_cols, MPI_DOUBLE_COMPLEX, MPI_SUM, mpi_comm_rows, mpierr)
call mpi_allreduce(tmp1, tmp2, nstor*l_cols, MPI_COMPLEX_PRECISION, MPI_SUM, mpi_comm_rows, mpierr)
#ifdef HAVE_DETAILED_TIMINGS
call timer%stop("mpi_communication")
#endif
if (l_rows>0) then
call ztrmm('L', 'L', 'N', 'N', nstor, l_cols, CONE, tmat, max_stored_rows, tmp2, nstor)
call zgemm('N', 'N', l_rows, l_cols, nstor, -CONE, hvm, ubound(hvm,dim=1), &
call PRECISION_TRMM('L', 'L', 'N', 'N', nstor, l_cols, CONE, tmat, max_stored_rows, tmp2, nstor)
call PRECISION_GEMM('N', 'N', l_rows, l_cols, nstor, -CONE, hvm, ubound(hvm,dim=1), &
tmp2, nstor, CONE, q, ldq)
endif
......@@ -971,55 +806,19 @@
! tmp2 = tmp1
if (l_rows>0) then
call ztrmm('L', 'L', 'N', 'N', nstor, l_cols, CONE, tmat, max_stored_rows, tmp1, nstor)
call zgemm('N', 'N', l_rows, l_cols, nstor, -CONE, hvm, ubound(hvm,dim=1), &
call PRECISION_TRMM('L', 'L', 'N', 'N', nstor, l_cols, CONE, tmat, max_stored_rows, tmp1, nstor)
call PRECISION_GEMM('N', 'N', l_rows, l_cols, nstor, -CONE, hvm, ubound(hvm,dim=1), &
tmp1, nstor, CONE, q, ldq)
endif
#endif
! if (l_rows>0) then
! call ztrmm('L', 'L', 'N', 'N', nstor, l_cols, CONE, tmat, max_stored_rows, tmp2, nstor)
! call zgemm('N', 'N', l_rows, l_cols, nstor, -CONE, hvm, ubound(hvm,dim=1), &
! tmp2, nstor, CONE, q, ldq)
! endif
#else /* DOUBLE_PRECISION_COMPLEX */
#ifdef WITH_MPI
#ifdef HAVE_DETAILED_TIMINGS
call timer%start("mpi_communication")
#endif
call mpi_allreduce(tmp1, tmp2, nstor*l_cols, MPI_COMPLEX, MPI_SUM, mpi_comm_rows, mpierr)
#ifdef HAVE_DETAILED_TIMINGS
call timer%stop("mpi_communication")
#endif
if (l_rows>0) then
call ctrmm('L', 'L', 'N', 'N', nstor, l_cols, CONE, tmat, max_stored_rows, tmp2, nstor)
call cgemm('N', 'N', l_rows, l_cols, nstor, -CONE, hvm, ubound(hvm,dim=1), &
tmp2, nstor, CONE, q, ldq)
endif
#else
! tmp2 = tmp1
if (l_rows>0) then
call ctrmm('L', 'L', 'N', 'N', nstor, l_cols, CONE, tmat, max_stored_rows, tmp1, nstor)
call cgemm('N', 'N', l_rows, l_cols, nstor, -CONE, hvm, ubound(hvm,dim=1), &
tmp1, nstor, CONE, q, ldq)
endif
#endif
!
! if (l_rows>0) then
! call ctrmm('L', 'L', 'N', 'N', nstor, l_cols, CONE, tmat, max_stored_rows, tmp2, nstor)
! call cgemm('N', 'N', l_rows, l_cols, nstor, -CONE, hvm, ubound(hvm,dim=1), &
! call PRECISION_TRMM('L', 'L', 'N', 'N', nstor, l_cols, CONE, tmat, max_stored_rows, tmp2, nstor)
! call PRECISION_GEMM('N', 'N', l_rows, l_cols, nstor, -CONE, hvm, ubound(hvm,dim=1), &
! tmp2, nstor, CONE, q, ldq)
! endif
#endif /* DOUBLE_PRECISION_COMPLEX */
nstor = 0
endif
......@@ -1031,25 +830,11 @@
stop
endif
#ifdef HAVE_DETAILED_TIMINGS
#ifdef DOUBLE_PRECISION_COMPLEX
call timer%stop("trans_ev_complex_double")
#else
call timer%stop("trans_ev_complex_single")
#endif
#endif
call timer%stop("trans_ev_complex" // PRECISION_SUFFIX)
#ifdef DOUBLE_PRECISION_COMPLEX
end subroutine trans_ev_complex_double
#else
end subroutine trans_ev_complex_single
#endif
end subroutine trans_ev_complex_PRECISION
#ifdef DOUBLE_PRECISION_COMPLEX
subroutine hh_transform_complex_double(alpha, xnorm_sq, xf, tau)
#else
subroutine hh_transform_complex_single(alpha, xnorm_sq, xf, tau)
#endif
subroutine hh_transform_complex_PRECISION(alpha, xnorm_sq, xf, tau)
! Similar to LAPACK routine ZLARFP, but uses ||x||**2 instead of x(:)