Commit 56043bdc authored by Andreas Marek's avatar Andreas Marek
Browse files

Single precision support for ELPA2

ELPA2 can now be build (as ELPA1) for single precision calculations.
The ELPA2 kernles which are implemented in assembler, C, or C++ have NOT
yet been ported.

Thus at the moment only the GENERIC and GENERIC_SIMPLE kernels support
single precision calculations
parent de6a4fde
...@@ -146,6 +146,7 @@ contains ...@@ -146,6 +146,7 @@ contains
use precision use precision
use cuda_functions use cuda_functions
use mod_check_for_gpu use mod_check_for_gpu
use iso_c_binding
implicit none implicit none
logical, intent(in), optional :: useQR logical, intent(in), optional :: useQR
logical :: useQRActual, useQREnvironment logical :: useQRActual, useQREnvironment
...@@ -163,7 +164,7 @@ contains ...@@ -163,7 +164,7 @@ contains
integer(kind=ik) :: my_pe, n_pes, my_prow, my_pcol, np_rows, np_cols, mpierr integer(kind=ik) :: my_pe, n_pes, my_prow, my_pcol, np_rows, np_cols, mpierr
integer(kind=ik) :: nbw, num_blocks integer(kind=ik) :: nbw, num_blocks
real(kind=rk), allocatable :: tmat(:,:,:), e(:) real(kind=rk), allocatable :: tmat(:,:,:), e(:)
real(kind=rk) :: ttt0, ttt1, ttts real(kind=c_double) :: ttt0, ttt1, ttts ! MPI_WTIME always needs double
integer(kind=ik) :: i integer(kind=ik) :: i
logical :: success logical :: success
logical, save :: firstCall = .true. logical, save :: firstCall = .true.
...@@ -294,6 +295,7 @@ contains ...@@ -294,6 +295,7 @@ contains
ttts = ttt0 ttts = ttt0
call bandred_real(na, a, lda, nblk, nbw, matrixCols, num_blocks, mpi_comm_rows, mpi_comm_cols, & call bandred_real(na, a, lda, nblk, nbw, matrixCols, num_blocks, mpi_comm_rows, mpi_comm_cols, &
tmat, wantDebug, useGPU, success, useQRActual) tmat, wantDebug, useGPU, success, useQRActual)
if (.not.(success)) return if (.not.(success)) return
ttt1 = MPI_Wtime() ttt1 = MPI_Wtime()
if (my_prow==0 .and. my_pcol==0 .and. elpa_print_times) & if (my_prow==0 .and. my_pcol==0 .and. elpa_print_times) &
...@@ -308,14 +310,19 @@ contains ...@@ -308,14 +310,19 @@ contains
endif endif
ttt0 = MPI_Wtime() ttt0 = MPI_Wtime()
call tridiag_band_real(na, nbw, nblk, a, lda, ev, e, matrixCols, hh_trans_real, & call tridiag_band_real(na, nbw, nblk, a, lda, ev, e, matrixCols, hh_trans_real, &
mpi_comm_rows, mpi_comm_cols, mpi_comm_all) mpi_comm_rows, mpi_comm_cols, mpi_comm_all)
ttt1 = MPI_Wtime() ttt1 = MPI_Wtime()
if (my_prow==0 .and. my_pcol==0 .and. elpa_print_times) & if (my_prow==0 .and. my_pcol==0 .and. elpa_print_times) &
write(error_unit,*) 'Time tridiag_band_real :',ttt1-ttt0 write(error_unit,*) 'Time tridiag_band_real :',ttt1-ttt0
#ifdef DOUBLE_PRECISION_REAL
call mpi_bcast(ev,na,MPI_REAL8,0,mpi_comm_all,mpierr) call mpi_bcast(ev,na,MPI_REAL8,0,mpi_comm_all,mpierr)
call mpi_bcast(e,na,MPI_REAL8,0,mpi_comm_all,mpierr) call mpi_bcast(e,na,MPI_REAL8,0,mpi_comm_all,mpierr)
#else
call mpi_bcast(ev,na,MPI_REAL4,0,mpi_comm_all,mpierr)
call mpi_bcast(e,na,MPI_REAL4,0,mpi_comm_all,mpierr)
#endif
ttt1 = MPI_Wtime() ttt1 = MPI_Wtime()
time_evp_fwd = ttt1-ttts time_evp_fwd = ttt1-ttts
...@@ -426,6 +433,7 @@ function solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, & ...@@ -426,6 +433,7 @@ function solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, &
use precision use precision
use cuda_functions use cuda_functions
use mod_check_for_gpu use mod_check_for_gpu
use iso_c_binding
implicit none implicit none
integer(kind=ik), intent(in), optional :: THIS_COMPLEX_ELPA_KERNEL_API integer(kind=ik), intent(in), optional :: THIS_COMPLEX_ELPA_KERNEL_API
integer(kind=ik) :: THIS_COMPLEX_ELPA_KERNEL integer(kind=ik) :: THIS_COMPLEX_ELPA_KERNEL
...@@ -440,7 +448,7 @@ function solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, & ...@@ -440,7 +448,7 @@ function solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, &
integer(kind=ik) :: l_cols, l_rows, l_cols_nev, nbw, num_blocks integer(kind=ik) :: l_cols, l_rows, l_cols_nev, nbw, num_blocks
complex(kind=ck), allocatable :: tmat(:,:,:) complex(kind=ck), allocatable :: tmat(:,:,:)
real(kind=rk), allocatable :: q_real(:,:), e(:) real(kind=rk), allocatable :: q_real(:,:), e(:)
real(kind=rk) :: ttt0, ttt1, ttts real(kind=c_double) :: ttt0, ttt1, ttts ! MPI_WTIME always needs double
integer(kind=ik) :: i integer(kind=ik) :: i
logical :: success, wantDebug logical :: success, wantDebug
...@@ -562,9 +570,13 @@ function solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, & ...@@ -562,9 +570,13 @@ function solve_evp_complex_2stage(na, nev, a, lda, ev, q, ldq, nblk, &
if (my_prow==0 .and. my_pcol==0 .and. elpa_print_times) & if (my_prow==0 .and. my_pcol==0 .and. elpa_print_times) &
write(error_unit,*) 'Time tridiag_band_complex :',ttt1-ttt0 write(error_unit,*) 'Time tridiag_band_complex :',ttt1-ttt0
call mpi_bcast(ev,na,MPI_REAL8,0,mpi_comm_all,mpierr) #ifdef DOUBLE_PRECISION_COMPLEX
call mpi_bcast(e,na,MPI_REAL8,0,mpi_comm_all,mpierr) call mpi_bcast(ev, na, mpi_real8, 0, mpi_comm_all, mpierr)
call mpi_bcast(e, na, mpi_real8, 0, mpi_comm_all, mpierr)
#else
call mpi_bcast(ev, na, mpi_real4, 0, mpi_comm_all, mpierr)
call mpi_bcast(e, na, mpi_real4, 0, mpi_comm_all, mpierr)
#endif
ttt1 = MPI_Wtime() ttt1 = MPI_Wtime()
time_evp_fwd = ttt1-ttts time_evp_fwd = ttt1-ttts
......
This diff is collapsed.
...@@ -34,8 +34,9 @@ module compute_hh_trafo_complex ...@@ -34,8 +34,9 @@ module compute_hh_trafo_complex
#ifdef HAVE_DETAILED_TIMINGS #ifdef HAVE_DETAILED_TIMINGS
use timings use timings
#endif #endif
use iso_c_binding
implicit none implicit none
real(kind=rk), intent(inout) :: kernel_time real(kind=c_double), intent(inout) :: kernel_time ! MPI_WTIME always needs double
integer(kind=lik) :: kernel_flops integer(kind=lik) :: kernel_flops
integer(kind=ik), intent(in) :: nbw, max_blk_size integer(kind=ik), intent(in) :: nbw, max_blk_size
complex(kind=ck) :: bcast_buffer(nbw,max_blk_size) complex(kind=ck) :: bcast_buffer(nbw,max_blk_size)
...@@ -57,7 +58,7 @@ module compute_hh_trafo_complex ...@@ -57,7 +58,7 @@ module compute_hh_trafo_complex
#ifdef WITH_OPENMP #ifdef WITH_OPENMP
integer(kind=ik) :: my_thread, noff integer(kind=ik) :: my_thread, noff
#endif #endif
real(kind=rk) :: ttt real(kind=c_double) :: ttt ! MPI_WTIME always needs double
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Currently (on Sandy Bridge), single is faster than double ! Currently (on Sandy Bridge), single is faster than double
......
...@@ -52,7 +52,7 @@ module compute_hh_trafo_real ...@@ -52,7 +52,7 @@ module compute_hh_trafo_real
#endif #endif
implicit none implicit none
include "mpif.h" include "mpif.h"
real(kind=rk), intent(inout) :: kernel_time real(kind=c_double), intent(inout) :: kernel_time ! MPI_WTIME always needs double
integer(kind=lik) :: kernel_flops integer(kind=lik) :: kernel_flops
integer(kind=ik), intent(in) :: nbw, max_blk_size integer(kind=ik), intent(in) :: nbw, max_blk_size
real(kind=rk) :: bcast_buffer(nbw,max_blk_size) real(kind=rk) :: bcast_buffer(nbw,max_blk_size)
...@@ -82,7 +82,8 @@ module compute_hh_trafo_real ...@@ -82,7 +82,8 @@ module compute_hh_trafo_real
integer(kind=ik) :: my_thread, noff integer(kind=ik) :: my_thread, noff
#endif #endif
integer(kind=ik) :: j, nl, jj, jjj integer(kind=ik) :: j, nl, jj, jjj
real(kind=rk) :: w(nbw,6), ttt real(kind=rk) :: w(nbw,6)
real(kind=c_double) :: ttt ! MPI_WTIME always needs double
if (THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_GPU) then if (THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_GPU) then
! ncols - indicates the number of HH reflectors to apply; at least 1 must be available ! ncols - indicates the number of HH reflectors to apply; at least 1 must be available
......
...@@ -12,8 +12,17 @@ module cuda_functions ...@@ -12,8 +12,17 @@ module cuda_functions
integer(kind=ik) :: cudaHostRegisterMapped integer(kind=ik) :: cudaHostRegisterMapped
integer(kind=ik) :: cudaMemcpyDeviceToDevice integer(kind=ik) :: cudaMemcpyDeviceToDevice
integer(kind=c_size_t), parameter :: size_of_real_datatype = 8_8 #ifdef DOUBLE_PRECISION_REAL
integer(kind=c_size_t), parameter :: size_of_complex_datatype = 16_8 integer(kind=c_size_t), parameter :: size_of_real_datatype = 8_rk
#else
integer(kind=c_size_t), parameter :: size_of_real_datatype = 4_rk
#endif
#ifdef DOUBLE_PRECISION_COMPLEX
integer(kind=c_size_t), parameter :: size_of_complex_datatype = 16_ck
#else
integer(kind=c_size_t), parameter :: size_of_complex_datatype = 8_ck
#endif
! functions to set and query the CUDA devices ! functions to set and query the CUDA devices
...@@ -193,6 +202,20 @@ module cuda_functions ...@@ -193,6 +202,20 @@ module cuda_functions
end subroutine cublas_dgemm_c end subroutine cublas_dgemm_c
end interface end interface
interface
subroutine cublas_sgemm_c(cta, ctb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) bind(C,name='cublasSgemm')
use iso_c_binding
implicit none
character(1,C_CHAR),value :: cta, ctb
integer(kind=C_INT),value :: m,n,k
integer(kind=C_INT), intent(in), value :: lda,ldb,ldc
real(kind=C_FLOAT),value :: alpha,beta
integer(kind=C_intptr_T), value :: a, b, c
end subroutine cublas_sgemm_c
end interface
interface interface
subroutine cublas_dtrmm_c(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb) bind(C,name='cublasDtrmm') subroutine cublas_dtrmm_c(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb) bind(C,name='cublasDtrmm')
...@@ -207,6 +230,20 @@ module cuda_functions ...@@ -207,6 +230,20 @@ module cuda_functions
end subroutine cublas_dtrmm_c end subroutine cublas_dtrmm_c
end interface end interface
interface
subroutine cublas_strmm_c(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb) bind(C,name='cublasStrmm')
use iso_c_binding
implicit none
character(1,C_CHAR),value :: side, uplo, trans, diag
integer(kind=C_INT),value :: m,n
integer(kind=C_INT), intent(in), value :: lda,ldb
real(kind=C_FLOAT), value :: alpha
integer(kind=C_intptr_T), value :: a, b
end subroutine cublas_strmm_c
end interface
interface interface
subroutine cublas_zgemm_c(cta, ctb, m, n, k, alpha, a, lda, b, ldb, beta, c,ldc) bind(C,name='cublasZgemm') subroutine cublas_zgemm_c(cta, ctb, m, n, k, alpha, a, lda, b, ldb, beta, c,ldc) bind(C,name='cublasZgemm')
...@@ -222,6 +259,21 @@ module cuda_functions ...@@ -222,6 +259,21 @@ module cuda_functions
end subroutine cublas_zgemm_c end subroutine cublas_zgemm_c
end interface end interface
interface
subroutine cublas_cgemm_c(cta, ctb, m, n, k, alpha, a, lda, b, ldb, beta, c,ldc) bind(C,name='cublasCgemm')
use iso_c_binding
implicit none
character(1,C_CHAR),value :: cta, ctb
integer(kind=C_INT),value :: m,n,k
integer(kind=C_INT), intent(in), value :: lda,ldb,ldc
complex(kind=C_FLOAT),value :: alpha,beta
integer(kind=C_intptr_T), value :: a, b, c
end subroutine cublas_cgemm_c
end interface
interface interface
subroutine cublas_ztrmm_c(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb) bind(C,name='cublasZtrmm') subroutine cublas_ztrmm_c(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb) bind(C,name='cublasZtrmm')
...@@ -237,6 +289,22 @@ module cuda_functions ...@@ -237,6 +289,22 @@ module cuda_functions
end subroutine cublas_ztrmm_c end subroutine cublas_ztrmm_c
end interface end interface
interface
subroutine cublas_ctrmm_c(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb) bind(C,name='cublasCtrmm')
use iso_c_binding
implicit none
character(1,C_CHAR),value :: side, uplo, trans, diag
integer(kind=C_INT),value :: m,n
integer(kind=C_INT), intent(in), value :: lda,ldb
complex(kind=C_FLOAT), value :: alpha
integer(kind=C_intptr_T), value :: a, b
end subroutine cublas_ctrmm_c
end interface
contains contains
! functions to set and query the CUDA devices ! functions to set and query the CUDA devices
...@@ -448,6 +516,20 @@ module cuda_functions ...@@ -448,6 +516,20 @@ module cuda_functions
#endif #endif
end subroutine cublas_dgemm end subroutine cublas_dgemm
subroutine cublas_sgemm(cta, ctb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
use iso_c_binding
implicit none
character(1,C_CHAR) :: cta, ctb
integer(kind=C_INT) :: m,n,k
integer(kind=C_INT), intent(in) :: lda,ldb,ldc
real(kind=C_FLOAT) :: alpha,beta
integer(kind=C_intptr_T) :: a, b, c
#ifdef WITH_GPU_VERSION
call cublas_sgemm_c(cta, ctb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
#endif
end subroutine cublas_sgemm
subroutine cublas_dtrmm(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb) subroutine cublas_dtrmm(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb)
use iso_c_binding use iso_c_binding
...@@ -463,6 +545,21 @@ module cuda_functions ...@@ -463,6 +545,21 @@ module cuda_functions
#endif #endif
end subroutine cublas_dtrmm end subroutine cublas_dtrmm
subroutine cublas_strmm(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb)
use iso_c_binding
implicit none
character(1,C_CHAR) :: side, uplo, trans, diag
integer(kind=C_INT) :: m,n
integer(kind=C_INT), intent(in) :: lda,ldb
real(kind=C_FLOAT) :: alpha
integer(kind=C_intptr_T) :: a, b
#ifdef WITH_GPU_VERSION
call cublas_strmm_c(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb)
#endif
end subroutine cublas_strmm
subroutine cublas_zgemm(cta, ctb, m, n, k, alpha, a, lda, b, ldb, beta, c,ldc) subroutine cublas_zgemm(cta, ctb, m, n, k, alpha, a, lda, b, ldb, beta, c,ldc)
use iso_c_binding use iso_c_binding
...@@ -478,6 +575,21 @@ module cuda_functions ...@@ -478,6 +575,21 @@ module cuda_functions
#endif #endif
end subroutine cublas_zgemm end subroutine cublas_zgemm
subroutine cublas_cgemm(cta, ctb, m, n, k, alpha, a, lda, b, ldb, beta, c,ldc)
use iso_c_binding
implicit none
character(1,C_CHAR) :: cta, ctb
integer(kind=C_INT) :: m,n,k
integer(kind=C_INT), intent(in) :: lda,ldb,ldc
complex(kind=C_FLOAT) :: alpha,beta
integer(kind=C_intptr_T) :: a, b, c
#ifdef WITH_GPU_VERSION
call cublas_cgemm_c(cta, ctb, m, n, k, alpha, a, lda, b, ldb, beta, c,ldc)
#endif
end subroutine cublas_cgemm
subroutine cublas_ztrmm(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb) subroutine cublas_ztrmm(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb)
use iso_c_binding use iso_c_binding
...@@ -493,6 +605,20 @@ module cuda_functions ...@@ -493,6 +605,20 @@ module cuda_functions
#endif #endif
end subroutine cublas_ztrmm end subroutine cublas_ztrmm
subroutine cublas_ctrmm(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb)
use iso_c_binding
implicit none
character(1,C_CHAR) :: side, uplo, trans, diag
integer(kind=C_INT) :: m,n
integer(kind=C_INT), intent(in) :: lda,ldb
complex(kind=C_FLOAT) :: alpha
integer(kind=C_intptr_T) :: a, b
#ifdef WITH_GPU_VERSION
call cublas_ctrmm_c(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb)
#endif
end subroutine cublas_ctrmm
end module cuda_functions end module cuda_functions
...@@ -73,11 +73,15 @@ int main(int argc, char** argv) { ...@@ -73,11 +73,15 @@ int main(int argc, char** argv) {
int na_rows, na_cols; int na_rows, na_cols;
double startVal; double startVal;
#ifdef DOUBLE_PRECISION_COMPLEX
complex double *a, *z, *as, *tmp1, *tmp2; complex double *a, *z, *as, *tmp1, *tmp2;
double *ev, *xr; double *ev, *xr;
#else
complex *a, *z, *as, *tmp1, *tmp2;
float *ev, *xr;
#endif
int *iseed; int *iseed;
int success; int success;
...@@ -105,6 +109,11 @@ int main(int argc, char** argv) { ...@@ -105,6 +109,11 @@ int main(int argc, char** argv) {
printf("\n"); printf("\n");
#ifdef DOUBLE_PRECISION_COMPLEX
printf(" Double precision version of ELPA2 is used. \n");
#else
printf(" Single precision version of ELPA2 is used. \n");
#endif
} }
status = 0; status = 0;
...@@ -162,7 +171,7 @@ int main(int argc, char** argv) { ...@@ -162,7 +171,7 @@ int main(int argc, char** argv) {
printf("Allocating matrices with na_rows=%d and na_cols=%d\n",na_rows, na_cols); printf("Allocating matrices with na_rows=%d and na_cols=%d\n",na_rows, na_cols);
printf("\n"); printf("\n");
} }
#ifdef DOUBLE_PRECISION_COMPLEX
a = malloc(na_rows*na_cols*sizeof(complex double)); a = malloc(na_rows*na_cols*sizeof(complex double));
z = malloc(na_rows*na_cols*sizeof(complex double)); z = malloc(na_rows*na_cols*sizeof(complex double));
as = malloc(na_rows*na_cols*sizeof(complex double)); as = malloc(na_rows*na_cols*sizeof(complex double));
...@@ -174,10 +183,25 @@ int main(int argc, char** argv) { ...@@ -174,10 +183,25 @@ int main(int argc, char** argv) {
tmp1 = malloc(na_rows*na_cols*sizeof(complex double)); tmp1 = malloc(na_rows*na_cols*sizeof(complex double));
tmp2 = malloc(na_rows*na_cols*sizeof(complex double)); tmp2 = malloc(na_rows*na_cols*sizeof(complex double));
#else
a = malloc(na_rows*na_cols*sizeof(complex));
z = malloc(na_rows*na_cols*sizeof(complex));
as = malloc(na_rows*na_cols*sizeof(complex));
xr = malloc(na_rows*na_cols*sizeof(float));
iseed = malloc(4096*sizeof(int));
prepare_matrix_complex_from_fortran(na, myid, na_rows, na_cols, sc_desc, iseed, xr, a, z, as); ev = malloc(na*sizeof(float));
tmp1 = malloc(na_rows*na_cols*sizeof(complex));
tmp2 = malloc(na_rows*na_cols*sizeof(complex));
#endif
iseed = malloc(4096*sizeof(int));
#ifdef DOUBLE_PRECISION_COMPLEX
prepare_matrix_complex_from_fortran_double_precision(na, myid, na_rows, na_cols, sc_desc, iseed, xr, a, z, as);
#else
prepare_matrix_complex_from_fortran_single_precision(na, myid, na_rows, na_cols, sc_desc, iseed, xr, a, z, as);
#endif
free(xr); free(xr);
...@@ -189,7 +213,11 @@ int main(int argc, char** argv) { ...@@ -189,7 +213,11 @@ int main(int argc, char** argv) {
mpierr = MPI_Barrier(MPI_COMM_WORLD); mpierr = MPI_Barrier(MPI_COMM_WORLD);
THIS_COMPLEX_ELPA_KERNEL_API = ELPA2_COMPLEX_KERNEL_GENERIC; THIS_COMPLEX_ELPA_KERNEL_API = ELPA2_COMPLEX_KERNEL_GENERIC;
success = elpa_solve_evp_complex_2stage(na, nev, a, na_rows, ev, z, na_rows, nblk, na_cols, mpi_comm_rows, mpi_comm_cols, my_mpi_comm_world, THIS_COMPLEX_ELPA_KERNEL_API); #ifdef DOUBLE_PRECISION_COMPLEX
success = elpa_solve_evp_complex_2stage_double_precision(na, nev, a, na_rows, ev, z, na_rows, nblk, na_cols, mpi_comm_rows, mpi_comm_cols, my_mpi_comm_world, THIS_COMPLEX_ELPA_KERNEL_API);
#else
success = elpa_solve_evp_complex_2stage_single_precision(na, nev, a, na_rows, ev, z, na_rows, nblk, na_cols, mpi_comm_rows, mpi_comm_cols, my_mpi_comm_world, THIS_COMPLEX_ELPA_KERNEL_API);
#endif
if (success != 1) { if (success != 1) {
printf("error in ELPA solve \n"); printf("error in ELPA solve \n");
...@@ -204,8 +232,11 @@ int main(int argc, char** argv) { ...@@ -204,8 +232,11 @@ int main(int argc, char** argv) {
} }
/* check the results */ /* check the results */
status = check_correctness_complex_from_fortran(na, nev, na_rows, na_cols, as, z, ev, sc_desc, myid, tmp1, tmp2); #ifdef DOUBLE_PRECISION_COMPLEX
status = check_correctness_complex_from_fortran_double_precision(na, nev, na_rows, na_cols, as, z, ev, sc_desc, myid, tmp1, tmp2);
#else
status = check_correctness_complex_from_fortran_single_precision(na, nev, na_rows, na_cols, as, z, ev, sc_desc, myid, tmp1, tmp2);
#endif
if (status !=0){ if (status !=0){
printf("The computed EVs are not correct !\n"); printf("The computed EVs are not correct !\n");
} }
......
...@@ -72,9 +72,11 @@ int main(int argc, char** argv) { ...@@ -72,9 +72,11 @@ int main(int argc, char** argv) {
int na_rows, na_cols; int na_rows, na_cols;
double startVal; double startVal;
#ifdef DOUBLE_PRECISION_REAL
double *a, *z, *as, *ev, *tmp1, *tmp2; double *a, *z, *as, *ev, *tmp1, *tmp2;
#else
float *a, *z, *as, *ev, *tmp1, *tmp2;
#endif
int *iseed; int *iseed;
int success; int success;
...@@ -100,7 +102,11 @@ int main(int argc, char** argv) { ...@@ -100,7 +102,11 @@ int main(int argc, char** argv) {
printf("as it's Fortran counterpart. It's only purpose is to show how \n"); printf("as it's Fortran counterpart. It's only purpose is to show how \n");
printf("to evoke ELPA1 from a c programm\n"); printf("to evoke ELPA1 from a c programm\n");
printf("\n"); printf("\n");
#ifdef DOUBLE_PRECISION_REAL
printf(" Double precision version of ELPA2 is used. \n");
#else
printf(" Single precision version of ELPA2 is used. \n");
#endif
} }
status = 0; status = 0;
...@@ -158,7 +164,7 @@ int main(int argc, char** argv) { ...@@ -158,7 +164,7 @@ int main(int argc, char** argv) {
printf("Allocating matrices with na_rows=%d and na_cols=%d\n",na_rows, na_cols); printf("Allocating matrices with na_rows=%d and na_cols=%d\n",na_rows, na_cols);
printf("\n"); printf("\n");
} }
#ifdef DOUBLE_PRECISION_REAL
a = malloc(na_rows*na_cols*sizeof(double)); a = malloc(na_rows*na_cols*sizeof(double));
z = malloc(na_rows*na_cols*sizeof(double)); z = malloc(na_rows*na_cols*sizeof(double));
as = malloc(na_rows*na_cols*sizeof(double)); as = malloc(