Commit 95cb4b69 authored by Pavel Kus's avatar Pavel Kus

using cannons algorithm for the back transformation as well

parent 47ca86bb
......@@ -452,17 +452,17 @@ void d_cannons_triang_rectangular(double* U, double* B, int np_rows, int np_cols
!f> end subroutine
!f> end interface
*/
void d_cannons_triang_rectantular_c(double* U, double* B, int local_rows, int local_cols, int np_rows, int np_cols,
void d_cannons_triang_rectangular_c(double* U, double* B, int local_rows, int local_cols, int np_rows, int np_cols,
int my_prow, int my_pcol, int* u_desc, int* b_desc, double *Res, int row_comm, int col_comm)
{
#ifdef WITH_MPI
MPI_Comm c_row_comm = MPI_Comm_f2c(row_comm);
MPI_Comm c_col_comm = MPI_Comm_f2c(col_comm);
//int c_my_prow, c_my_pcol;
//MPI_Comm_rank(c_row_comm, &c_my_prow);
//MPI_Comm_rank(c_col_comm, &c_my_pcol);
//printf("FORT<->C row: %d<->%d, col: %d<->%d\n", my_prow, c_my_prow, my_pcol, c_my_pcol);
// int c_my_prow, c_my_pcol;
// MPI_Comm_rank(c_row_comm, &c_my_prow);
// MPI_Comm_rank(c_col_comm, &c_my_pcol);
// printf("FORT<->C row: %d<->%d, col: %d<->%d\n", my_prow, c_my_prow, my_pcol, c_my_pcol);
// BEWARE
// in the cannons algorithm, column and row communicators are exchanged
......
......@@ -922,10 +922,10 @@ void d_cannons_reduction_c(double* A, double* U, int local_rows, int local_cols,
MPI_Comm c_row_comm = MPI_Comm_f2c(row_comm);
MPI_Comm c_col_comm = MPI_Comm_f2c(col_comm);
//int c_my_prow, c_my_pcol;
//MPI_Comm_rank(c_row_comm, &c_my_prow);
//MPI_Comm_rank(c_col_comm, &c_my_pcol);
//printf("FORT<->C row: %d<->%d, col: %d<->%d\n", my_prow, c_my_prow, my_pcol, c_my_pcol);
// int c_my_prow, c_my_pcol;
// MPI_Comm_rank(c_row_comm, &c_my_prow);
// MPI_Comm_rank(c_col_comm, &c_my_pcol);
// printf("FORT<->C row: %d<->%d, col: %d<->%d\n", my_prow, c_my_prow, my_pcol, c_my_pcol);
// BEWARE
// in the cannons algorithm, column and row communicators are exchanged
......
......@@ -348,8 +348,9 @@ module elpa_impl
error = self%setup()
end function
function elpa_construct_scalapack_descriptor(self, sc_desc) result(error)
function elpa_construct_scalapack_descriptor(self, sc_desc, rectangular_for_ev) result(error)
class(elpa_impl_t), intent(inout) :: self
logical, intent(in) :: rectangular_for_ev
integer :: error, blacs_ctx
integer, intent(out) :: sc_desc(SC_DESC_LEN)
......@@ -363,7 +364,11 @@ module elpa_impl
sc_desc(1) = 1
sc_desc(2) = blacs_ctx
sc_desc(3) = self%na
sc_desc(4) = self%na
if(rectangular_for_ev) then
sc_desc(4) = self%nev
else
sc_desc(4) = self%na
endif
sc_desc(5) = self%nblk
sc_desc(6) = self%nblk
sc_desc(7) = 0
......@@ -736,6 +741,12 @@ module elpa_impl
#undef SINGLE_PRECISION
#endif
! function use_cannons_algorithm(self) result(use_cannon, do_print)
! class(elpa_impl_t), intent(inout), target :: self
! logical :: use_cannon
! logical, intent(in) :: do_print
! end function
!
#ifdef ENABLE_AUTOTUNING
!> \brief function to setup the ELPA autotuning and create the autotune object
!> Parameters
......
......@@ -58,7 +58,7 @@
use_cannon = 0
endif
error = self%construct_scalapack_descriptor(sc_desc)
error = self%construct_scalapack_descriptor(sc_desc, .false.)
if(error .NE. ELPA_OK) return
if (.not. is_already_decomposed) then
......@@ -136,28 +136,64 @@
#else
MATH_DATATYPE(kind=rck) :: b(self%local_nrows, self%local_ncols), q(self%local_nrows, self%local_ncols)
#endif
integer(kind=ik) :: my_p, my_prow, my_pcol, np_rows, np_cols, mpierr, mpi_comm_rows, mpi_comm_cols, mpi_comm_all
integer :: error
integer :: sc_desc(SC_DESC_LEN)
integer :: sc_desc_ev(SC_DESC_LEN)
integer(kind=ik) :: use_cannon
MATH_DATATYPE(kind=rck) :: tmp(self%local_nrows, self%local_ncols)
call self%get("mpi_comm_rows",mpi_comm_rows,error)
call self%get("mpi_comm_cols",mpi_comm_cols,error)
call self%get("mpi_comm_parent", mpi_comm_all,error)
call mpi_comm_rank(mpi_comm_all,my_p,mpierr)
call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
call self%timer_start("transform_back_generalized()")
call self%get("cannon_for_generalized",use_cannon,error)
#if !defined(REALCASE) || !defined(DOUBLE_PRECISION)
use_cannon = 0
#endif
error = self%construct_scalapack_descriptor(sc_desc)
#if !defined(WITH_MPI)
use_cannon = 0
#endif
if (mod(np_cols, np_rows) /= 0) then
use_cannon = 0
endif
error = self%construct_scalapack_descriptor(sc_desc, .false.)
error = self%construct_scalapack_descriptor(sc_desc_ev, .true.)
if(error .NE. ELPA_OK) return
call self%timer_start("scalapack multiply inv(U) * Q")
if(use_cannon == 1) then
#if defined(REALCASE) && defined(DOUBLE_PRECISION)
call cannons_triang_rectangular(b, q, self%local_nrows, self%local_ncols, np_rows, np_cols, my_prow, my_pcol, &
sc_desc, sc_desc_ev, tmp, mpi_comm_rows, mpi_comm_cols);
q(1:self%local_nrows, 1:self%local_ncols) = tmp(1:self%local_nrows, 1:self%local_ncols)
#endif
else
call self%timer_start("scalapack multiply inv(U) * Q")
#ifdef WITH_MPI
! Q <- inv(U) * Q
call p&
&BLAS_CHAR&
&trmm("L", "U", "N", "N", self%na, self%nev, &
ONE, b, 1, 1, sc_desc, q, 1, 1, sc_desc)
! Q <- inv(U) * Q
call p&
&BLAS_CHAR&
&trmm("L", "U", "N", "N", self%na, self%nev, &
ONE, b, 1, 1, sc_desc, q, 1, 1, sc_desc)
#else
call BLAS_CHAR&
&trmm("L", "U", "N", "N", self%na, self%nev, &
ONE, b, self%na, q, self%na)
call BLAS_CHAR&
&trmm("L", "U", "N", "N", self%na, self%nev, &
ONE, b, self%na, q, self%na)
#endif
call self%timer_stop("scalapack multiply inv(U) * Q")
call self%timer_stop("scalapack multiply inv(U) * Q")
endif
call self%timer_stop("transform_back_generalized()")
end subroutine
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment