Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
elpa
elpa
Commits
58c393cd
Commit
58c393cd
authored
Oct 02, 2017
by
Pavel Kus
Browse files
further real/complex unifications
in elpa2_bandred_template.F90
parent
e62e2dd8
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/elpa2/elpa2_bandred_template.F90
View file @
58c393cd
...
...
@@ -591,13 +591,7 @@
#ifdef WITH_MPI
if
(
wantDebug
)
call
obj
%
timer
%
start
(
"mpi_communication"
)
call
mpi_allreduce
(
aux1
,
aux2
,
2
,
&
#if REALCASE == 1
MPI_REAL_PRECISION
,
&
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION
,
&
#endif
call
mpi_allreduce
(
aux1
,
aux2
,
2
,
MPI_MATH_DATATYPE_PRECISION
,
&
MPI_SUM
,
mpi_comm_rows
,
mpierr
)
if
(
wantDebug
)
call
obj
%
timer
%
stop
(
"mpi_communication"
)
...
...
@@ -609,14 +603,11 @@
vrl
=
aux2
(
2
)
! Householder transformation
#if REALCASE == 1
call
hh_transform_real_
&
#endif
#if COMPLEXCASE == 1
call
hh_transform_complex_
&
#endif
call
hh_transform_
&
&
MATH_DATATYPE
&
&
_
&
&
PRECISION
&
(
obj
,
vrl
,
vnorm2
,
xf
,
tau
,
wantDebug
)
(
obj
,
vrl
,
vnorm2
,
xf
,
tau
,
wantDebug
)
! Scale vr and store Householder Vector for back transformation
vr
(
1
:
lr
)
=
vr
(
1
:
lr
)
*
xf
...
...
@@ -635,13 +626,7 @@
vr
(
lr
+1
)
=
tau
#ifdef WITH_MPI
if
(
wantDebug
)
call
obj
%
timer
%
start
(
"mpi_communication"
)
call
MPI_Bcast
(
vr
,
lr
+1
,
&
#if REALCASE == 1
MPI_REAL_PRECISION
,
&
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION
,
&
#endif
call
MPI_Bcast
(
vr
,
lr
+1
,
MPI_MATH_DATATYPE_PRECISION
,
&
cur_pcol
,
mpi_comm_cols
,
mpierr
)
if
(
wantDebug
)
call
obj
%
timer
%
stop
(
"mpi_communication"
)
...
...
@@ -754,14 +739,8 @@
!$omp single
#ifdef WITH_MPI
if
(
wantDebug
)
call
obj
%
timer
%
start
(
"mpi_communication"
)
if
(
mynlc
>
0
)
call
mpi_allreduce
(
aux1
,
aux2
,
mynlc
,
&
#if REALCASE == 1
MPI_REAL_PRECISION
,
&
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION
,
&
#endif
MPI_SUM
,
mpi_comm_rows
,
mpierr
)
if
(
mynlc
>
0
)
call
mpi_allreduce
(
aux1
,
aux2
,
mynlc
,
MPI_MATH_DATATYPE_PRECISION
,
MPI_SUM
,
mpi_comm_rows
,
mpierr
)
if
(
wantDebug
)
call
obj
%
timer
%
stop
(
"mpi_communication"
)
#else /* WITH_MPI */
if
(
mynlc
>
0
)
aux2
=
aux1
...
...
@@ -807,14 +786,8 @@
! Get global dot products
#ifdef WITH_MPI
if
(
wantDebug
)
call
obj
%
timer
%
start
(
"mpi_communication"
)
if
(
nlc
>
0
)
call
mpi_allreduce
(
aux1
,
aux2
,
nlc
,
&
#if REALCASE == 1
MPI_REAL_PRECISION
,
&
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION
,&
#endif
MPI_SUM
,
mpi_comm_rows
,
mpierr
)
if
(
nlc
>
0
)
call
mpi_allreduce
(
aux1
,
aux2
,
nlc
,
MPI_MATH_DATATYPE_PRECISION
,
&
MPI_SUM
,
mpi_comm_rows
,
mpierr
)
if
(
wantDebug
)
call
obj
%
timer
%
stop
(
"mpi_communication"
)
#else /* WITH_MPI */
if
(
nlc
>
0
)
aux2
=
aux1
...
...
@@ -899,12 +872,7 @@
do
lc
=
n_cols
,
1
,
-1
tau
=
tmat
(
lc
,
lc
,
istep
)
if
(
lc
<
n_cols
)
then
#if REALCASE == 1
call
PRECISION_TRMV
(
'U'
,
'T'
,
'N'
,
&
#endif
#if COMPLEXCASE == 1
call
PRECISION_TRMV
(
'U'
,
'C'
,
'N'
,
&
#endif
call
PRECISION_TRMV
(
'U'
,
BLAS_TRANS_OR_CONJ
,
'N'
,&
n_cols
-
lc
,
tmat
(
lc
+1
,
lc
+1
,
istep
),
ubound
(
tmat
,
dim
=
1
),
vav
(
lc
+1
,
lc
),
1
)
#if REALCASE == 1
...
...
@@ -1040,12 +1008,7 @@
! C1 += A10' B0
if
(
lce
>
lcs
.and.
i
>
0
)
then
call
obj
%
timer
%
start
(
"blas"
)
#if REALCASE == 1
call
PRECISION_GEMM
(
'T'
,
'N'
,
&
#endif
#if COMPLEXCASE == 1
call
PRECISION_GEMM
(
'C'
,
'N'
,
&
#endif
call
PRECISION_GEMM
(
BLAS_TRANS_OR_CONJ
,
'N'
,
&
lce
-
lcs
+1
,
n_cols
,
lrs
-1
,
&
ONE
,
a
(
1
,
lcs
),
ubound
(
a
,
dim
=
1
),
&
vmrCPU
(
1
,
1
),
ubound
(
vmrCPU
,
dim
=
1
),
&
...
...
@@ -1099,12 +1062,7 @@
if
(
useGPU
)
then
call
obj
%
timer
%
start
(
"cublas"
)
#if REALCASE == 1
call
cublas_PRECISION_GEMM
(
'T'
,
'N'
,
&
#endif
#if COMPLEXCASE == 1
call
cublas_PRECISION_GEMM
(
'C'
,
'N'
,
&
#endif
call
cublas_PRECISION_GEMM
(
BLAS_TRANS_OR_CONJ
,
'N'
,
&
lce
-
lcs
+1
,
n_cols
,
lre
,
&
ONE
,
(
a_dev
+
((
lcs
-1
)
*
lda
*
&
size_of_datatype
)),
&
...
...
@@ -1131,12 +1089,7 @@
else
! useGPU
call
obj
%
timer
%
start
(
"blas"
)
#if REALCASE == 1
call
PRECISION_GEMM
(
'T'
,
'N'
,
&
#endif
#if COMPLEXCASE == 1
call
PRECISION_GEMM
(
'C'
,
'N'
,
&
#endif
call
PRECISION_GEMM
(
BLAS_TRANS_OR_CONJ
,
'N'
,
&
lce
-
lcs
+1
,
n_cols
,
lre
,
ONE
,
a
(
1
,
lcs
),
ubound
(
a
,
dim
=
1
),
&
vmrCPU
,
ubound
(
vmrCPU
,
dim
=
1
),
ONE
,
umcCPU
(
lcs
,
1
),
ubound
(
umcCPU
,
dim
=
1
))
call
obj
%
timer
%
stop
(
"blas"
)
...
...
@@ -1222,13 +1175,7 @@
if
(
wantDebug
)
call
obj
%
timer
%
start
(
"mpi_communication"
)
call
mpi_allreduce
(
umcCUDA
,
tmpCUDA
,
l_cols
*
n_cols
,
&
#if REALCASE == 1
MPI_REAL_PRECISION
,
&
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION
,
&
#endif
call
mpi_allreduce
(
umcCUDA
,
tmpCUDA
,
l_cols
*
n_cols
,
MPI_MATH_DATATYPE_PRECISION
,
&
MPI_SUM
,
mpi_comm_rows
,
ierr
)
umcCUDA
(
1
:
l_cols
*
n_cols
)
=
tmpCUDA
(
1
:
l_cols
*
n_cols
)
...
...
@@ -1261,13 +1208,7 @@
#ifdef WITH_MPI
if
(
wantDebug
)
call
obj
%
timer
%
start
(
"mpi_communication"
)
call
mpi_allreduce
(
umcCPU
,
tmpCPU
,
l_cols
*
n_cols
,
&
#if REALCASE == 1
MPI_REAL_PRECISION
,
&
#endif
#if COMPLEXCASE == 1
MPI_COMPLEX_PRECISION
,
&
#endif
call
mpi_allreduce
(
umcCPU
,
tmpCPU
,
l_cols
*
n_cols
,
MPI_MATH_DATATYPE_PRECISION
,
&
MPI_SUM
,
mpi_comm_rows
,
mpierr
)
umcCPU
(
1
:
l_cols
,
1
:
n_cols
)
=
tmpCPU
(
1
:
l_cols
,
1
:
n_cols
)
if
(
wantDebug
)
call
obj
%
timer
%
stop
(
"mpi_communication"
)
...
...
@@ -1306,12 +1247,7 @@
endif
call
obj
%
timer
%
start
(
"cublas"
)
#if REALCASE == 1
call
cublas_PRECISION_TRMM
(
'Right'
,
'Upper'
,
'Trans'
,
'Nonunit'
,
&
#endif
#if COMPLEXCASE == 1
call
cublas_PRECISION_TRMM
(
'Right'
,
'Upper'
,
'C'
,
'Nonunit'
,
&
#endif
call
cublas_PRECISION_TRMM
(
'Right'
,
'Upper'
,
BLAS_TRANS_OR_CONJ
,
'Nonunit'
,
&
l_cols
,
n_cols
,
ONE
,
tmat_dev
,
nbw
,
umc_dev
,
cur_l_cols
)
call
obj
%
timer
%
stop
(
"cublas"
)
...
...
@@ -1325,22 +1261,12 @@
endif
call
obj
%
timer
%
start
(
"cublas"
)
#if REALCASE == 1
call
cublas_PRECISION_GEMM
(
'T'
,
'N'
,
&
#endif
#if COMPLEXCASE == 1
call
cublas_PRECISION_GEMM
(
'C'
,
'N'
,
&
#endif
call
cublas_PRECISION_GEMM
(
BLAS_TRANS_OR_CONJ
,
'N'
,
&
n_cols
,
n_cols
,
l_cols
,
ONE
,
umc_dev
,
cur_l_cols
,
&
(
umc_dev
+
(
cur_l_cols
*
n_cols
)
*
size_of_datatype
),
cur_l_cols
,
&
ZERO
,
vav_dev
,
nbw
)
#if REALCASE == 1
call
cublas_PRECISION_TRMM
(
'Right'
,
'Upper'
,
'Trans'
,
'Nonunit'
,
&
#endif
#if COMPLEXCASE == 1
call
cublas_PRECISION_TRMM
(
'Right'
,
'Upper'
,
'C'
,
'Nonunit'
,
&
#endif
call
cublas_PRECISION_TRMM
(
'Right'
,
'Upper'
,
BLAS_TRANS_OR_CONJ
,
'Nonunit'
,
&
n_cols
,
n_cols
,
ONE
,
tmat_dev
,
nbw
,
vav_dev
,
nbw
)
call
obj
%
timer
%
stop
(
"cublas"
)
...
...
@@ -1355,32 +1281,17 @@
call
obj
%
timer
%
start
(
"blas"
)
#if REALCASE == 1
call
PRECISION_TRMM
(
'Right'
,
'Upper'
,
'Trans'
,
'Nonunit'
,
&
#endif
#if COMPLEXCASE == 1
call
PRECISION_TRMM
(
'Right'
,
'Upper'
,
'C'
,
'Nonunit'
,
&
#endif
call
PRECISION_TRMM
(
'Right'
,
'Upper'
,
BLAS_TRANS_OR_CONJ
,
'Nonunit'
,
&
l_cols
,
n_cols
,
ONE
,
tmat
(
1
,
1
,
istep
),
ubound
(
tmat
,
dim
=
1
),
&
umcCPU
,
ubound
(
umcCPU
,
dim
=
1
))
! VAV = Tmat * V**T * A * V * Tmat**T = (U*Tmat**T)**T * V * Tmat**T
#if REALCASE == 1
call
PRECISION_GEMM
(
'T'
,
'N'
,
&
#endif
#if COMPLEXCASE == 1
call
PRECISION_GEMM
(
'C'
,
'N'
,
&
#endif
call
PRECISION_GEMM
(
BLAS_TRANS_OR_CONJ
,
'N'
,
&
n_cols
,
n_cols
,
l_cols
,
ONE
,
umcCPU
,
ubound
(
umcCPU
,
dim
=
1
),
umcCPU
(
1
,
n_cols
+1
),
&
ubound
(
umcCPU
,
dim
=
1
),
ZERO
,
vav
,
ubound
(
vav
,
dim
=
1
))
#if REALCASE == 1
call
PRECISION_TRMM
(
'Right'
,
'Upper'
,
'Trans'
,
'Nonunit'
,
&
#endif
#if COMPLEXCASE == 1
call
PRECISION_TRMM
(
'Right'
,
'Upper'
,
'C'
,
'Nonunit'
,
&
#endif
call
PRECISION_TRMM
(
'Right'
,
'Upper'
,
BLAS_TRANS_OR_CONJ
,
'Nonunit'
,
&
n_cols
,
n_cols
,
ONE
,
tmat
(
1
,
1
,
istep
),
&
ubound
(
tmat
,
dim
=
1
),
vav
,
ubound
(
vav
,
dim
=
1
))
call
obj
%
timer
%
stop
(
"blas"
)
...
...
@@ -1519,16 +1430,9 @@
if
(
myend
>
lre
)
myend
=
lre
if
(
myend
-
mystart
+1
<
1
)
cycle
call
obj
%
timer
%
start
(
"blas"
)
#if REALCASE == 1
call
PRECISION_GEMM
(
'N'
,
'T'
,
myend
-
mystart
+1
,
lce
-
lcs
+1
,
2
*
n_cols
,
-
ONE
,
&
call
PRECISION_GEMM
(
'N'
,
BLAS_TRANS_OR_CONJ
,
myend
-
mystart
+1
,
lce
-
lcs
+1
,
2
*
n_cols
,
-
ONE
,
&
vmrCPU
(
mystart
,
1
),
ubound
(
vmrCPU
,
1
),
umcCPU
(
lcs
,
1
),
ubound
(
umcCPU
,
1
),
&
ONE
,
a
(
mystart
,
lcs
),
ubound
(
a
,
1
))
#endif
#if COMPLEXCASE == 1
call
PRECISION_GEMM
(
'N'
,
'C'
,
myend
-
mystart
+1
,
lce
-
lcs
+1
,
2
*
n_cols
,
-
ONE
,
&
vmrCPU
(
mystart
,
1
),
ubound
(
vmrCPU
,
1
),
umcCPU
(
lcs
,
1
),
ubound
(
umcCPU
,
1
),
&
one
,
a
(
mystart
,
lcs
),
ubound
(
a
,
1
))
#endif
call
obj
%
timer
%
stop
(
"blas"
)
enddo
!$omp end parallel
...
...
@@ -1557,12 +1461,7 @@
if
(
useGPU
)
then
call
obj
%
timer
%
start
(
"cublas"
)
#if REALCASE == 1
call
cublas_PRECISION_GEMM
(
'N'
,
'T'
,
&
#endif
#if COMPLEXCASE == 1
call
cublas_PRECISION_GEMM
(
'N'
,
'C'
,
&
#endif
call
cublas_PRECISION_GEMM
(
'N'
,
BLAS_TRANS_OR_CONJ
,
&
lre
,
lce
-
lcs
+1
,
2
*
n_cols
,
-
ONE
,
&
vmr_dev
,
cur_l_rows
,
(
umc_dev
+
(
lcs
-1
)
*
&
size_of_datatype
),
&
...
...
@@ -1573,13 +1472,7 @@
else
! useGPU
call
obj
%
timer
%
start
(
"blas"
)
#if REALCASE == 1
call
PRECISION_GEMM
(
'N'
,
'T'
,
&
#endif
#if COMPLEXCASE == 1
call
PRECISION_GEMM
(
'N'
,
'C'
,
&
#endif
lre
,
lce
-
lcs
+1
,
2
*
n_cols
,
-
ONE
,
&
call
PRECISION_GEMM
(
'N'
,
BLAS_TRANS_OR_CONJ
,
lre
,
lce
-
lcs
+1
,
2
*
n_cols
,
-
ONE
,
&
vmrCPU
,
ubound
(
vmrCPU
,
dim
=
1
),
umcCPU
(
lcs
,
1
),
ubound
(
umcCPU
,
dim
=
1
),
&
ONE
,
a
(
1
,
lcs
),
lda
)
call
obj
%
timer
%
stop
(
"blas"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment