elpa_impl_generalized_transform_template.F90 7.04 KB
Newer Older
1
2
3
4
5
! using elpa internal Hermitian multiply is faster then scalapack multiply, but we need an extra
! temporary matrix.
! using cannon algorithm should be the fastest. After this is verified, the other options should be removed
! however, we need the extra temporary matrix as well.

Andreas Marek's avatar
Andreas Marek committed
6
   subroutine elpa_transform_generalized_&
Pavel Kus's avatar
Pavel Kus committed
7
            &ELPA_IMPL_SUFFIX&
8
            &(self, a, b, is_already_decomposed, error)
Pavel Kus's avatar
Pavel Kus committed
9
10
11
12
13
14
15
16
17
        implicit none
#include "general/precision_kinds.F90"
        class(elpa_impl_t)  :: self
#ifdef USE_ASSUMED_SIZE
      MATH_DATATYPE(kind=rck) :: a(self%local_nrows, *), b(self%local_nrows, *)
#else
      MATH_DATATYPE(kind=rck) :: a(self%local_nrows, self%local_ncols), b(self%local_nrows, self%local_ncols)
#endif
     integer                :: error
18
     logical                :: is_already_decomposed
19
     integer                :: sc_desc(SC_DESC_LEN)
20
     integer(kind=ik)       :: my_p, my_prow, my_pcol, np_rows, np_cols, mpierr, mpi_comm_rows, mpi_comm_cols, mpi_comm_all
21
     integer(kind=ik)       :: BuffLevelInt, use_cannon
Pavel Kus's avatar
Pavel Kus committed
22
23

     MATH_DATATYPE(kind=rck) :: tmp(self%local_nrows, self%local_ncols)
24
25
26
27
28
29
30
31
32
33

     call self%get("mpi_comm_rows",mpi_comm_rows,error)
     call self%get("mpi_comm_cols",mpi_comm_cols,error)
     call self%get("mpi_comm_parent", mpi_comm_all,error)

     call mpi_comm_rank(mpi_comm_all,my_p,mpierr)
     call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
     call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
     call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
     call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
Pavel Kus's avatar
Pavel Kus committed
34

35
     call self%timer_start("transform_generalized()")
36
37
     call self%get("cannon_for_generalized",use_cannon,error)

38
39
#if !defined(WITH_MPI)
     if(my_p == 0) then
40
       write(*,*) "Cannons algorithm can only be used with MPI"
41
42
43
44
45
       write(*,*) "Switching to elpa Hermitian and scalapack"
     end if
     use_cannon = 0
#endif

46
47
48
49
50
51
52
     if (mod(np_cols, np_rows) /= 0) then
       if(my_p == 0) then
         write(*,*) "To use Cannons algorithm, np_cols must be a multiple of np_rows."
         write(*,*) "Switching to elpa Hermitian and scalapack"
       end if
       use_cannon = 0
     endif
53

54
     error = self%construct_scalapack_descriptor(sc_desc, .false.)
55
56
     if(error .NE. ELPA_OK) return

57
58
59
60
61
62
63
64
65
66
67
68
     if (.not. is_already_decomposed) then
       ! B = U^T*U, B<-U
       call self%elpa_cholesky_&
           &ELPA_IMPL_SUFFIX&
           &(b, error)
       if(error .NE. ELPA_OK) return
       ! B <- inv(U)
       call self%elpa_invert_trm_&
           &ELPA_IMPL_SUFFIX&
           &(b, error)
       if(error .NE. ELPA_OK) return
     end if
69

70
71
72
73
     if(use_cannon == 1) then
       !TODO set the value properly
       !TODO tunable parameter?
       BuffLevelInt = 1
74

75
       call self%timer_start("cannons_reduction")
76
77
       ! BEWARE! even though tmp is output from the routine, it has to be zero on input!
       tmp = 0.0_rck
78
#ifdef WITH_MPI
79
80
       call cannons_reduction_&
         &ELPA_IMPL_SUFFIX&
81
         &(a, b, self%local_nrows, self%local_ncols, sc_desc, tmp, BuffLevelInt, mpi_comm_rows, mpi_comm_cols)
82
#endif
83
       call self%timer_stop("cannons_reduction")
84

85
       a(1:self%local_nrows, 1:self%local_ncols) = tmp(1:self%local_nrows, 1:self%local_ncols)
86

87
88
89
90
91
92
93
     else  ! do not use cannon algorithm, use elpa hermitian multiply and scalapack instead
       ! tmp <- inv(U^T) * A (we have to use temporary variable)
       call self%elpa_hermitian_multiply_&
           &ELPA_IMPL_SUFFIX&
           &('U','F', self%na, b, a, self%local_nrows, self%local_ncols, tmp, &
                                 self%local_nrows, self%local_ncols, error)
       if(error .NE. ELPA_OK) return
94

95
96
       ! A <- inv(U)^T * A
       a(1:self%local_nrows, 1:self%local_ncols) = tmp(1:self%local_nrows, 1:self%local_ncols)
97

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
       ! A <- inv(U)^T * A * inv(U)
       ! For this multiplication we do not have internal function in ELPA,
       ! so we have to call scalapack
       call self%timer_start("scalapack multiply A * inv(U)")
#ifdef WITH_MPI
       call p&
           &BLAS_CHAR&
           &trmm("R", "U", "N", "N", self%na, self%na, &
                 ONE, b, 1, 1, sc_desc, a, 1, 1, sc_desc)
#else
       call BLAS_CHAR&
           &trmm("R", "U", "N", "N", self%na, self%na, &
                 ONE, b, self%na, a, self%na)
#endif
       call self%timer_stop("scalapack multiply A * inv(U)")
     endif ! use_cannon
114

115
     !write(*, *) my_prow, my_pcol, "A(2,3)", a(2,3)
Pavel Kus's avatar
Pavel Kus committed
116

117
     call self%timer_stop("transform_generalized()")
Pavel Kus's avatar
Pavel Kus committed
118
119
    end subroutine

120
121
122

    subroutine elpa_transform_back_generalized_&
            &ELPA_IMPL_SUFFIX&
123
            &(self, b, q, error)
124
125
126
127
128
129
130
131
        implicit none
#include "general/precision_kinds.F90"
        class(elpa_impl_t)  :: self
#ifdef USE_ASSUMED_SIZE
      MATH_DATATYPE(kind=rck) :: b(self%local_nrows, *), q(self%local_nrows, *)
#else
      MATH_DATATYPE(kind=rck) :: b(self%local_nrows, self%local_ncols), q(self%local_nrows, self%local_ncols)
#endif
132
     integer(kind=ik)       :: my_p, my_prow, my_pcol, np_rows, np_cols, mpierr, mpi_comm_rows, mpi_comm_cols, mpi_comm_all
133
     integer                :: error
134
     integer                :: sc_desc(SC_DESC_LEN)
135
136
137
138
139
140
141
142
143
144
145
146
147
148
     integer                :: sc_desc_ev(SC_DESC_LEN)
     integer(kind=ik)       :: use_cannon

     MATH_DATATYPE(kind=rck) :: tmp(self%local_nrows, self%local_ncols)

     call self%get("mpi_comm_rows",mpi_comm_rows,error)
     call self%get("mpi_comm_cols",mpi_comm_cols,error)
     call self%get("mpi_comm_parent", mpi_comm_all,error)

     call mpi_comm_rank(mpi_comm_all,my_p,mpierr)
     call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
     call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
     call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
     call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
149

150
     call self%timer_start("transform_back_generalized()")
151
     call self%get("cannon_for_generalized",use_cannon,error)
152

153
154
155
156
157
158
159
160
161
162
#if !defined(WITH_MPI)
     use_cannon = 0
#endif

     if (mod(np_cols, np_rows) /= 0) then
       use_cannon = 0
     endif

     error = self%construct_scalapack_descriptor(sc_desc, .false.)
     error = self%construct_scalapack_descriptor(sc_desc_ev, .true.)
163
164
     if(error .NE. ELPA_OK) return

165
     if(use_cannon == 1) then
166
       call self%timer_start("cannons_triang_rectangular")
167
#ifdef WITH_MPI
168
169
170
       call cannons_triang_rectangular_&
         &ELPA_IMPL_SUFFIX&
         &(b, q, self%local_nrows, self%local_ncols, sc_desc, sc_desc_ev, tmp, mpi_comm_rows, mpi_comm_cols);
171
#endif
172
       call self%timer_stop("cannons_triang_rectangular")
173
174
175
176

       q(1:self%local_nrows, 1:self%local_ncols) = tmp(1:self%local_nrows, 1:self%local_ncols)
     else
       call self%timer_start("scalapack multiply inv(U) * Q")
177
#ifdef WITH_MPI
178
179
180
181
182
       ! Q <- inv(U) * Q
       call p&
           &BLAS_CHAR&
           &trmm("L", "U", "N", "N", self%na, self%nev, &
                 ONE, b, 1, 1, sc_desc,  q, 1, 1, sc_desc)
183
#else
184
185
186
       call BLAS_CHAR&
           &trmm("L", "U", "N", "N", self%na, self%nev, &
                 ONE, b, self%na, q, self%na)
187
#endif
188
189
       call self%timer_stop("scalapack multiply inv(U) * Q")
     endif
190
     call self%timer_stop("transform_back_generalized()")
191
192
193

    end subroutine