Commit 95cb4b69 authored by Pavel Kus's avatar Pavel Kus

using cannons algorithm for the back transformation as well

parent 47ca86bb
...@@ -452,17 +452,17 @@ void d_cannons_triang_rectangular(double* U, double* B, int np_rows, int np_cols ...@@ -452,17 +452,17 @@ void d_cannons_triang_rectangular(double* U, double* B, int np_rows, int np_cols
!f> end subroutine !f> end subroutine
!f> end interface !f> end interface
*/ */
void d_cannons_triang_rectantular_c(double* U, double* B, int local_rows, int local_cols, int np_rows, int np_cols, void d_cannons_triang_rectangular_c(double* U, double* B, int local_rows, int local_cols, int np_rows, int np_cols,
int my_prow, int my_pcol, int* u_desc, int* b_desc, double *Res, int row_comm, int col_comm) int my_prow, int my_pcol, int* u_desc, int* b_desc, double *Res, int row_comm, int col_comm)
{ {
#ifdef WITH_MPI #ifdef WITH_MPI
MPI_Comm c_row_comm = MPI_Comm_f2c(row_comm); MPI_Comm c_row_comm = MPI_Comm_f2c(row_comm);
MPI_Comm c_col_comm = MPI_Comm_f2c(col_comm); MPI_Comm c_col_comm = MPI_Comm_f2c(col_comm);
//int c_my_prow, c_my_pcol; // int c_my_prow, c_my_pcol;
//MPI_Comm_rank(c_row_comm, &c_my_prow); // MPI_Comm_rank(c_row_comm, &c_my_prow);
//MPI_Comm_rank(c_col_comm, &c_my_pcol); // MPI_Comm_rank(c_col_comm, &c_my_pcol);
//printf("FORT<->C row: %d<->%d, col: %d<->%d\n", my_prow, c_my_prow, my_pcol, c_my_pcol); // printf("FORT<->C row: %d<->%d, col: %d<->%d\n", my_prow, c_my_prow, my_pcol, c_my_pcol);
// BEWARE // BEWARE
// in the cannons algorithm, column and row communicators are exchanged // in the cannons algorithm, column and row communicators are exchanged
......
...@@ -922,10 +922,10 @@ void d_cannons_reduction_c(double* A, double* U, int local_rows, int local_cols, ...@@ -922,10 +922,10 @@ void d_cannons_reduction_c(double* A, double* U, int local_rows, int local_cols,
MPI_Comm c_row_comm = MPI_Comm_f2c(row_comm); MPI_Comm c_row_comm = MPI_Comm_f2c(row_comm);
MPI_Comm c_col_comm = MPI_Comm_f2c(col_comm); MPI_Comm c_col_comm = MPI_Comm_f2c(col_comm);
//int c_my_prow, c_my_pcol; // int c_my_prow, c_my_pcol;
//MPI_Comm_rank(c_row_comm, &c_my_prow); // MPI_Comm_rank(c_row_comm, &c_my_prow);
//MPI_Comm_rank(c_col_comm, &c_my_pcol); // MPI_Comm_rank(c_col_comm, &c_my_pcol);
//printf("FORT<->C row: %d<->%d, col: %d<->%d\n", my_prow, c_my_prow, my_pcol, c_my_pcol); // printf("FORT<->C row: %d<->%d, col: %d<->%d\n", my_prow, c_my_prow, my_pcol, c_my_pcol);
// BEWARE // BEWARE
// in the cannons algorithm, column and row communicators are exchanged // in the cannons algorithm, column and row communicators are exchanged
......
...@@ -348,8 +348,9 @@ module elpa_impl ...@@ -348,8 +348,9 @@ module elpa_impl
error = self%setup() error = self%setup()
end function end function
function elpa_construct_scalapack_descriptor(self, sc_desc) result(error) function elpa_construct_scalapack_descriptor(self, sc_desc, rectangular_for_ev) result(error)
class(elpa_impl_t), intent(inout) :: self class(elpa_impl_t), intent(inout) :: self
logical, intent(in) :: rectangular_for_ev
integer :: error, blacs_ctx integer :: error, blacs_ctx
integer, intent(out) :: sc_desc(SC_DESC_LEN) integer, intent(out) :: sc_desc(SC_DESC_LEN)
...@@ -363,7 +364,11 @@ module elpa_impl ...@@ -363,7 +364,11 @@ module elpa_impl
sc_desc(1) = 1 sc_desc(1) = 1
sc_desc(2) = blacs_ctx sc_desc(2) = blacs_ctx
sc_desc(3) = self%na sc_desc(3) = self%na
sc_desc(4) = self%na if(rectangular_for_ev) then
sc_desc(4) = self%nev
else
sc_desc(4) = self%na
endif
sc_desc(5) = self%nblk sc_desc(5) = self%nblk
sc_desc(6) = self%nblk sc_desc(6) = self%nblk
sc_desc(7) = 0 sc_desc(7) = 0
...@@ -736,6 +741,12 @@ module elpa_impl ...@@ -736,6 +741,12 @@ module elpa_impl
#undef SINGLE_PRECISION #undef SINGLE_PRECISION
#endif #endif
! function use_cannons_algorithm(self) result(use_cannon, do_print)
! class(elpa_impl_t), intent(inout), target :: self
! logical :: use_cannon
! logical, intent(in) :: do_print
! end function
!
#ifdef ENABLE_AUTOTUNING #ifdef ENABLE_AUTOTUNING
!> \brief function to setup the ELPA autotuning and create the autotune object !> \brief function to setup the ELPA autotuning and create the autotune object
!> Parameters !> Parameters
......
...@@ -58,7 +58,7 @@ ...@@ -58,7 +58,7 @@
use_cannon = 0 use_cannon = 0
endif endif
error = self%construct_scalapack_descriptor(sc_desc) error = self%construct_scalapack_descriptor(sc_desc, .false.)
if(error .NE. ELPA_OK) return if(error .NE. ELPA_OK) return
if (.not. is_already_decomposed) then if (.not. is_already_decomposed) then
...@@ -136,28 +136,64 @@ ...@@ -136,28 +136,64 @@
#else #else
MATH_DATATYPE(kind=rck) :: b(self%local_nrows, self%local_ncols), q(self%local_nrows, self%local_ncols) MATH_DATATYPE(kind=rck) :: b(self%local_nrows, self%local_ncols), q(self%local_nrows, self%local_ncols)
#endif #endif
integer(kind=ik) :: my_p, my_prow, my_pcol, np_rows, np_cols, mpierr, mpi_comm_rows, mpi_comm_cols, mpi_comm_all
integer :: error integer :: error
integer :: sc_desc(SC_DESC_LEN) integer :: sc_desc(SC_DESC_LEN)
integer :: sc_desc_ev(SC_DESC_LEN)
integer(kind=ik) :: use_cannon
MATH_DATATYPE(kind=rck) :: tmp(self%local_nrows, self%local_ncols)
call self%get("mpi_comm_rows",mpi_comm_rows,error)
call self%get("mpi_comm_cols",mpi_comm_cols,error)
call self%get("mpi_comm_parent", mpi_comm_all,error)
call mpi_comm_rank(mpi_comm_all,my_p,mpierr)
call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
call self%timer_start("transform_back_generalized()") call self%timer_start("transform_back_generalized()")
call self%get("cannon_for_generalized",use_cannon,error)
#if !defined(REALCASE) || !defined(DOUBLE_PRECISION)
use_cannon = 0
#endif
error = self%construct_scalapack_descriptor(sc_desc) #if !defined(WITH_MPI)
use_cannon = 0
#endif
if (mod(np_cols, np_rows) /= 0) then
use_cannon = 0
endif
error = self%construct_scalapack_descriptor(sc_desc, .false.)
error = self%construct_scalapack_descriptor(sc_desc_ev, .true.)
if(error .NE. ELPA_OK) return if(error .NE. ELPA_OK) return
call self%timer_start("scalapack multiply inv(U) * Q") if(use_cannon == 1) then
#if defined(REALCASE) && defined(DOUBLE_PRECISION)
call cannons_triang_rectangular(b, q, self%local_nrows, self%local_ncols, np_rows, np_cols, my_prow, my_pcol, &
sc_desc, sc_desc_ev, tmp, mpi_comm_rows, mpi_comm_cols);
q(1:self%local_nrows, 1:self%local_ncols) = tmp(1:self%local_nrows, 1:self%local_ncols)
#endif
else
call self%timer_start("scalapack multiply inv(U) * Q")
#ifdef WITH_MPI #ifdef WITH_MPI
! Q <- inv(U) * Q ! Q <- inv(U) * Q
call p& call p&
&BLAS_CHAR& &BLAS_CHAR&
&trmm("L", "U", "N", "N", self%na, self%nev, & &trmm("L", "U", "N", "N", self%na, self%nev, &
ONE, b, 1, 1, sc_desc, q, 1, 1, sc_desc) ONE, b, 1, 1, sc_desc, q, 1, 1, sc_desc)
#else #else
call BLAS_CHAR& call BLAS_CHAR&
&trmm("L", "U", "N", "N", self%na, self%nev, & &trmm("L", "U", "N", "N", self%na, self%nev, &
ONE, b, self%na, q, self%na) ONE, b, self%na, q, self%na)
#endif #endif
call self%timer_stop("scalapack multiply inv(U) * Q") call self%timer_stop("scalapack multiply inv(U) * Q")
endif
call self%timer_stop("transform_back_generalized()") call self%timer_stop("transform_back_generalized()")
end subroutine end subroutine
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment