Commit 90a156d6 authored by Andreas Marek's avatar Andreas Marek
Browse files

Test for cholesky decomposition with new interface directly

parent 174c6c48
......@@ -28,6 +28,7 @@ test_type_flag = {
"eigenvectors" : "-DTEST_EIGENVECTORS",
"eigenvalues" : "-DTEST_EIGENVALUES",
"solve_tridiagonal" : "-DTEST_SOLVE_TRIDIAGONAL",
"cholesky" : "-DTEST_CHOLESKY",
}
layout_flag = {
......@@ -54,6 +55,9 @@ for m, g, t, p, d, s, l in product(
if (t == "solve_tridiagonal" and (s == "2stage" or d == "complex")):
continue
if (t == "cholesky" and (s == "2stage")):
continue
for kernel in ["all_kernels", "default_kernel"] if s == "2stage" else ["nokernel"]:
endifs = 0
extra_flags = []
......
......@@ -139,7 +139,7 @@ program test
MATRIX_TYPE, allocatable :: z(:,:)
! eigenvalues
EV_TYPE, allocatable :: ev(:)
#if defined(TEST_EIGENVALUES) || defined(TEST_SOLVE_TRIDIAGONAL) || defined(TEST_EIGENVECTORS)
#if defined(TEST_EIGENVALUES) || defined(TEST_SOLVE_TRIDIAGONAL) || defined(TEST_EIGENVECTORS) || defined(TEST_CHOLESKY)
EV_TYPE, allocatable :: d(:), sd(:), ev_analytic(:), ds(:), sds(:)
EV_TYPE :: diagonalELement, subdiagonalElement
#endif
......@@ -218,7 +218,7 @@ program test
allocate(z (na_rows,na_cols))
allocate(ev(na))
#if defined(TEST_EIGENVALUES) || defined(TEST_SOLVE_TRIDIAGONAL) || defined(TEST_EIGENVECTORS)
#if defined(TEST_EIGENVALUES) || defined(TEST_SOLVE_TRIDIAGONAL) || defined(TEST_EIGENVECTORS) || defined(TEST_CHOLESKY)
allocate(d (na), ds(na))
allocate(sd (na), sds(na))
allocate(ev_analytic(na))
......@@ -265,6 +265,20 @@ program test
np_cols, my_prow, my_pcol)
#endif /* EIGENVALUES OR TRIDIAGONAL */
#if defined(TEST_CHOLESKY)
#ifdef TEST_SINGLE
diagonalElement = 2.546_c_float
subdiagonalElement = 0.0_c_float
#else
diagonalElement = 2.546_c_double
subdiagonalElement = 0.0_c_double
#endif
call prepare_toeplitz_matrix(na, diagonalElement, subdiagonalElement, &
d, sd, ds, sds, a, as, nblk, np_rows, &
np_cols, my_prow, my_pcol)
#endif /* TEST_CHOLESKY */
e => elpa_allocate()
call e%set("na", na, error)
......@@ -340,7 +354,7 @@ program test
call e%eigenvectors(a, ev, z, error)
#endif
call e%timer_stop("e%eigenvectors()")
#endif
#endif /* TEST_EIGENVECTORS */
#ifdef TEST_EIGENVALUES
call e%timer_start("e%eigenvalues()")
......@@ -355,6 +369,12 @@ program test
ev(:) = d(:)
#endif
#if defined(TEST_CHOLESKY)
call e%timer_start("e%cholesky()")
call e%cholesky(a, error)
call e%timer_stop("e%cholesky()")
#endif
assert_elpa_ok(error)
......@@ -365,7 +385,8 @@ program test
if (myid .eq. 0) then
#ifdef TEST_ALL_KERNELS
call e%print_times(elpa_int_value_to_string(KERNEL_KEY, kernel))
#else
#else /* TEST_ALL_KERNELS */
#ifdef TEST_EIGENVECTORS
call e%print_times("e%eigenvectors()")
#endif
......@@ -375,7 +396,10 @@ program test
#ifdef TEST_SOLVE_TRIDIAGONAL
call e%print_times("e%solve_tridiagonal()")
#endif
#ifdef TEST_CHOLESKY
call e%print_times("e%cholesky()")
#endif
#endif /* TEST_ALL_KERNELS */
endif
#ifdef TEST_EIGENVECTORS
......@@ -405,13 +429,18 @@ program test
#endif
#endif
#if defined(TEST_CHOLESKY)
status = check_correctness_cholesky(na, a, as, na_rows, sc_desc, myid )
call check_status(status, myid)
#endif
if (myid == 0) then
print *, ""
endif
#ifdef TEST_ALL_KERNELS
a(:,:) = as(:,:)
#if defined(TEST_EIGENVALUES) || defined(TEST_SOLVE_TRIDIAGONAL) || defined(TEST_EIGENVECTORS)
#if defined(TEST_EIGENVALUES) || defined(TEST_SOLVE_TRIDIAGONAL) || defined(TEST_EIGENVECTORS) || defined(TEST_CHOLESKY)
d = ds
sd = sds
#endif
......@@ -425,7 +454,7 @@ program test
deallocate(z)
deallocate(ev)
#if defined(TEST_EIGENVALUES) || defined(TEST_SOLVE_TRIDIAGONAL) || defined(TEST_EIGENVECTORS)
#if defined(TEST_EIGENVALUES) || defined(TEST_SOLVE_TRIDIAGONAL) || defined(TEST_EIGENVECTORS) || defined(TEST_CHOLESKY)
deallocate(d, ds)
deallocate(sd, sds)
deallocate(ev_analytic)
......
......@@ -66,6 +66,18 @@ module test_check_correctness
module procedure check_correctness_eigenvalues_toeplitz_complex_single
#endif
end interface
interface check_correctness_cholesky
module procedure check_correctness_cholesky_complex_double
module procedure check_correctness_cholesky_real_double
#ifdef WANT_SINGLE_PRECISION_REAL
module procedure check_correctness_cholesky_real_single
#endif
#ifdef WANT_SINGLE_PRECISION_COMPLEX
module procedure check_correctness_cholesky_complex_single
#endif
end interface
contains
#define COMPLEXCASE 1
......
......@@ -454,4 +454,241 @@ function check_correctness_&
endif
end function
function check_correctness_cholesky_&
&MATH_DATATYPE&
&_&
&PRECISION&
& (na, a, as, na_rows, sc_desc, myid) result(status)
implicit none
#include "../../src/general/precision_kinds.F90"
integer(kind=ik) :: status
integer(kind=ik), intent(in) :: na, myid, na_rows
#if REALCASE == 1
real(kind=rck), intent(in) :: a(:,:), as(:,:)
real(kind=rck), dimension(size(as,dim=1),size(as,dim=2)) :: tmp1, tmp2
real(kind=rck) :: norm, normmax
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_REAL
real(kind=rck) :: pdlange
#else
real(kind=rck) :: pslange
#endif
#else /* WITH_MPI */
#ifdef DOUBLE_PRECISION_REAL
real(kind=rck) :: dlange
#else
real(kind=rck) :: slange
#endif
#endif /* WITH_MPI */
#endif /* REALCASE */
#if COMPLEXCASE == 1
complex(kind=rck), intent(in) :: a(:,:), as(:,:)
complex(kind=rck), dimension(size(as,dim=1),size(as,dim=2)) :: tmp1, tmp2
real(kind=rck) :: norm, normmax
#ifdef DOUBLE_PRECISION_COMPLEX
complex(kind=ck8), parameter :: CZERO = (0.0_rk8,0.0_rk8), CONE = (1.0_rk8,0.0_rk8)
#else
complex(kind=ck4), parameter :: CZERO = (0.0_rk4,0.0_rk4), CONE = (1.0_rk4,0.0_rk8)
#endif
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_COMPLEX
complex(kind=rck) :: pzlange
#else
complex(kind=rck) :: pclange
#endif
#else /* WITH_MPI */
#ifdef DOUBLE_PRECISION_COMPLEX
complex(kind=rck) :: zlange
#else
complex(kind=rck) :: clange
#endif
#endif /* WITH_MPI */
#endif /* COMPLEXCASE */
integer(kind=ik) :: sc_desc(:)
real(kind=rck) :: err, errmax
integer :: mpierr
status = 0
#if REALCASE == 1
tmp1(:,:) = 0.0_rk8
#endif
#if COMPLEXCASE == 1
tmp1(:,:) = 0.0_ck8
#endif
#if REALCASE == 1
! tmp1 = a**T
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_REAL
call pdtran(na, na, 1.0_rk8, a, 1, 1, sc_desc, 0.0_rk8, tmp1, 1, 1, sc_desc)
#else
call pstran(na, na, 1.0_rk4, a, 1, 1, sc_desc, 0.0_rk4, tmp1, 1, 1, sc_desc)
#endif
#else /* WITH_MPI */
tmp1 = transpose(a)
#endif /* WITH_MPI */
! tmp2 = a * a**T
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_REAL
call pdgemm("N","N", na, na, na, 1.0_rk8, a, 1, 1, sc_desc, tmp1, 1, 1, &
sc_desc, 0.0_rk8, tmp2, 1, 1, sc_desc)
#else
call psgemm("N","N", na, na, na, 1.0_rk4, a, 1, 1, sc_desc, tmp1, 1, 1, &
sc_desc, 0.0_rk4, tmp2, 1, 1, sc_desc)
#endif
#else /* WITH_MPI */
#ifdef DOUBLE_PRECISION_REAL
call dgemm("N","N", na, na, na, 1.0_rk8, a, na, tmp1, na, 0.0_rk8, tmp2, na)
#else
call sgemm("N","N", na, na, na, 1.0_rk4, a, na, tmp1, na, 0.0_rk4, tmp2, na)
#endif
#endif /* WITH_MPI */
#endif /* REALCASE == 1 */
#if COMPLEXCASE == 1
! tmp1 = a**H
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_COMPLEX
call pztranc(na, na, CONE, a, 1, 1, sc_desc, CZERO, tmp1, 1, 1, sc_desc)
#else
call pctranc(na, na, CONE, a, 1, 1, sc_desc, CZERO, tmp1, 1, 1, sc_desc)
#endif
#else /* WITH_MPI */
tmp1 = transpose(conjg(a))
#endif /* WITH_MPI */
! tmp2 = a * a**T
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_COMPLEX
call pzgemm("N","N", na, na, na, CONE, a, 1, 1, sc_desc, tmp1, 1, 1, &
sc_desc, CZERO, tmp2, 1, 1, sc_desc)
#else
call pcgemm("N","N", na, na, na, CONE, a, 1, 1, sc_desc, tmp1, 1, 1, &
sc_desc, CZERO, tmp2, 1, 1, sc_desc)
#endif
#else /* WITH_MPI */
#ifdef DOUBLE_PRECISION_COMPLEX
call zgemm("N","N", na, na, na, CONE, a, na, tmp1, na, CZERO, tmp2, na)
#else
call cgemm("N","N", na, na, na, CONE, a, na, tmp1, na, CZERO, tmp2, na)
#endif
#endif /* WITH_MPI */
#endif /* COMPLEXCASE == 1 */
! compare tmp2 with original matrix
tmp2(:,:) = tmp2(:,:) - as(:,:)
#if REALCASE == 1
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_REAL
norm = pdlange("M",na, na, tmp2, 1, 1, sc_desc, tmp1)
#else
norm = pslange("M",na, na, tmp2, 1, 1, sc_desc, tmp1)
#endif
#else /* WITH_MPI */
#ifdef DOUBLE_PRECISION_REAL
norm = dlange("M", na, na, tmp2, na_rows, tmp1)
#else
norm = slange("M", na, na, tmp2, na_rows, tmp1)
#endif
#endif /* WITH_MPI */
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_REAL
call mpi_allreduce(norm,normmax,1,MPI_REAL8,MPI_MAX,MPI_COMM_WORLD,mpierr)
#else
call mpi_allreduce(norm,normmax,1,MPI_REAL4,MPI_MAX,MPI_COMM_WORLD,mpierr)
#endif
#else /* WITH_MPI */
normmax = norm
#endif /* WITH_MPI */
#endif /* REALCASE == 1 */
#if COMPLEXCASE == 1
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_COMPLEX
norm = pzlange("M",na, na, tmp2, 1, 1, sc_desc, tmp1)
#else
norm = pclange("M",na, na, tmp2, 1, 1, sc_desc, tmp1)
#endif
#else /* WITH_MPI */
#ifdef DOUBLE_PRECISION_COMPLEX
norm = zlange("M", na, na, tmp2, na_rows, tmp1)
#else
norm = clange("M", na, na, tmp2, na_rows, tmp1)
#endif
#endif /* WITH_MPI */
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_allreduce(norm,normmax,1,MPI_REAL8,MPI_MAX,MPI_COMM_WORLD,mpierr)
#else
call mpi_allreduce(norm,normmax,1,MPI_REAL4,MPI_MAX,MPI_COMM_WORLD,mpierr)
#endif
#else /* WITH_MPI */
normmax = norm
#endif /* WITH_MPI */
#endif /* REALCASE == 1 */
if (myid .eq. 0) then
print *," Maximum error of result: ", normmax
endif
#ifdef DOUBLE_PRECISION_REAL
if (normmax .gt. 5e-12_rk8) then
status = 1
endif
#else
if (normmax .gt. 5e-4_rk4) then
status = 1
endif
#endif
end function
! vim: syntax=fortran
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment