Commit a9f1fc4b authored by Andreas Marek's avatar Andreas Marek

Pass number of OMP threads in subroutines

parent ab19611c
...@@ -66,6 +66,9 @@ function elpa_solve_evp_& ...@@ -66,6 +66,9 @@ function elpa_solve_evp_&
use elpa_abstract_impl use elpa_abstract_impl
use elpa_mpi use elpa_mpi
use elpa1_compute use elpa1_compute
#ifdef WITH_OPENMP
use omp_lib
#endif
implicit none implicit none
#include "../general/precision_kinds.F90" #include "../general/precision_kinds.F90"
class(elpa_abstract_impl_t), intent(inout) :: obj class(elpa_abstract_impl_t), intent(inout) :: obj
...@@ -110,7 +113,7 @@ function elpa_solve_evp_& ...@@ -110,7 +113,7 @@ function elpa_solve_evp_&
mpi_comm_all, check_pd, i, error mpi_comm_all, check_pd, i, error
logical :: do_bandred, do_solve, do_trans_ev logical :: do_bandred, do_solve, do_trans_ev
integer(kind=ik) :: nrThreads, omp_get_num_threads integer(kind=ik) :: nrThreads
call obj%timer%start("elpa_solve_evp_& call obj%timer%start("elpa_solve_evp_&
&MATH_DATATYPE& &MATH_DATATYPE&
...@@ -119,7 +122,7 @@ function elpa_solve_evp_& ...@@ -119,7 +122,7 @@ function elpa_solve_evp_&
&") &")
#ifdef WITH_OPENMP #ifdef WITH_OPENMP
nrThreads = omp_get_num_threads() nrThreads = omp_get_max_threads()
#else #else
nrThreads = 1 nrThreads = 1
#endif #endif
...@@ -310,7 +313,7 @@ function elpa_solve_evp_& ...@@ -310,7 +313,7 @@ function elpa_solve_evp_&
if (obj%eigenvalues_only) then if (obj%eigenvalues_only) then
do_trans_ev = .true. do_trans_ev = .true.
endif endif
print *,"ELPA 1 ",nrThreads
if (do_bandred) then if (do_bandred) then
call obj%timer%start("forward") call obj%timer%start("forward")
call tridiag_& call tridiag_&
......
...@@ -535,11 +535,6 @@ call prmat(na,useGpu,a_mat,a_dev,lda,matrixCols,nblk,my_prow,my_pcol,np_rows,np_ ...@@ -535,11 +535,6 @@ call prmat(na,useGpu,a_mat,a_dev,lda,matrixCols,nblk,my_prow,my_pcol,np_rows,np_
my_thread = omp_get_thread_num() my_thread = omp_get_thread_num()
n_threads = omp_get_num_threads() n_threads = omp_get_num_threads()
! debug REMOVE again
print *,"debug"
if (n_threads .ne. max_threads) then
print *,"WTF?"
endif
n_iter = 0 n_iter = 0
......
...@@ -71,7 +71,7 @@ ...@@ -71,7 +71,7 @@
logical :: success logical :: success
integer(kind=ik) :: istat, debug, error integer(kind=ik) :: istat, debug, error
character(200) :: errorMessage character(200) :: errorMessage
integer(kind=ik) :: max_threads integer(kind=ik) :: nrThreads
call obj%timer%start("elpa_cholesky_& call obj%timer%start("elpa_cholesky_&
&MATH_DATATYPE& &MATH_DATATYPE&
...@@ -80,9 +80,9 @@ ...@@ -80,9 +80,9 @@
&") &")
#ifdef WITH_OPENMP #ifdef WITH_OPENMP
max_threads=omp_get_num_threads() nrThreads=omp_get_max_threads()
#else #else
max_threads=1 nrThreads=1
#endif #endif
na = obj%na na = obj%na
...@@ -295,7 +295,7 @@ ...@@ -295,7 +295,7 @@
&PRECISION & &PRECISION &
(obj, tmatc, ubound(tmatc,dim=1), mpi_comm_cols, & (obj, tmatc, ubound(tmatc,dim=1), mpi_comm_cols, &
tmatr, ubound(tmatr,dim=1), mpi_comm_rows, & tmatr, ubound(tmatr,dim=1), mpi_comm_rows, &
n, na, nblk, nblk, max_threads) n, na, nblk, nblk, nrThreads)
do i=0,(na-1)/tile_size do i=0,(na-1)/tile_size
lcs = max(l_colx,i*l_cols_tile+1) lcs = max(l_colx,i*l_cols_tile+1)
......
...@@ -95,7 +95,7 @@ ...@@ -95,7 +95,7 @@
matrixCols = obj%local_ncols matrixCols = obj%local_ncols
#ifdef WITH_OPENMP #ifdef WITH_OPENMP
max_threads=omp_get_num_threads() max_threads=omp_get_max_threads()
#else #else
max_threads=1 max_threads=1
#endif #endif
......
...@@ -1463,13 +1463,9 @@ ...@@ -1463,13 +1463,9 @@
! A = A - V*U**T - U*V**T ! A = A - V*U**T - U*V**T
#ifdef WITH_OPENMP #ifdef WITH_OPENMP
! OPENMP_CHANGE here
!$omp parallel private( ii, i, lcs, lce, lre, n_way, m_way, m_id, n_id, work_per_thread, mystart, myend ) !$omp parallel private( ii, i, lcs, lce, lre, n_way, m_way, m_id, n_id, work_per_thread, mystart, myend )
n_threads = omp_get_num_threads() n_threads = omp_get_num_threads()
print *,"debug"
if (n_threads .ne. max_threads) then
print *,"WTF2"
endif
if (mod(n_threads, 2) == 0) then if (mod(n_threads, 2) == 0) then
n_way = 2 n_way = 2
else else
......
...@@ -490,7 +490,7 @@ ...@@ -490,7 +490,7 @@
&_& &_&
&PRECISION& &PRECISION&
(obj, na, nbw, nblk, a, a_dev, lda, ev, e, matrixCols, hh_trans, mpi_comm_rows, mpi_comm_cols, mpi_comm_all, & (obj, na, nbw, nblk, a, a_dev, lda, ev, e, matrixCols, hh_trans, mpi_comm_rows, mpi_comm_cols, mpi_comm_all, &
do_useGPU, wantDebug) do_useGPU, wantDebug, nrThreads)
#ifdef WITH_MPI #ifdef WITH_MPI
call obj%timer%start("mpi_communication") call obj%timer%start("mpi_communication")
......
...@@ -97,7 +97,7 @@ ...@@ -97,7 +97,7 @@
use precision use precision
use iso_c_binding use iso_c_binding
#ifdef WITH_OPENMP #ifdef WITH_OPENMP
use omp_lib ! use omp_lib
#endif #endif
implicit none implicit none
#include "../general/precision_kinds.F90" #include "../general/precision_kinds.F90"
...@@ -219,11 +219,6 @@ ...@@ -219,11 +219,6 @@
kernel_time = 0.0 kernel_time = 0.0
kernel_flops = 0 kernel_flops = 0
!#ifdef WITH_OPENMP
! ! openmp_change_here
! max_threads = 1
! max_threads = omp_get_max_threads()
!#endif
if (wantDebug) call obj%timer%start("mpi_communication") if (wantDebug) call obj%timer%start("mpi_communication")
call MPI_Comm_rank(mpi_comm_rows, my_prow, mpierr) call MPI_Comm_rank(mpi_comm_rows, my_prow, mpierr)
call MPI_Comm_size(mpi_comm_rows, np_rows, mpierr) call MPI_Comm_size(mpi_comm_rows, np_rows, mpierr)
......
...@@ -56,7 +56,7 @@ ...@@ -56,7 +56,7 @@
&_& &_&
&PRECISION & &PRECISION &
(obj, na, nb, nblk, aMatrix, a_dev, lda, d, e, matrixCols, & (obj, na, nb, nblk, aMatrix, a_dev, lda, d, e, matrixCols, &
hh_trans, mpi_comm_rows, mpi_comm_cols, communicator, useGPU, wantDebug) hh_trans, mpi_comm_rows, mpi_comm_cols, communicator, useGPU, wantDebug, nrThreads)
!------------------------------------------------------------------------------- !-------------------------------------------------------------------------------
! tridiag_band_real/complex: ! tridiag_band_real/complex:
! Reduces a real symmetric band matrix to tridiagonal form ! Reduces a real symmetric band matrix to tridiagonal form
...@@ -89,6 +89,9 @@ ...@@ -89,6 +89,9 @@
use precision use precision
use iso_c_binding use iso_c_binding
use redist use redist
#ifdef WITH_OPENMP
use omp_lib
#endif
implicit none implicit none
#include "../general/precision_kinds.F90" #include "../general/precision_kinds.F90"
class(elpa_abstract_impl_t), intent(inout) :: obj class(elpa_abstract_impl_t), intent(inout) :: obj
...@@ -112,16 +115,14 @@ ...@@ -112,16 +115,14 @@
integer(kind=ik) :: my_prow, np_rows, my_pcol, np_cols integer(kind=ik) :: my_prow, np_rows, my_pcol, np_cols
integer(kind=ik) :: ireq_ab, ireq_hv integer(kind=ik) :: ireq_ab, ireq_hv
integer(kind=ik) :: na_s, nx, num_hh_vecs, num_chunks, local_size, max_blk_size, n_off integer(kind=ik) :: na_s, nx, num_hh_vecs, num_chunks, local_size, max_blk_size, n_off
integer(kind=ik), intent(in) :: nrThreads
#ifdef WITH_OPENMP #ifdef WITH_OPENMP
integer(kind=ik) :: max_threads, my_thread, my_block_s, my_block_e, iter integer(kind=ik) :: max_threads, my_thread, my_block_s, my_block_e, iter
#ifdef WITH_MPI #ifdef WITH_MPI
! integer(kind=ik) :: my_mpi_status(MPI_STATUS_SIZE)
#endif #endif
! integer(kind=ik), allocatable :: mpi_statuses(:,:), global_id_tmp(:,:)
integer(kind=ik), allocatable :: global_id_tmp(:,:) integer(kind=ik), allocatable :: global_id_tmp(:,:)
integer(kind=ik), allocatable :: omp_block_limits(:) integer(kind=ik), allocatable :: omp_block_limits(:)
MATH_DATATYPE(kind=rck), allocatable :: hv_t(:,:), tau_t(:) MATH_DATATYPE(kind=rck), allocatable :: hv_t(:,:), tau_t(:)
integer(kind=ik) :: omp_get_max_threads
#endif /* WITH_OPENMP */ #endif /* WITH_OPENMP */
integer(kind=ik), allocatable :: ireq_hhr(:), ireq_hhs(:), global_id(:,:), hh_cnt(:), hh_dst(:) integer(kind=ik), allocatable :: ireq_hhr(:), ireq_hhs(:), global_id(:,:), hh_cnt(:), hh_dst(:)
integer(kind=ik), allocatable :: limits(:), snd_limits(:,:) integer(kind=ik), allocatable :: limits(:), snd_limits(:,:)
...@@ -379,15 +380,7 @@ ...@@ -379,15 +380,7 @@
#ifdef WITH_OPENMP #ifdef WITH_OPENMP
! OpenMP work distribution: ! OpenMP work distribution:
max_threads = nrThreads
max_threads = 1
#if REALCASE == 1
! OPENMP_CHANGE here
max_threads = omp_get_max_threads()
#endif
#if COMPLEXCASE == 1
!$ max_threads = omp_get_max_threads()
#endif
! For OpenMP we need at least 2 blocks for every thread ! For OpenMP we need at least 2 blocks for every thread
max_threads = MIN(max_threads, nblocks/2) max_threads = MIN(max_threads, nblocks/2)
if (max_threads==0) max_threads = 1 if (max_threads==0) max_threads = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment