elpa_impl_generalized_transform_template.F90 5.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
! using elpa internal Hermitian multiply is faster then scalapack multiply, but we need an extra
! temporary matrix.
! using cannon algorithm should be the fastest. After this is verified, the other options should be removed
! however, we need the extra temporary matrix as well.
#undef FORWARD_ELPA_CANNON
#undef  FORWARD_SCALAPACK
#undef FORWARD_ELPA_HERMITIAN

#if defined(REALCASE) && defined(DOUBLE_PRECISION)
#define  FORWARD_ELPA_CANNON
!#define  FORWARD_ELPA_HERMITIAN
#else
!TODO first just for real double...
#define FORWARD_ELPA_HERMITIAN
#endif

#define BACKWARD_ELPA_CANNON
#undef  BACKWARD_SCALAPACK

Andreas Marek's avatar
Andreas Marek committed
20
   subroutine elpa_transform_generalized_&
Pavel Kus's avatar
Pavel Kus committed
21
            &ELPA_IMPL_SUFFIX&
22
            &(self, a, b, is_already_decomposed, error)
Pavel Kus's avatar
Pavel Kus committed
23
24
25
26
27
28
29
30
31
        implicit none
#include "general/precision_kinds.F90"
        class(elpa_impl_t)  :: self
#ifdef USE_ASSUMED_SIZE
      MATH_DATATYPE(kind=rck) :: a(self%local_nrows, *), b(self%local_nrows, *)
#else
      MATH_DATATYPE(kind=rck) :: a(self%local_nrows, self%local_ncols), b(self%local_nrows, self%local_ncols)
#endif
     integer                :: error
32
     logical                :: is_already_decomposed
33
     integer                :: sc_desc(SC_DESC_LEN)
34
35
     integer(kind=ik)       :: my_p, my_prow, my_pcol, np_rows, np_cols, mpierr, mpi_comm_rows, mpi_comm_cols, mpi_comm_all
     integer(kind=ik)       :: BuffLevelInt
Pavel Kus's avatar
Pavel Kus committed
36

37
#if defined(FORWARD_ELPA_HERMITIAN) || defined(FORWARD_ELPA_CANNON)
Pavel Kus's avatar
Pavel Kus committed
38
     MATH_DATATYPE(kind=rck) :: tmp(self%local_nrows, self%local_ncols)
39
#endif
Pavel Kus's avatar
Pavel Kus committed
40

41
     call self%timer_start("transform_generalized()")
42

43
44
45
     error = self%construct_scalapack_descriptor(sc_desc)
     if(error .NE. ELPA_OK) return

46
47
48
49
50
51
52
53
54
55
56
57
     if (.not. is_already_decomposed) then
       ! B = U^T*U, B<-U
       call self%elpa_cholesky_&
           &ELPA_IMPL_SUFFIX&
           &(b, error)
       if(error .NE. ELPA_OK) return
       ! B <- inv(U)
       call self%elpa_invert_trm_&
           &ELPA_IMPL_SUFFIX&
           &(b, error)
       if(error .NE. ELPA_OK) return
     end if
58

59
#ifdef FORWARD_ELPA_HERMITIAN
60
61
62
63
64
65
     ! tmp <- inv(U^T) * A (we have to use temporary variable)
     call self%elpa_hermitian_multiply_&
         &ELPA_IMPL_SUFFIX&
         &('U','F', self%na, b, a, self%local_nrows, self%local_ncols, tmp, &
                               self%local_nrows, self%local_ncols, error)
     if(error .NE. ELPA_OK) return
66

67
68
     ! A <- inv(U)^T * A
     a(1:self%local_nrows, 1:self%local_ncols) = tmp(1:self%local_nrows, 1:self%local_ncols)
69
70
#endif
#ifdef FORWARD_SCALAPACK
71
72
     ! A <- inv(U)^T * A (using scalapack, we can directly update A)
     call self%timer_start("scalapack multiply inv(U)^T * A")
73
#ifdef WITH_MPI
74
75
76
77
     call p&
         &BLAS_CHAR&
         &trmm("L", "U", BLAS_TRANS_OR_CONJ, "N", self%na, self%na, &
               ONE, b, 1, 1, sc_desc,  a, 1, 1, sc_desc)
78
#else
79
80
81
     call BLAS_CHAR&
         &trmm("L", "U", BLAS_TRANS_OR_CONJ, "N", self%na, self%na, &
               ONE, b, self%na, a, self%na)
82
#endif
Andreas Marek's avatar
Andreas Marek committed
83

84
     call self%timer_stop("scalapack multiply inv(U)^T * A")
85
#endif /* FORWARD_SCALAPACK */
86

87
#if defined(FORWARD_ELPA_HERMITIAN) || defined(FORWARD_SCALAPACK)
88
89
90
     ! A <- inv(U)^T * A * inv(U)
     ! For this multiplication we do not have internal function in ELPA, 
     ! so we have to call scalapack anyway
91
92
     call self%timer_start("scalapack multiply A * inv(U)")
#ifdef WITH_MPI
Pavel Kus's avatar
Pavel Kus committed
93
94
95
96
97
98
99
100
101
     call p&
         &BLAS_CHAR&
         &trmm("R", "U", "N", "N", self%na, self%na, &
               ONE, b, 1, 1, sc_desc, a, 1, 1, sc_desc)
#else
     call BLAS_CHAR&
         &trmm("R", "U", "N", "N", self%na, self%na, &
               ONE, b, self%na, a, self%na)
#endif
102
     call self%timer_stop("scalapack multiply A * inv(U)")
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#endif /*(FORWARD_ELPA_HERMITIAN) || defined(FORWARD_SCALAPACK)*/

#ifdef FORWARD_ELPA_CANNON
     !TODO set the value properly
     !TODO tunable parameter? 
     BuffLevelInt = 1

     call self%get("mpi_comm_rows",mpi_comm_rows,error)
     call self%get("mpi_comm_cols",mpi_comm_cols,error)
     call self%get("mpi_comm_parent", mpi_comm_all,error)

     call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
     call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
     call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
     call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
     call mpi_comm_rank(mpi_comm_all,my_p,mpierr)
     call cannons_reduction(a, b, self%local_nrows, self%local_ncols, np_rows, np_cols, my_prow, my_pcol, &
                            sc_desc, tmp, BuffLevelInt, mpi_comm_rows, mpi_comm_cols)

     a(1:self%local_nrows, 1:self%local_ncols) = tmp(1:self%local_nrows, 1:self%local_ncols)
#endif /*FORWARD_ELPA_CANNON*/

     write(*, *) my_prow, my_pcol, "A(2,3)", a(2,3)
Pavel Kus's avatar
Pavel Kus committed
126

127
     call self%timer_stop("transform_generalized()")
Pavel Kus's avatar
Pavel Kus committed
128
129
    end subroutine

130
131
132

    subroutine elpa_transform_back_generalized_&
            &ELPA_IMPL_SUFFIX&
133
            &(self, b, q, error)
134
135
136
137
138
139
140
141
142
        implicit none
#include "general/precision_kinds.F90"
        class(elpa_impl_t)  :: self
#ifdef USE_ASSUMED_SIZE
      MATH_DATATYPE(kind=rck) :: b(self%local_nrows, *), q(self%local_nrows, *)
#else
      MATH_DATATYPE(kind=rck) :: b(self%local_nrows, self%local_ncols), q(self%local_nrows, self%local_ncols)
#endif
     integer                :: error
143
     integer                :: sc_desc(SC_DESC_LEN)
144

145
146
     call self%timer_start("transform_back_generalized()")

147
148
149
     error = self%construct_scalapack_descriptor(sc_desc)
     if(error .NE. ELPA_OK) return

150
     call self%timer_start("scalapack multiply inv(U) * Q")
151
#ifdef WITH_MPI
152
     ! Q <- inv(U) * Q
153
154
     call p&
         &BLAS_CHAR&
155
         &trmm("L", "U", "N", "N", self%na, self%nev, &
156
157
158
               ONE, b, 1, 1, sc_desc,  q, 1, 1, sc_desc)
#else
     call BLAS_CHAR&
159
         &trmm("L", "U", "N", "N", self%na, self%nev, &
160
161
               ONE, b, self%na, q, self%na)
#endif
162
163
164
     call self%timer_stop("scalapack multiply inv(U) * Q")

     call self%timer_stop("transform_back_generalized()")
165
166
167

    end subroutine