Commit 30ded3ab authored by Andreas Marek's avatar Andreas Marek

Test hermitian_multiply decomposition with new interface directly

parent 90a156d6
......@@ -29,6 +29,7 @@ test_type_flag = {
"eigenvalues" : "-DTEST_EIGENVALUES",
"solve_tridiagonal" : "-DTEST_SOLVE_TRIDIAGONAL",
"cholesky" : "-DTEST_CHOLESKY",
"hermitian_multiply" : "-DTEST_HERMITIAN_MULTIPLY",
}
layout_flag = {
......@@ -58,6 +59,12 @@ for m, g, t, p, d, s, l in product(
if (t == "cholesky" and (s == "2stage")):
continue
if (t == "hermitian_multiply" and (s == "2stage")):
continue
if (t == "hermitian_multiply" and (p == "single")):
continue
for kernel in ["all_kernels", "default_kernel"] if s == "2stage" else ["nokernel"]:
endifs = 0
extra_flags = []
......
......@@ -57,7 +57,7 @@ error: define exactly one of TEST_SINGLE or TEST_DOUBLE
#endif
#if !(defined(TEST_SOLVER_1STAGE) ^ defined(TEST_SOLVER_2STAGE) ^ defined(TEST_SCALAPACK_ALL))
error: define exactly one of TEST_SOLVER_1STAGE or TEST_SOLVER_2STAGE or TEST_SCALAPACK_ALL
error: define exactly one of TEST_SOLVER_1STAGE or TEST_SOLVER_2STAGE or TEST_SCALAPACK_ALL
#endif
#ifdef TEST_SOLVER_1STAGE
......@@ -135,15 +135,22 @@ program test
! The Matrix
MATRIX_TYPE, allocatable :: a(:,:), as(:,:)
#if defined(TEST_HERMITIAN_MULTIPLY)
MATRIX_TYPE, allocatable :: b(:,:), c(:,:)
#endif
! eigenvectors
MATRIX_TYPE, allocatable :: z(:,:)
! eigenvalues
EV_TYPE, allocatable :: ev(:)
#if defined(TEST_EIGENVALUES) || defined(TEST_SOLVE_TRIDIAGONAL) || defined(TEST_EIGENVECTORS) || defined(TEST_CHOLESKY)
EV_TYPE, allocatable :: d(:), sd(:), ev_analytic(:), ds(:), sds(:)
EV_TYPE, allocatable :: ev(:), ev_analytic(:)
#if defined(TEST_EIGENVALUES) || defined(TEST_SOLVE_TRIDIAGONAL) || defined(TEST_EIGENVECTORS) || defined(TEST_HERMITIAN_MULTIPLY)
EV_TYPE, allocatable :: d(:), sd(:), ds(:), sds(:)
EV_TYPE :: diagonalELement, subdiagonalElement
#endif
#if defined(TEST_CHOLESKY)
MATRIX_TYPE, allocatable :: d(:), sd(:), ds(:), sds(:)
MATRIX_TYPE :: diagonalELement, subdiagonalElement
#endif
integer :: error, status
......@@ -218,6 +225,11 @@ program test
allocate(z (na_rows,na_cols))
allocate(ev(na))
#ifdef TEST_HERMITIAN_MULTIPLY
allocate(b (na_rows,na_cols))
allocate(c (na_rows,na_cols))
#endif
#if defined(TEST_EIGENVALUES) || defined(TEST_SOLVE_TRIDIAGONAL) || defined(TEST_EIGENVECTORS) || defined(TEST_CHOLESKY)
allocate(d (na), ds(na))
allocate(sd (na), sds(na))
......@@ -228,10 +240,10 @@ program test
z(:,:) = 0.0
ev(:) = 0.0
#ifdef TEST_EIGENVECTORS
#if defined(TEST_EIGENVECTORS) || defined(TEST_HERMITIAN_MULTIPLY)
#ifdef TEST_MATRIX_ANALYTIC
call prepare_matrix_analytic(na, a, nblk, myid, np_rows, np_cols, my_prow, my_pcol)
as(:,:) = a
as(:,:) = a
#else
if (nev .ge. 1) then
call prepare_matrix(na, myid, sc_desc, a, z, as)
......@@ -248,9 +260,37 @@ program test
d, sd, ds, sds, a, as, nblk, np_rows, &
np_cols, my_prow, my_pcol)
endif
#ifdef TEST_HERMITIAN_MULTIPLY
#if REALCASE == 1
#ifdef DOUBLE_PRECISION_REAL
b(:,:) = 2.0_rk8 * a(:,:)
c(:,:) = 0.0_rk8
#else
b(:,:) = 2.0_rk4 * a(:,:)
c(:,:) = 0.0_rk4
#endif
#endif
#if COMPLEXCASE == 1
#ifdef DOUBLE_PRECISION_COMPLEX
b(:,:) = 2.0_ck8 * a(:,:)
c(:,:) = 0.0_ck8
#else
b(:,:) = 2.0_ck4 * a(:,:)
c(:,:) = 0.0_ck4
#endif
#endif
#endif /* TEST_HERMITIAN_MULTIPLY */
#endif /* TEST_MATRIX_ANALYTIC */
#endif /* defined(TEST_EIGENVECTORS) || defined(TEST_HERMITIAN_MULTIPLY) */
#if defined(TEST_EIGENVALUES) || defined(TEST_SOLVE_TRIDIAGONAL)
#ifdef TEST_SINGLE
......@@ -268,15 +308,17 @@ program test
#if defined(TEST_CHOLESKY)
#ifdef TEST_SINGLE
diagonalElement = 2.546_c_float
subdiagonalElement = 0.0_c_float
diagonalElement = (2.546_c_float, 0.0_c_float)
subdiagonalElement = (0.0_c_float, 0.0_c_float)
#else
diagonalElement = 2.546_c_double
subdiagonalElement = 0.0_c_double
diagonalElement = (2.546_c_double, 0.0_c_double)
subdiagonalElement = (0.0_c_double, 0.0_c_double)
#endif
call prepare_toeplitz_matrix(na, diagonalElement, subdiagonalElement, &
d, sd, ds, sds, a, as, nblk, np_rows, &
np_cols, my_prow, my_pcol)
#endif /* TEST_CHOLESKY */
e => elpa_allocate()
......@@ -375,6 +417,11 @@ program test
call e%timer_stop("e%cholesky()")
#endif
#if defined(TEST_HERMITIAN_MULTIPLY)
call e%timer_start("e%hermitian_multiply()")
call e%hermitian_multiply('F','F', na, a, b, na_rows, na_cols, c, na_rows, na_cols, error)
call e%timer_stop("e%hermitian_multiply()")
#endif
assert_elpa_ok(error)
......@@ -399,6 +446,9 @@ program test
#ifdef TEST_CHOLESKY
call e%print_times("e%cholesky()")
#endif
#ifdef TEST_HERMITIAN_MULTIPLY
call e%print_times("e%hermitian_multiply()")
#endif
#endif /* TEST_ALL_KERNELS */
endif
......@@ -434,6 +484,12 @@ program test
call check_status(status, myid)
#endif
#if defined(TEST_HERMITIAN_MULTIPLY)
status = check_correctness_hermitian_multiply(na, a, b, c, na_rows, sc_desc, myid )
call check_status(status, myid)
#endif
if (myid == 0) then
print *, ""
endif
......@@ -454,6 +510,11 @@ program test
deallocate(z)
deallocate(ev)
#ifdef TEST_HERMITIAN_MULTIPLY
deallocate(b)
deallocate(c)
#endif
#if defined(TEST_EIGENVALUES) || defined(TEST_SOLVE_TRIDIAGONAL) || defined(TEST_EIGENVECTORS) || defined(TEST_CHOLESKY)
deallocate(d, ds)
deallocate(sd, sds)
......
......@@ -78,6 +78,18 @@ module test_check_correctness
#endif
end interface
interface check_correctness_hermitian_multiply
module procedure check_correctness_hermitian_multiply_complex_double
module procedure check_correctness_hermitian_multiply_real_double
#ifdef WANT_SINGLE_PRECISION_REAL
module procedure check_correctness_hermitian_multiply_real_single
#endif
#ifdef WANT_SINGLE_PRECISION_COMPLEX
module procedure check_correctness_hermitian_multiply_complex_single
#endif
end interface
contains
#define COMPLEXCASE 1
......
......@@ -607,6 +607,7 @@ function check_correctness_&
! compare tmp2 with original matrix
tmp2(:,:) = tmp2(:,:) - as(:,:)
#if REALCASE == 1
#ifdef WITH_MPI
......@@ -689,6 +690,248 @@ function check_correctness_&
end function
function check_correctness_hermitian_multiply_&
&MATH_DATATYPE&
&_&
&PRECISION&
& (na, a, b, c, na_rows, sc_desc, myid) result(status)
implicit none
#include "../../src/general/precision_kinds.F90"
integer(kind=ik) :: status
integer(kind=ik), intent(in) :: na, myid, na_rows
#if REALCASE == 1
real(kind=rck), intent(in) :: a(:,:), b(:,:), c(:,:)
real(kind=rck), dimension(size(a,dim=1),size(a,dim=2)) :: tmp1, tmp2
real(kind=rck) :: norm, normmax
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_REAL
real(kind=rck) :: pdlange
#else
real(kind=rck) :: pslange
#endif
#else /* WITH_MPI */
#ifdef DOUBLE_PRECISION_REAL
real(kind=rck) :: dlange
#else
real(kind=rck) :: slange
#endif
#endif /* WITH_MPI */
#endif /* REALCASE */
#if COMPLEXCASE == 1
complex(kind=rck), intent(in) :: a(:,:), b(:,:), c(:,:)
complex(kind=rck), dimension(size(a,dim=1),size(a,dim=2)) :: tmp1, tmp2
real(kind=rck) :: norm, normmax
#ifdef DOUBLE_PRECISION_COMPLEX
complex(kind=ck8), parameter :: CZERO = (0.0_rk8,0.0_rk8), CONE = (1.0_rk8,0.0_rk8)
#else
complex(kind=ck4), parameter :: CZERO = (0.0_rk4,0.0_rk4), CONE = (1.0_rk4,0.0_rk8)
#endif
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_COMPLEX
complex(kind=rck) :: pzlange
#else
complex(kind=rck) :: pclange
#endif
#else /* WITH_MPI */
#ifdef DOUBLE_PRECISION_COMPLEX
complex(kind=rck) :: zlange
#else
complex(kind=rck) :: clange
#endif
#endif /* WITH_MPI */
#endif /* COMPLEXCASE */
integer(kind=ik) :: sc_desc(:)
real(kind=rck) :: err, errmax
integer :: mpierr
status = 0
#if REALCASE == 1
#ifdef DOUBLE_PRECISION_REAL
tmp1(:,:) = 0.0_rk8
#else
tmp1(:,:) = 0.0_rk4
#endif
#endif /* REALCASE */
#if COMPLEXCASE == 1
#ifdef DOUBLE_PRECISION_COMPLEX
tmp1(:,:) = 0.0_ck8
#else
tmp1(:,:) = 0.0_ck4
#endif
#endif /* COMPLEXCASE */
#if REALCASE == 1
! tmp1 = a**T
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_REAL
call pdtran(na, na, 1.0_rk8, a, 1, 1, sc_desc, 0.0_rk8, tmp1, 1, 1, sc_desc)
#else
call pstran(na, na, 1.0_rk4, a, 1, 1, sc_desc, 0.0_rk4, tmp1, 1, 1, sc_desc)
#endif
#else /* WITH_MPI */
tmp1 = transpose(a)
#endif /* WITH_MPI */
! tmp2 = tmp1 * b
#ifdef DOUBLE_PRECISION_REAL
#ifdef WITH_MPI
call pdgemm("N","N", na, na, na, 1.0_rk8, tmp1, 1, 1, sc_desc, b, 1, 1, &
sc_desc, 0.0_rk8, tmp2, 1, 1, sc_desc)
#else
call dgemm("N","N", na, na, na, 1.0_rk8, tmp1, na, b, na, 0.0_rk8, tmp2, na)
#endif
#else /* DOUBLE_PRECISION_REAL */
#ifdef WITH_MPI
call psgemm("N","N", na, na, na, 1.0_rk4, tmp1, 1, 1, sc_desc, b, 1, 1, &
sc_desc, 0.0_rk4, tmp2, 1, 1, sc_desc)
#else
call sgemm("N","N", na, na, na, 1.0_rk4, tmp1, na, b, na, 0.0_rk4, tmp2, na)
#endif
#endif /* DOUBLE_PRECISION_REAL */
#endif /* REALCASE == 1 */
#if COMPLEXCASE == 1
! tmp1 = a**H
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_COMPLEX
call pztranc(na, na, CONE, a, 1, 1, sc_desc, CZERO, tmp1, 1, 1, sc_desc)
#else
call pctranc(na, na, CONE, a, 1, 1, sc_desc, CZERO, tmp1, 1, 1, sc_desc)
#endif
#else /* WITH_MPI */
tmp1 = transpose(conjg(a))
#endif /* WITH_MPI */
! tmp2 = tmp1 * b
#ifdef DOUBLE_PRECISION_COMPLEX
#ifdef WITH_MPI
call pzgemm("N","N", na, na, na, CONE, tmp1, 1, 1, sc_desc, b, 1, 1, &
sc_desc, CZERO, tmp2, 1, 1, sc_desc)
#else
call zgemm("N","N", na, na, na, CONE, tmp1, na, b, na, CZERO, tmp2, na)
#endif
#else /* DOUBLE_PRECISION_COMPLEX */
#ifdef WITH_MPI
call pcgemm("N","N", na, na, na, CONE, tmp1, 1, 1, sc_desc, b, 1, 1, &
sc_desc, CZERO, tmp2, 1, 1, sc_desc)
#else
call cgemm("N","N", na, na, na, CONE, tmp1, na, b, na, CZERO, tmp2, na)
#endif
#endif /* DOUBLE_PRECISION_COMPLEX */
#endif /* COMPLEXCASE == 1 */
! compare tmp2 with c
tmp2(:,:) = tmp2(:,:) - c(:,:)
#if REALCASE == 1
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_REAL
norm = pdlange("M",na, na, tmp2, 1, 1, sc_desc, tmp1)
#else
norm = pslange("M",na, na, tmp2, 1, 1, sc_desc, tmp1)
#endif
#else /* WITH_MPI */
#ifdef DOUBLE_PRECISION_REAL
norm = dlange("M", na, na, tmp2, na_rows, tmp1)
#else
norm = slange("M", na, na, tmp2, na_rows, tmp1)
#endif
#endif /* WITH_MPI */
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_REAL
call mpi_allreduce(norm,normmax,1,MPI_REAL8,MPI_MAX,MPI_COMM_WORLD,mpierr)
#else
call mpi_allreduce(norm,normmax,1,MPI_REAL4,MPI_MAX,MPI_COMM_WORLD,mpierr)
#endif
#else /* WITH_MPI */
normmax = norm
#endif /* WITH_MPI */
#endif /* REALCASE == 1 */
#if COMPLEXCASE == 1
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_COMPLEX
norm = pzlange("M",na, na, tmp2, 1, 1, sc_desc, tmp1)
#else
norm = pclange("M",na, na, tmp2, 1, 1, sc_desc, tmp1)
#endif
#else /* WITH_MPI */
#ifdef DOUBLE_PRECISION_COMPLEX
norm = zlange("M", na, na, tmp2, na_rows, tmp1)
#else
norm = clange("M", na, na, tmp2, na_rows, tmp1)
#endif
#endif /* WITH_MPI */
#ifdef WITH_MPI
#ifdef DOUBLE_PRECISION_COMPLEX
call mpi_allreduce(norm,normmax,1,MPI_REAL8,MPI_MAX,MPI_COMM_WORLD,mpierr)
#else
call mpi_allreduce(norm,normmax,1,MPI_REAL4,MPI_MAX,MPI_COMM_WORLD,mpierr)
#endif
#else /* WITH_MPI */
normmax = norm
#endif /* WITH_MPI */
#endif /* REALCASE == 1 */
if (myid .eq. 0) then
print *," Maximum error of result: ", normmax
endif
#ifdef DOUBLE_PRECISION_REAL
if (normmax .gt. 5e-12_rk8) then
status = 1
endif
#else
if (normmax .gt. 5e-4_rk4) then
status = 1
endif
#endif
end function
! vim: syntax=fortran
......@@ -59,11 +59,13 @@ module test_prepare_matrix
interface prepare_toeplitz_matrix
module procedure prepare_toeplitz_matrix_complex_double
module procedure prepare_toeplitz_matrix_real_double
module procedure prepare_toeplitz_matrix_mixed_complex_complex_double
#ifdef WANT_SINGLE_PRECISION_REAL
module procedure prepare_toeplitz_matrix_real_single
#endif
#ifdef WANT_SINGLE_PRECISION_COMPLEX
module procedure prepare_toeplitz_matrix_complex_single
module procedure prepare_toeplitz_matrix_mixed_complex_complex_single
#endif
end interface
......
......@@ -216,9 +216,19 @@ subroutine prepare_matrix_&
implicit none
integer, intent(in) :: na, nblk, np_rows, np_cols, my_prow, my_pcol
#if REALCASE == 1
real(kind=C_DATATYPE_KIND) :: diagonalElement, subdiagonalElement
real(kind=C_DATATYPE_KIND) :: d(:), sd(:), ds(:), sds(:)
#endif
#if COMPLEXCASE == 1
complex(kind=C_DATATYPE_KIND) :: diagonalElement, subdiagonalElement
complex(kind=C_DATATYPE_KIND) :: d(:), sd(:), ds(:), sds(:)
#endif
#if REALCASE == 1
real(kind=C_DATATYPE_KIND) :: a(:,:), as(:,:)
#endif
......@@ -254,5 +264,62 @@ subroutine prepare_matrix_&
as = a
end subroutine
subroutine prepare_toeplitz_matrix_mixed_complex&
&_&
&MATH_DATATYPE&
&_&
&PRECISION&
#if COMPLEXCASE == 1
& (na, diagonalElement, subdiagonalElement, d, sd, ds, sds, a, as, &
nblk, np_rows, np_cols, my_prow, my_pcol)
#endif
#if REALCASE == 1
& (na, diagonalElement, subdiagonalElement, d, sd, ds, sds, &
nblk, np_rows, np_cols, my_prow, my_pcol)
#endif
use test_util
implicit none
integer, intent(in) :: na, nblk, np_rows, np_cols, my_prow, my_pcol
real(kind=C_DATATYPE_KIND) :: diagonalElement, subdiagonalElement
real(kind=C_DATATYPE_KIND) :: d(:), sd(:), ds(:), sds(:)
#if COMPLEXCASE == 1
complex(kind=C_DATATYPE_KIND) :: a(:,:), as(:,:)
#endif
#if REALCASE == 1
#endif
integer :: ii, rowLocal, colLocal
#if COMPLEXCASE == 1
d(:) = diagonalElement
sd(:) = subdiagonalElement
! set up the diagonal and subdiagonals (for general solver test)
do ii=1, na ! for diagonal elements
if (map_global_array_index_to_local_index(ii, ii, rowLocal, colLocal, nblk, np_rows, np_cols, my_prow, my_pcol)) then
a(rowLocal,colLocal) = diagonalElement
endif
enddo
do ii=1, na-1
if (map_global_array_index_to_local_index(ii, ii+1, rowLocal, colLocal, nblk, np_rows, np_cols, my_prow, my_pcol)) then
a(rowLocal,colLocal) = subdiagonalElement
endif
enddo
do ii=2, na
if (map_global_array_index_to_local_index(ii, ii-1, rowLocal, colLocal, nblk, np_rows, np_cols, my_prow, my_pcol)) then
a(rowLocal,colLocal) = subdiagonalElement
endif
enddo
ds = d
sds = sd
as = a
#endif
end subroutine
! vim: syntax=fortran
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment