Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
elpa
elpa
Commits
2d1ad5c9
Commit
2d1ad5c9
authored
Aug 28, 2017
by
Andreas Marek
Browse files
Unify real/complex GPU code path in band_to_full
parent
cc0bf8f5
Changes
1
Pipelines
2
Hide whitespace changes
Inline
Side-by-side
src/elpa2/elpa2_trans_ev_band_to_full_template.F90
View file @
2d1ad5c9
...
...
@@ -308,7 +308,7 @@
#ifdef WITH_MPI
call
obj
%
timer
%
start
(
"mpi_communication"
)
call
MPI_Bcast
(
hvb
(
ns
+1
),
nb
-
ns
,
MPI_MATH_DATATYPE_PRECISION
,&
pcol
(
ncol
,
nblk
,
np_cols
),
mpi_comm_cols
,
mpierr
)
pcol
(
ncol
,
nblk
,
np_cols
),
mpi_comm_cols
,
mpierr
)
call
obj
%
timer
%
stop
(
"mpi_communication"
)
...
...
@@ -348,7 +348,6 @@
q_dev
,
ldq
,
ZERO
,
tmp_dev
,
n_cols
)
call
obj
%
timer
%
stop
(
"cublas"
)
#if REALCASE == 1
#ifdef WITH_MPI
! copy data from device to host for a later MPI_ALLREDUCE
...
...
@@ -359,34 +358,15 @@
stop
1
endif
#else /* WITH_MPI */
! in real case no copy needed.
maybe also
in complexcase
?
! in real case no copy needed.
Don't do it
in complex
case
neither
#endif /* WITH_MPI */
#endif /* REALCASE */
#if COMPLEXCASE == 1
successCUDA
=
cuda_memcpy
(
loc
(
tmp1
),
tmp_dev
,
n_cols
*
l_cols
*
size_of_datatype
,
&
cudaMemcpyDeviceToHost
)
if
(
.not.
(
successCUDA
))
then
print
*
,
"trans_ev_band_to_full_complex: error in cudaMemcpy"
stop
1
endif
#endif
else
! l_rows>0
tmp1
(
1
:
l_cols
*
n_cols
)
=
0.0_rck
endif
! l_rows>0
!#ifdef WITH_GPU_VERSION
! istat = cuda_memcpy(loc(tmp1), tmp_dev, max_local_cols*nbw*size_of_datatype,cudaMemcpyDeviceToHost)
! if (istat .ne. 0) then
! print *,"error in cudaMemcpy"
! stop 1
! endif
!#endif
#ifdef WITH_MPI
call
obj
%
timer
%
start
(
"mpi_communication"
)
call
mpi_allreduce
(
tmp1
,
tmp2
,
n_cols
*
l_cols
,
MPI_MATH_DATATYPE_PRECISION
,
&
...
...
@@ -397,14 +377,6 @@
! tmp2(1:n_cols*l_cols) = tmp1(1:n_cols*l_cols)
#endif /* WITH_MPI */
!#ifdef WITH_GPU_VERSION
! istat = cuda_memcpy(tmp_dev, loc(tmp2), max_local_cols*nbw*size_of_datatype,cudaMemcpyHostToDevice)
! if (istat .ne. 0) then
! print *,"error in cudaMemcpy"
! stop 1
! endif
!#endif
if
(
l_rows
>
0
)
then
#ifdef WITH_MPI
! after the mpi_allreduce we have to copy back to the device
...
...
@@ -418,15 +390,7 @@
stop
1
endif
#else /* WITH_MPI */
#if COMPLEXCASE == 1
! check whether this could be avoided like in the real case
successCUDA
=
cuda_memcpy
(
tmp_dev
,
loc
(
tmp1
),
l_cols
*
n_cols
*
size_of_datatype
,
cudaMemcpyHostToDevice
)
if
(
.not.
(
successCUDA
))
then
print
*
,
"trans_ev_band_to_full_complex: error in cudaMemcpy"
stop
1
endif
#endif
! in real case no memcopy needed. Don't do it in complex case neither
#endif /* WITH_MPI */
!#ifdef WITH_MPI
...
...
@@ -451,7 +415,6 @@
tmp_dev
,
n_cols
,
one
,
q_dev
,
ldq
)
call
obj
%
timer
%
stop
(
"cublas"
)
#if REALCASE == 1
! copy to host maybe this can be avoided
! this is not necessary hvm is not used anymore
successCUDA
=
cuda_memcpy
(
loc
(
hvm
),
hvm_dev
,
((
max_local_rows
)
*
nbw
*
size_of_datatype
),
cudaMemcpyDeviceToHost
)
...
...
@@ -459,17 +422,8 @@
print
*
,
"trans_ev_band_to_full_real: error in cudaMemcpy"
stop
1
endif
#endif
endif
! l_rows > 0
!#ifdef WITH_GPU_VERSION
! istat = cuda_memcpy(loc(hvm), hvm_dev, ((max_local_rows)*nbw*size_of_datatype),cudaMemcpyDeviceToHost)
! if (istat .ne. 0) then
! print *,"error in cudaMemcpy"
! stop 1
! endif
!
!#endif
enddo
! istep
...
...
@@ -594,7 +548,8 @@
#ifdef BAND_TO_FULL_BLOCKING
! This the call when using na >= ((t_blocking+1)*nbw)
! n_cols = MIN(na,istep*cwy_blocking+nbw) - (istep-1)*cwy_blocking - nbw ! Number of columns in current step
! n_cols = MIN(na,istep*cwy_blocking+nbw) - (istep-1)*cwy_blocking - nbw
! Number of columns in current step
! As an alternative we add some special case handling if na < cwy_blocking
IF
(
na
<
cwy_blocking
)
THEN
n_cols
=
MAX
(
0
,
na
-
nbw
)
...
...
@@ -631,7 +586,7 @@
#ifdef WITH_MPI
call
obj
%
timer
%
start
(
"mpi_communication"
)
call
MPI_Bcast
(
hvb
(
ns
+1
),
nb
-
ns
,
MPI_MATH_DATATYPE_PRECISION
,
&
pcol
(
ncol
,
nblk
,
np_cols
),
mpi_comm_cols
,
mpierr
)
pcol
(
ncol
,
nblk
,
np_cols
),
mpi_comm_cols
,
mpierr
)
call
obj
%
timer
%
stop
(
"mpi_communication"
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment