elpa_impl_generalized_transform_template.F90 7.03 KB
Newer Older
1 2 3 4 5
! using elpa internal Hermitian multiply is faster then scalapack multiply, but we need an extra
! temporary matrix.
! using cannon algorithm should be the fastest. After this is verified, the other options should be removed
! however, we need the extra temporary matrix as well.

Andreas Marek's avatar
Andreas Marek committed
6
   subroutine elpa_transform_generalized_&
Pavel Kus's avatar
Pavel Kus committed
7
            &ELPA_IMPL_SUFFIX&
8
            &(self, a, b, is_already_decomposed, error)
Pavel Kus's avatar
Pavel Kus committed
9 10 11 12 13 14 15 16 17
        implicit none
#include "general/precision_kinds.F90"
        class(elpa_impl_t)  :: self
#ifdef USE_ASSUMED_SIZE
      MATH_DATATYPE(kind=rck) :: a(self%local_nrows, *), b(self%local_nrows, *)
#else
      MATH_DATATYPE(kind=rck) :: a(self%local_nrows, self%local_ncols), b(self%local_nrows, self%local_ncols)
#endif
     integer                :: error
18
     logical                :: is_already_decomposed
19
     integer                :: sc_desc(SC_DESC_LEN)
20
     integer(kind=ik)       :: my_p, my_prow, my_pcol, np_rows, np_cols, mpierr, mpi_comm_rows, mpi_comm_cols, mpi_comm_all
21
     integer(kind=ik)       :: BuffLevelInt, use_cannon
Pavel Kus's avatar
Pavel Kus committed
22 23

     MATH_DATATYPE(kind=rck) :: tmp(self%local_nrows, self%local_ncols)
24 25 26 27 28 29 30 31 32 33

     call self%get("mpi_comm_rows",mpi_comm_rows,error)
     call self%get("mpi_comm_cols",mpi_comm_cols,error)
     call self%get("mpi_comm_parent", mpi_comm_all,error)

     call mpi_comm_rank(mpi_comm_all,my_p,mpierr)
     call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
     call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
     call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
     call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
Pavel Kus's avatar
Pavel Kus committed
34

35
     call self%timer_start("transform_generalized()")
36 37
     call self%get("cannon_for_generalized",use_cannon,error)

38 39 40 41 42 43 44 45
#if !defined(WITH_MPI)
     if(my_p == 0) then
       write(*,*) "Cannons algorithm can be used with MPI"
       write(*,*) "Switching to elpa Hermitian and scalapack"
     end if
     use_cannon = 0
#endif

46 47 48 49 50 51 52
     if (mod(np_cols, np_rows) /= 0) then
       if(my_p == 0) then
         write(*,*) "To use Cannons algorithm, np_cols must be a multiple of np_rows."
         write(*,*) "Switching to elpa Hermitian and scalapack"
       end if
       use_cannon = 0
     endif
53

54
     error = self%construct_scalapack_descriptor(sc_desc, .false.)
55 56
     if(error .NE. ELPA_OK) return

57 58 59 60 61 62 63 64 65 66 67 68
     if (.not. is_already_decomposed) then
       ! B = U^T*U, B<-U
       call self%elpa_cholesky_&
           &ELPA_IMPL_SUFFIX&
           &(b, error)
       if(error .NE. ELPA_OK) return
       ! B <- inv(U)
       call self%elpa_invert_trm_&
           &ELPA_IMPL_SUFFIX&
           &(b, error)
       if(error .NE. ELPA_OK) return
     end if
69

70 71 72 73
     if(use_cannon == 1) then
       !TODO set the value properly
       !TODO tunable parameter?
       BuffLevelInt = 1
74

75
       call self%timer_start("cannons_reduction")
76 77
       ! BEWARE! even though tmp is output from the routine, it has to be zero on input!
       tmp = 0.0_rck
78
#ifdef WITH_MPI
79 80
       call cannons_reduction_&
         &ELPA_IMPL_SUFFIX&
81
         &(a, b, self%local_nrows, self%local_ncols, sc_desc, tmp, BuffLevelInt, mpi_comm_rows, mpi_comm_cols)
82
#endif
83
       call self%timer_stop("cannons_reduction")
84

85
       a(1:self%local_nrows, 1:self%local_ncols) = tmp(1:self%local_nrows, 1:self%local_ncols)
86

87 88 89 90 91 92 93
     else  ! do not use cannon algorithm, use elpa hermitian multiply and scalapack instead
       ! tmp <- inv(U^T) * A (we have to use temporary variable)
       call self%elpa_hermitian_multiply_&
           &ELPA_IMPL_SUFFIX&
           &('U','F', self%na, b, a, self%local_nrows, self%local_ncols, tmp, &
                                 self%local_nrows, self%local_ncols, error)
       if(error .NE. ELPA_OK) return
94

95 96
       ! A <- inv(U)^T * A
       a(1:self%local_nrows, 1:self%local_ncols) = tmp(1:self%local_nrows, 1:self%local_ncols)
97

98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
       ! A <- inv(U)^T * A * inv(U)
       ! For this multiplication we do not have internal function in ELPA,
       ! so we have to call scalapack
       call self%timer_start("scalapack multiply A * inv(U)")
#ifdef WITH_MPI
       call p&
           &BLAS_CHAR&
           &trmm("R", "U", "N", "N", self%na, self%na, &
                 ONE, b, 1, 1, sc_desc, a, 1, 1, sc_desc)
#else
       call BLAS_CHAR&
           &trmm("R", "U", "N", "N", self%na, self%na, &
                 ONE, b, self%na, a, self%na)
#endif
       call self%timer_stop("scalapack multiply A * inv(U)")
     endif ! use_cannon
114

115
     !write(*, *) my_prow, my_pcol, "A(2,3)", a(2,3)
Pavel Kus's avatar
Pavel Kus committed
116

117
     call self%timer_stop("transform_generalized()")
Pavel Kus's avatar
Pavel Kus committed
118 119
    end subroutine

120 121 122

    subroutine elpa_transform_back_generalized_&
            &ELPA_IMPL_SUFFIX&
123
            &(self, b, q, error)
124 125 126 127 128 129 130 131
        implicit none
#include "general/precision_kinds.F90"
        class(elpa_impl_t)  :: self
#ifdef USE_ASSUMED_SIZE
      MATH_DATATYPE(kind=rck) :: b(self%local_nrows, *), q(self%local_nrows, *)
#else
      MATH_DATATYPE(kind=rck) :: b(self%local_nrows, self%local_ncols), q(self%local_nrows, self%local_ncols)
#endif
132
     integer(kind=ik)       :: my_p, my_prow, my_pcol, np_rows, np_cols, mpierr, mpi_comm_rows, mpi_comm_cols, mpi_comm_all
133
     integer                :: error
134
     integer                :: sc_desc(SC_DESC_LEN)
135 136 137 138 139 140 141 142 143 144 145 146 147 148
     integer                :: sc_desc_ev(SC_DESC_LEN)
     integer(kind=ik)       :: use_cannon

     MATH_DATATYPE(kind=rck) :: tmp(self%local_nrows, self%local_ncols)

     call self%get("mpi_comm_rows",mpi_comm_rows,error)
     call self%get("mpi_comm_cols",mpi_comm_cols,error)
     call self%get("mpi_comm_parent", mpi_comm_all,error)

     call mpi_comm_rank(mpi_comm_all,my_p,mpierr)
     call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
     call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
     call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
     call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
149

150
     call self%timer_start("transform_back_generalized()")
151
     call self%get("cannon_for_generalized",use_cannon,error)
152

153 154 155 156 157 158 159 160 161 162
#if !defined(WITH_MPI)
     use_cannon = 0
#endif

     if (mod(np_cols, np_rows) /= 0) then
       use_cannon = 0
     endif

     error = self%construct_scalapack_descriptor(sc_desc, .false.)
     error = self%construct_scalapack_descriptor(sc_desc_ev, .true.)
163 164
     if(error .NE. ELPA_OK) return

165
     if(use_cannon == 1) then
166
       call self%timer_start("cannons_triang_rectangular")
167
#ifdef WITH_MPI
168 169 170
       call cannons_triang_rectangular_&
         &ELPA_IMPL_SUFFIX&
         &(b, q, self%local_nrows, self%local_ncols, sc_desc, sc_desc_ev, tmp, mpi_comm_rows, mpi_comm_cols);
171
#endif
172
       call self%timer_stop("cannons_triang_rectangular")
173 174 175 176

       q(1:self%local_nrows, 1:self%local_ncols) = tmp(1:self%local_nrows, 1:self%local_ncols)
     else
       call self%timer_start("scalapack multiply inv(U) * Q")
177
#ifdef WITH_MPI
178 179 180 181 182
       ! Q <- inv(U) * Q
       call p&
           &BLAS_CHAR&
           &trmm("L", "U", "N", "N", self%na, self%nev, &
                 ONE, b, 1, 1, sc_desc,  q, 1, 1, sc_desc)
183
#else
184 185 186
       call BLAS_CHAR&
           &trmm("L", "U", "N", "N", self%na, self%nev, &
                 ONE, b, self%na, q, self%na)
187
#endif
188 189
       call self%timer_stop("scalapack multiply inv(U) * Q")
     endif
190
     call self%timer_stop("transform_back_generalized()")
191 192 193

    end subroutine