Commit ea6daa06 authored by Andreas Marek's avatar Andreas Marek
Browse files

Enable GPU verion with OpenMP in ELPA 2stage

The GPU logic has been implemented in the OpenMP code paths in ELPA2.
Currently, this implies that _internal_ to ELPA2, the number of OpenMP
threads is set to one (independent of how many threads the calling
application uses) and the original value is restored at the end of ELPA.
Though this is not the general case, this is _not_ a limitation, since
in the GPU case no work is done on the CPU and thus no threading is
helpful
parent ff3f445a
......@@ -57,7 +57,7 @@
#include "cuUtils_template.cu"
#undef DOUBLE_PRECISION_REAL
#if WANT_SINGLE_PRECISION_REAL
#ifdef WANT_SINGLE_PRECISION_REAL
#undef DOUBLE_PRECISION_REAL
#include "cuUtils_template.cu"
......@@ -71,7 +71,7 @@
#include "cuUtils_template.cu"
#undef DOUBLE_PRECISION_COMPLEX
#if WANT_SINGLE_PRECISION_COMPLEX
#ifdef WANT_SINGLE_PRECISION_COMPLEX
#undef DOUBLE_PRECISION_COMPLEX
#include "cuUtils_template.cu"
......
......@@ -137,7 +137,7 @@ module elpa1_auxiliary_impl
#undef DOUBLE_PRECISION
#undef REALCASE
#if WANT_SINGLE_PRECISION_REAL
#ifdef WANT_SINGLE_PRECISION_REAL
#define REALCASE 1
#define SINGLE_PRECISION
#include "../general/precision_macros.h"
......@@ -287,7 +287,7 @@ module elpa1_auxiliary_impl
#undef DOUBLE_PRECISION
#undef REALCASE
#if WANT_SINGLE_PRECISION_REAL
#ifdef WANT_SINGLE_PRECISION_REAL
#define REALCASE 1
#define SINGLE_PRECISION
#include "../general/precision_macros.h"
......
......@@ -58,11 +58,10 @@ l_nev, &
a_off, nbw, max_blk_size, bcast_buffer, bcast_buffer_dev, &
hh_tau_dev, kernel_flops, kernel_time, n_times, off, ncols, istripe, &
#ifdef WITH_OPENMP_TRADITIONAL
my_thread, thread_width, &
my_thread, thread_width, kernel, last_stripe_width)
#else
last_stripe_width, &
last_stripe_width, kernel)
#endif
kernel)
use precision
use elpa_abstract_impl
......@@ -141,6 +140,7 @@ kernel)
#else /* WITH_OPENMP_TRADITIONAL */
integer(kind=ik), intent(in) :: l_nev, thread_width
integer(kind=ik), intent(in), optional :: last_stripe_width
#if REALCASE == 1
! real(kind=C_DATATYPE_KIND) :: a(stripe_width,a_dim2,stripe_count,max_threads)
real(kind=C_DATATYPE_KIND), pointer :: a(:,:,:,:)
......@@ -221,54 +221,39 @@ kernel)
#ifdef WITH_OPENMP_TRADITIONAL
if (my_thread==1) then
if (my_thread==1) then ! in the calling routine threads go form 1 .. max_threads
#endif
ttt = mpi_wtime()
#ifdef WITH_OPENMP_TRADITIONAL
endif
#endif
#ifdef WITH_OPENMP_TRADITIONAL
#if REALCASE == 1
if (kernel .eq. ELPA_2STAGE_REAL_GPU) then
print *,"compute_hh_trafo_&
&MATH_DATATYPE&
&_GPU OPENMP: not yet implemented"
stop 1
endif
#endif
#if COMPLEXCASE == 1
if (kernel .eq. ELPA_2STAGE_COMPLEX_GPU) then
print *,"compute_hh_trafo_&
&MATH_DATATYPE&
&_GPU OPENMP: not yet implemented"
stop 1
endif
#endif
#endif /* WITH_OPENMP_TRADITIONAL */
#ifndef WITH_OPENMP_TRADITIONAL
nl = merge(stripe_width, last_stripe_width, istripe<stripe_count)
#else /* WITH_OPENMP_TRADITIONAL */
if (istripe<stripe_count) then
nl = stripe_width
if (present(last_stripe_width)) then
nl = merge(stripe_width, last_stripe_width, istripe<stripe_count)
else
noff = (my_thread-1)*thread_width + (istripe-1)*stripe_width
nl = min(my_thread*thread_width-noff, l_nev-noff)
if (nl<=0) then
if (wantDebug) call obj%timer%stop("compute_hh_trafo_&
&MATH_DATATYPE&
if (istripe<stripe_count) then
nl = stripe_width
else
noff = (my_thread-1)*thread_width + (istripe-1)*stripe_width
nl = min(my_thread*thread_width-noff, l_nev-noff)
if (nl<=0) then
if (wantDebug) call obj%timer%stop("compute_hh_trafo_&
&MATH_DATATYPE&
#ifdef WITH_OPENMP_TRADITIONAL
&_openmp" // &
&_openmp" // &
#else
&" // &
&" // &
#endif
&PRECISION_SUFFIX &
)
&PRECISION_SUFFIX &
)
return
return
endif
endif
endif
#endif /* not WITH_OPENMP_TRADITIONAL */
......
......@@ -528,8 +528,13 @@ subroutine tridiag_band_&
! with MPI calls
call obj%timer%start("OpenMP parallel" // PRECISION_SUFFIX)
!$omp parallel do private(my_thread, my_block_s, my_block_e, iblk, ns, ne, hv, tau, &
!$omp& nc, nr, hs, hd, vnorm2, hf, x, h, i), schedule(static,1), num_threads(max_threads)
!$omp parallel do &
!$omp default(none) &
!$omp private(my_thread, my_block_s, my_block_e, iblk, ns, ne, hv, tau, &
!$omp& nc, nr, hs, hd, vnorm2, hf, x, h, i) &
!$omp shared(max_threads, obj, ab, isSkewsymmetric, wantDebug, hh_gath, &
!$omp hh_cnt, tau_t, hv_t, na, istep, n_off, na_s, nb, omp_block_limits, iter) &
!$omp schedule(static,1), num_threads(max_threads)
do my_thread = 1, max_threads
if (iter == 1) then
......@@ -1087,7 +1092,7 @@ subroutine tridiag_band_&
endif
#endif
#if WITH_OPENMP_TRADITIONAL
#ifdef WITH_OPENMP_TRADITIONAL
do iblk = 1, nblocks
if (hh_dst(iblk) >= np_rows) exit
......
......@@ -57,7 +57,7 @@ module qr_utils_mod
public :: reverse_matrix_1dcomm_double
public :: reverse_matrix_2dcomm_ref_double
#if WANT_SINGLE_PRECISION_REAL
#ifdef WANT_SINGLE_PRECISION_REAL
public :: reverse_vector_local_single
public :: reverse_matrix_local_single
public :: reverse_matrix_1dcomm_single
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment