elpa1_compute_complex_template.X90 31.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
#if 0
!    This file is part of ELPA.
!
!    The ELPA library was originally created by the ELPA consortium,
!    consisting of the following organizations:
!
!    - Max Planck Computing and Data Facility (MPCDF), formerly known as
!      Rechenzentrum Garching der Max-Planck-Gesellschaft (RZG),
!    - Bergische Universität Wuppertal, Lehrstuhl für angewandte
!      Informatik,
!    - Technische Universität München, Lehrstuhl für Informatik mit
!      Schwerpunkt Wissenschaftliches Rechnen ,
!    - Fritz-Haber-Institut, Berlin, Abt. Theorie,
14
!    - Max-Plack-Institut für Mathematik in den Naturwissenschaften,
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
!      Leipzig, Abt. Komplexe Strukutren in Biologie und Kognition,
!      and
!    - IBM Deutschland GmbH
!
!    This particular source code file contains additions, changes and
!    enhancements authored by Intel Corporation which is not part of
!    the ELPA consortium.
!
!    More information can be found here:
!    http://elpa.mpcdf.mpg.de/
!
!    ELPA is free software: you can redistribute it and/or modify
!    it under the terms of the version 3 of the license of the
!    GNU Lesser General Public License as published by the Free
!    Software Foundation.
!
!    ELPA is distributed in the hope that it will be useful,
!    but WITHOUT ANY WARRANTY; without even the implied warranty of
!    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!    GNU Lesser General Public License for more details.
!
!    You should have received a copy of the GNU Lesser General Public License
!    along with ELPA.  If not, see <http://www.gnu.org/licenses/>
!
!    ELPA reflects a substantial effort on the part of the original
!    ELPA consortium, and we ask you to respect the spirit of the
!    license that we chose: i.e., please contribute any changes you
!    may have back to the original ELPA library distribution, and keep
!    any derivatives of ELPA under the same license that we chose for
!    the original distribution, the GNU Lesser General Public License.
!
!
! ELPA1 -- Faster replacements for ScaLAPACK symmetric eigenvalue routines
!
! Copyright of the original code rests with the authors inside the ELPA
! consortium. The copyright of any additional modifications shall rest
! with their original authors, but shall adhere to the licensing terms
! distributed along with the original code in the file "COPYING".
#endif

55
56
57
58
#include "precision_macros_complex.h"


    subroutine tridiag_complex_PRECISION(na, a, lda, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols, d, e, tau)
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    !-------------------------------------------------------------------------------
    !  tridiag_complex: Reduces a distributed hermitian matrix to tridiagonal form
    !                   (like Scalapack Routine PZHETRD)
    !
    !  Parameters
    !
    !  na          Order of matrix
    !
    !  a(lda,matrixCols)    Distributed matrix which should be reduced.
    !              Distribution is like in Scalapack.
    !              Opposed to PZHETRD, a(:,:) must be set completely (upper and lower half)
    !              a(:,:) is overwritten on exit with the Householder vectors
    !
    !  lda         Leading dimension of a
    !  matrixCols  local columns of matrix a
    !
    !  nblk        blocksize of cyclic distribution, must be the same in both directions!
    !
    !  mpi_comm_rows
    !  mpi_comm_cols
    !              MPI-Communicators for rows/columns
    !
    !  d(na)       Diagonal elements (returned), identical on all processors
    !
    !  e(na)       Off-Diagonal elements (returned), identical on all processors
    !
    !  tau(na)     Factors for the Householder vectors (returned), needed for back transformation
    !
    !-------------------------------------------------------------------------------
#ifdef HAVE_DETAILED_TIMINGS
      use timings
90
91
#else
      use timings_dummy
92
93
94
95
96
97
#endif
      use precision
      implicit none

      integer(kind=ik)              :: na, lda, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols
      complex(kind=COMPLEX_DATATYPE)              :: tau(na)
98
#ifdef USE_ASSUMED_SIZE
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
      complex(kind=COMPLEX_DATATYPE)              :: a(lda,*)
#else
      complex(kind=COMPLEX_DATATYPE)              :: a(lda,matrixCols)
#endif
      real(kind=REAL_DATATYPE)                 :: d(na), e(na)

      integer(kind=ik), parameter   :: max_stored_rows = 32
#ifdef DOUBLE_PRECISION_COMPLEX
      complex(kind=ck8), parameter   :: CZERO = (0.0_rk8,0.0_rk8), CONE = (1.0_rk8,0.0_rk8)
#else
      complex(kind=ck4), parameter   :: CZERO = (0.0_rk4,0.0_rk4), CONE = (1.0_rk4,0.0_rk4)
#endif
      integer(kind=ik)              :: my_prow, my_pcol, np_rows, np_cols, mpierr
      integer(kind=ik)              :: totalblocks, max_blocks_row, max_blocks_col, max_local_rows, max_local_cols
      integer(kind=ik)              :: l_cols, l_rows, nstor
      integer(kind=ik)              :: istep, i, j, lcs, lce, lrs, lre
      integer(kind=ik)              :: tile_size, l_rows_tile, l_cols_tile

#ifdef WITH_OPENMP
      integer(kind=ik)              :: my_thread, n_threads, max_threads, n_iter
      integer(kind=ik)              :: omp_get_thread_num, omp_get_num_threads, omp_get_max_threads
#endif

      real(kind=REAL_DATATYPE)                 :: vnorm2
      complex(kind=COMPLEX_DATATYPE)              :: vav, xc, aux(2*max_stored_rows),  aux1(2), aux2(2), vrl, xf

      complex(kind=COMPLEX_DATATYPE), allocatable :: tmp(:), vr(:), vc(:), ur(:), uc(:), vur(:,:), uvc(:,:)
#ifdef WITH_OPENMP
      complex(kind=COMPLEX_DATATYPE), allocatable :: ur_p(:,:), uc_p(:,:)
#endif
      real(kind=REAL_DATATYPE), allocatable    :: tmpr(:)
      integer(kind=ik)              :: istat
      character(200)                :: errorMessage

133
      call timer%start("tridiag_complex" // PRECISION_SUFFIX)
134
135
#ifdef HAVE_DETAILED_TIMINGS
      call timer%start("mpi_communication")
136
137
138
139
140
141
#endif

      call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
      call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
      call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
      call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
142
143
144
#ifdef HAVE_DETAILED_TIMINGS
      call timer%stop("mpi_communication")
#endif
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254

      ! Matrix is split into tiles; work is done only for tiles on the diagonal or above

      tile_size = nblk*least_common_multiple(np_rows,np_cols) ! minimum global tile size
      tile_size = ((128*max(np_rows,np_cols)-1)/tile_size+1)*tile_size ! make local tiles at least 128 wide

      l_rows_tile = tile_size/np_rows ! local rows of a tile
      l_cols_tile = tile_size/np_cols ! local cols of a tile


      totalblocks = (na-1)/nblk + 1
      max_blocks_row = (totalblocks-1)/np_rows + 1
      max_blocks_col = (totalblocks-1)/np_cols + 1

      max_local_rows = max_blocks_row*nblk
      max_local_cols = max_blocks_col*nblk

      allocate(tmp(MAX(max_local_rows,max_local_cols)), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"tridiag_complex: error when allocating tmp "//errorMessage
       stop
      endif

      allocate(vr(max_local_rows+1), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"tridiag_complex: error when allocating vr "//errorMessage
       stop
      endif

      allocate(ur(max_local_rows), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"tridiag_complex: error when allocating ur "//errorMessage
       stop
      endif

      allocate(vc(max_local_cols), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"tridiag_complex: error when allocating vc "//errorMessage
       stop
      endif

      allocate(uc(max_local_cols), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"tridiag_complex: error when allocating uc "//errorMessage
       stop
      endif

#ifdef WITH_OPENMP
      max_threads = omp_get_max_threads()

      allocate(ur_p(max_local_rows,0:max_threads-1), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"tridiag_complex: error when allocating ur_p "//errorMessage
       stop
      endif

      allocate(uc_p(max_local_cols,0:max_threads-1), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"tridiag_complex: error when allocating uc_p "//errorMessage
       stop
      endif
#endif

      tmp = 0
      vr = 0
      ur = 0
      vc = 0
      uc = 0

      allocate(vur(max_local_rows,2*max_stored_rows), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"tridiag_complex: error when allocating vur "//errorMessage
       stop
      endif

      allocate(uvc(max_local_cols,2*max_stored_rows), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"tridiag_complex: error when allocating uvc "//errorMessage
       stop
      endif

      d(:) = 0
      e(:) = 0
      tau(:) = 0

      nstor = 0

      l_rows = local_index(na, my_prow, np_rows, nblk, -1) ! Local rows of a
      l_cols = local_index(na, my_pcol, np_cols, nblk, -1) ! Local cols of a
      if (my_prow==prow(na, nblk, np_rows) .and. my_pcol==pcol(na, nblk, np_cols)) d(na) = a(l_rows,l_cols)

      do istep=na,3,-1

        ! Calculate number of local rows and columns of the still remaining matrix
        ! on the local processor

        l_rows = local_index(istep-1, my_prow, np_rows, nblk, -1)
        l_cols = local_index(istep-1, my_pcol, np_cols, nblk, -1)

        ! Calculate vector for Householder transformation on all procs
        ! owning column istep

        if (my_pcol==pcol(istep, nblk, np_cols)) then

          ! Get vector to be transformed; distribute last element and norm of
          ! remaining elements to all procs in current column

          vr(1:l_rows) = a(1:l_rows,l_cols+1)
          if (nstor>0 .and. l_rows>0) then
            aux(1:2*nstor) = conjg(uvc(l_cols+1,1:2*nstor))
255
            call PRECISION_GEMV('N', l_rows, 2*nstor, CONE, vur, ubound(vur,dim=1), &
256
257
258
259
260
261
262
263
264
265
266
                        aux, 1, CONE, vr, 1)
          endif

          if (my_prow==prow(istep-1, nblk, np_rows)) then
            aux1(1) = dot_product(vr(1:l_rows-1),vr(1:l_rows-1))
            aux1(2) = vr(l_rows)
          else
            aux1(1) = dot_product(vr(1:l_rows),vr(1:l_rows))
            aux1(2) = 0.
          endif
#ifdef WITH_MPI
267
268
269
#ifdef HAVE_DETAILED_TIMINGS
          call timer%start("mpi_communication")
#endif
270
          call mpi_allreduce(aux1, aux2, 2, MPI_COMPLEX_PRECISION, MPI_SUM, mpi_comm_rows, mpierr)
271
272
          vnorm2 = aux2(1)
          vrl    = aux2(2)
273
274
275
#ifdef HAVE_DETAILED_TIMINGS
          call timer%stop("mpi_communication")
#endif
276

277
278
279
280
281
282
283
284
285
286
287
#else /* WITH_MPI */
!          aux2 = aux1

          vnorm2 = aux1(1)
          vrl    = aux1(2)

#endif /* WITH_MPI */

!          vnorm2 = aux2(1)
!          vrl    = aux2(2)

288
          ! Householder transformation
289
          call hh_transform_complex_PRECISION(vrl, vnorm2, xf, tau(istep))
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
          ! Scale vr and store Householder vector for back transformation

          vr(1:l_rows) = vr(1:l_rows) * xf
          if (my_prow==prow(istep-1, nblk, np_rows)) then
            vr(l_rows) = 1.
            e(istep-1) = vrl
          endif
          a(1:l_rows,l_cols+1) = vr(1:l_rows) ! store Householder vector for back transformation

        endif

        ! Broadcast the Householder vector (and tau) along columns

        if (my_pcol==pcol(istep, nblk, np_cols)) vr(l_rows+1) = tau(istep)
#ifdef WITH_MPI
305
306
307
#ifdef HAVE_DETAILED_TIMINGS
        call timer%start("mpi_communication")
#endif
308
        call MPI_Bcast(vr, l_rows+1, MPI_COMPLEX_PRECISION, pcol(istep, nblk, np_cols), mpi_comm_cols, mpierr)
309
310
311
#ifdef HAVE_DETAILED_TIMINGS
        call timer%stop("mpi_communication")
#endif
312
313
314
315
316
317
318
319
320

#endif /* WITH_MPI */
        tau(istep) =  vr(l_rows+1)

        ! Transpose Householder vector vr -> vc

!        call elpa_transpose_vectors  (vr, 2*ubound(vr,dim=1), mpi_comm_rows, &
!                                      vc, 2*ubound(vc,dim=1), mpi_comm_cols, &
!                                      1, 2*(istep-1), 1, 2*nblk)
321
        call elpa_transpose_vectors_complex_PRECISION  (vr, ubound(vr,dim=1), mpi_comm_rows, &
322
323
324
325
326
327
328
329
330
331
332
333
                                              vc, ubound(vc,dim=1), mpi_comm_cols, &
                                              1, (istep-1), 1, nblk)
        ! Calculate u = (A + VU**T + UV**T)*v

        ! For cache efficiency, we use only the upper half of the matrix tiles for this,
        ! thus the result is partly in uc(:) and partly in ur(:)

        uc(1:l_cols) = 0
        ur(1:l_rows) = 0
        if (l_rows>0 .and. l_cols>0) then

#ifdef WITH_OPENMP
334
          call timer%start("OpenMP parallel" // PRECISION_SUFFIX)
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356

!$OMP PARALLEL PRIVATE(my_thread,n_threads,n_iter,i,lcs,lce,j,lrs,lre)

          my_thread = omp_get_thread_num()
          n_threads = omp_get_num_threads()

          n_iter = 0

          uc_p(1:l_cols,my_thread) = 0.
          ur_p(1:l_rows,my_thread) = 0.
#endif

          do i=0,(istep-2)/tile_size
            lcs = i*l_cols_tile+1
            lce = min(l_cols,(i+1)*l_cols_tile)
            if (lce<lcs) cycle
            do j=0,i
              lrs = j*l_rows_tile+1
              lre = min(l_rows,(j+1)*l_rows_tile)
              if (lre<lrs) cycle
#ifdef WITH_OPENMP
              if (mod(n_iter,n_threads) == my_thread) then
357
                call PRECISION_GEMV('C', lre-lrs+1 ,lce-lcs+1, CONE, a(lrs,lcs), lda, vr(lrs), 1, CONE, uc_p(lcs,my_thread), 1)
358
                if (i/=j) then
359
                  call PRECISION_GEMV('N', lre-lrs+1, lce-lcs+1, CONE, a(lrs,lcs), lda, vc(lcs), 1, CONE, ur_p(lrs,my_thread), 1)
360
361
362
363
                endif
              endif
              n_iter = n_iter+1
#else /* WITH_OPENMP */
364
              call PRECISION_GEMV('C', lre-lrs+1, lce-lcs+1, CONE, a(lrs,lcs), lda, vr(lrs), 1, CONE, uc(lcs), 1)
365
              if (i/=j) then
366
                call PRECISION_GEMV('N', lre-lrs+1, lce-lcs+1, CONE, a(lrs,lcs), lda, vc(lcs), 1, CONE, ur(lrs), 1)
367
368
369
370
371
372
373
              endif
#endif /* WITH_OPENMP */
            enddo
          enddo

#ifdef WITH_OPENMP
!$OMP END PARALLEL
374
          call timer%stop("OpenMP parallel" // PRECISION_SUFFIX)
375
376
377
378
379
380
381
382

          do i=0,max_threads-1
            uc(1:l_cols) = uc(1:l_cols) + uc_p(1:l_cols,i)
            ur(1:l_rows) = ur(1:l_rows) + ur_p(1:l_rows,i)
          enddo
#endif

          if (nstor>0) then
383
384
            call PRECISION_GEMV('C', l_rows, 2*nstor, CONE, vur, ubound(vur,dim=1), vr,  1, CZERO, aux, 1)
            call PRECISION_GEMV('N', l_cols, 2*nstor, CONE, uvc, ubound(uvc,dim=1), aux, 1, CONE, uc, 1)
385
386
387
388
389
390
391
392
393
394
          endif

        endif

        ! Sum up all ur(:) parts along rows and add them to the uc(:) parts
        ! on the processors containing the diagonal
        ! This is only necessary if ur has been calculated, i.e. if the
        ! global tile size is smaller than the global remaining matrix

        if (tile_size < istep-1) then
395
          call elpa_reduce_add_vectors_complex_PRECISION  (ur, ubound(ur,dim=1), mpi_comm_rows, &
396
397
398
399
400
401
402
403
404
                                          uc, ubound(uc,dim=1), mpi_comm_cols, &
                                          (istep-1), 1, nblk)
        endif

        ! Sum up all the uc(:) parts, transpose uc -> ur

        if (l_cols>0) then
          tmp(1:l_cols) = uc(1:l_cols)
#ifdef WITH_MPI
405
406
407
#ifdef HAVE_DETAILED_TIMINGS
          call timer%start("mpi_communication")
#endif
408
          call mpi_allreduce(tmp, uc, l_cols, MPI_COMPLEX_PRECISION, MPI_SUM, mpi_comm_rows, mpierr)
409
410
411
#ifdef HAVE_DETAILED_TIMINGS
          call timer%stop("mpi_communication")
#endif
412
413
414
415
416
417
418
419
420

#else /* WITH_MPI */
          uc = tmp
#endif /* WITH_MPI */
        endif

!        call elpa_transpose_vectors  (uc, 2*ubound(uc,dim=1), mpi_comm_cols, &
!                                      ur, 2*ubound(ur,dim=1), mpi_comm_rows, &
!                                      1, 2*(istep-1), 1, 2*nblk)
421
        call elpa_transpose_vectors_complex_PRECISION  (uc, ubound(uc,dim=1), mpi_comm_cols, &
422
423
424
425
426
427
428
429
                                              ur, ubound(ur,dim=1), mpi_comm_rows, &
                                              1, (istep-1), 1, nblk)

        ! calculate u**T * v (same as v**T * (A + VU**T + UV**T) * v )

        xc = 0
        if (l_cols>0) xc = dot_product(vc(1:l_cols),uc(1:l_cols))
#ifdef WITH_MPI
430
431
432
#ifdef HAVE_DETAILED_TIMINGS
        call timer%start("mpi_communication")
#endif
433
        call mpi_allreduce(xc, vav, 1 , MPI_COMPLEX_PRECISION, MPI_SUM, mpi_comm_cols, mpierr)
434
435
436
#ifdef HAVE_DETAILED_TIMINGS
        call timer%stop("mpi_communication")
#endif
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465

#else /* WITH_MPI */
        vav = xc
#endif /* WITH_MPI */

        ! store u and v in the matrices U and V
        ! these matrices are stored combined in one here

        do j=1,l_rows
          vur(j,2*nstor+1) = conjg(tau(istep))*vr(j)
          vur(j,2*nstor+2) = 0.5*conjg(tau(istep))*vav*vr(j) - ur(j)
        enddo
        do j=1,l_cols
          uvc(j,2*nstor+1) = 0.5*conjg(tau(istep))*vav*vc(j) - uc(j)
          uvc(j,2*nstor+2) = conjg(tau(istep))*vc(j)
        enddo

        nstor = nstor+1

        ! If the limit of max_stored_rows is reached, calculate A + VU**T + UV**T

        if (nstor==max_stored_rows .or. istep==3) then

          do i=0,(istep-2)/tile_size
            lcs = i*l_cols_tile+1
            lce = min(l_cols,(i+1)*l_cols_tile)
            lrs = 1
            lre = min(l_rows,(i+1)*l_rows_tile)
            if (lce<lcs .or. lre<lrs) cycle
466
            call PRECISION_GEMM('N', 'C', lre-lrs+1, lce-lcs+1, 2*nstor, CONE, &
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
                         vur(lrs,1), ubound(vur,dim=1), uvc(lcs,1), ubound(uvc,dim=1), &
                         CONE, a(lrs,lcs), lda)
          enddo

          nstor = 0

        endif

        if (my_prow==prow(istep-1, nblk, np_rows) .and. my_pcol==pcol(istep-1, nblk, np_cols)) then
          if (nstor>0) a(l_rows,l_cols) = a(l_rows,l_cols) &
                          + dot_product(vur(l_rows,1:2*nstor),uvc(l_cols,1:2*nstor))
          d(istep-1) = a(l_rows,l_cols)
        endif

      enddo ! istep

      ! Store e(1) and d(1)

      if (my_pcol==pcol(2, nblk, np_cols)) then
        if (my_prow==prow(1, nblk, np_rows)) then
          ! We use last l_cols value of loop above
          vrl = a(1,l_cols)
489
          call hh_transform_complex_PRECISION(vrl, CONST_REAL_0_0, xf, tau(2))
490
491
492
493
494
          e(1) = vrl
          a(1,l_cols) = 1. ! for consistency only
        endif

#ifdef WITH_MPI
495
496
497
#ifdef HAVE_DETAILED_TIMINGS
        call timer%start("mpi_communication")
#endif
498
        call mpi_bcast(tau(2), 1, MPI_COMPLEX_PRECISION, prow(1, nblk, np_rows), mpi_comm_rows, mpierr)
499
500
501
#ifdef HAVE_DETAILED_TIMINGS
        call timer%stop("mpi_communication")
#endif
502
503
504
505
506

#endif /* WITH_MPI */
      endif

#ifdef WITH_MPI
507
508
509
#ifdef HAVE_DETAILED_TIMINGS
      call timer%start("mpi_communication")
#endif
510
      call mpi_bcast(tau(2), 1, MPI_COMPLEX_PRECISION, pcol(2, nblk, np_cols), mpi_comm_cols, mpierr)
511
512
513
#ifdef HAVE_DETAILED_TIMINGS
      call timer%stop("mpi_communication")
#endif
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533

#endif /* WITH_MPI */


      if (my_prow==prow(1, nblk, np_rows) .and. my_pcol==pcol(1, nblk, np_cols)) d(1) = a(1,1)

      deallocate(tmp, vr, ur, vc, uc, vur, uvc, stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"tridiag_complex: error when deallocating tmp "//errorMessage
       stop
      endif
      ! distribute the arrays d and e to all processors

      allocate(tmpr(na), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"tridiag_complex: error when allocating tmpr "//errorMessage
       stop
      endif

#ifdef WITH_MPI
534
535
536
#ifdef HAVE_DETAILED_TIMINGS
      call timer%start("mpi_communication")
#endif
537
      tmpr = d
538
      call mpi_allreduce(tmpr, d, na, MPI_REAL_PRECISION, MPI_SUM, mpi_comm_rows, mpierr)
539
      tmpr = d
540
      call mpi_allreduce(tmpr, d, na, MPI_REAL_PRECISION ,MPI_SUM, mpi_comm_cols, mpierr)
541
      tmpr = e
542
      call mpi_allreduce(tmpr, e, na, MPI_REAL_PRECISION, MPI_SUM, mpi_comm_rows, mpierr)
543
      tmpr = e
544
      call mpi_allreduce(tmpr, e, na, MPI_REAL_PRECISION, MPI_SUM, mpi_comm_cols, mpierr)
545
546
547
#ifdef HAVE_DETAILED_TIMINGS
      call timer%stop("mpi_communication")
#endif
548
549
550
551
552
553
554
555

#endif /* WITH_MPI */
      deallocate(tmpr, stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"tridiag_complex: error when deallocating tmpr "//errorMessage
       stop
      endif

556
      call timer%stop("tridiag_complex" // PRECISION_SUFFIX)
557

558
    end subroutine tridiag_complex_PRECISION
559

560
    subroutine trans_ev_complex_PRECISION(na, nqc, a, lda, tau, q, ldq, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols)
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
    !-------------------------------------------------------------------------------
    !  trans_ev_complex: Transforms the eigenvectors of a tridiagonal matrix back
    !                    to the eigenvectors of the original matrix
    !                    (like Scalapack Routine PZUNMTR)
    !
    !  Parameters
    !
    !  na          Order of matrix a, number of rows of matrix q
    !
    !  nqc         Number of columns of matrix q
    !
    !  a(lda,matrixCols)    Matrix containing the Householder vectors (i.e. matrix a after tridiag_complex)
    !              Distribution is like in Scalapack.
    !
    !  lda         Leading dimension of a
    !
    !  tau(na)     Factors of the Householder vectors
    !
    !  q           On input: Eigenvectors of tridiagonal matrix
    !              On output: Transformed eigenvectors
    !              Distribution is like in Scalapack.
    !
    !  ldq         Leading dimension of q
    !
    !  nblk        blocksize of cyclic distribution, must be the same in both directions!
    !
    !  mpi_comm_rows
    !  mpi_comm_cols
    !              MPI-Communicators for rows/columns
    !
    !-------------------------------------------------------------------------------
#ifdef HAVE_DETAILED_TIMINGS
      use timings
594
595
#else
      use timings_dummy
596
597
598
599
600
601
#endif
      use precision
      implicit none

      integer(kind=ik)              ::  na, nqc, lda, ldq, nblk, matrixCols, mpi_comm_rows, mpi_comm_cols
      complex(kind=COMPLEX_DATATYPE)              ::  tau(na)
602
#ifdef USE_ASSUMED_SIZE
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
      complex(kind=COMPLEX_DATATYPE)              :: a(lda,*), q(ldq,*)
#else
      complex(kind=COMPLEX_DATATYPE)              ::  a(lda,matrixCols), q(ldq,matrixCols)
#endif
      integer(kind=ik)              :: max_stored_rows
#ifdef DOUBLE_PRECISION_COMPLEX
      complex(kind=ck8), parameter   :: CZERO = (0.0_rk8,0.0_rk8), CONE = (1.0_rk8,0.0_rk8)
#else
      complex(kind=ck4), parameter   :: CZERO = (0.0_rk4,0.0_rk4), CONE = (1.0_rk4,0.0_rk4)
#endif
      integer(kind=ik)              :: my_prow, my_pcol, np_rows, np_cols, mpierr
      integer(kind=ik)              :: totalblocks, max_blocks_row, max_blocks_col, max_local_rows, max_local_cols
      integer(kind=ik)              :: l_cols, l_rows, l_colh, nstor
      integer(kind=ik)              :: istep, i, n, nc, ic, ics, ice, nb, cur_pcol

      complex(kind=COMPLEX_DATATYPE), allocatable :: tmp1(:), tmp2(:), hvb(:), hvm(:,:)
      complex(kind=COMPLEX_DATATYPE), allocatable :: tmat(:,:), h1(:), h2(:)
      integer(kind=ik)              :: istat
      character(200)                :: errorMessage
622
623

      call timer%start("trans_ev_complex" // PRECISION_SUFFIX)
624
625
#ifdef HAVE_DETAILED_TIMINGS
      call timer%start("mpi_communication")
626
627
628
629
630
631
#endif

      call mpi_comm_rank(mpi_comm_rows,my_prow,mpierr)
      call mpi_comm_size(mpi_comm_rows,np_rows,mpierr)
      call mpi_comm_rank(mpi_comm_cols,my_pcol,mpierr)
      call mpi_comm_size(mpi_comm_cols,np_cols,mpierr)
632
633
634
#ifdef HAVE_DETAILED_TIMINGS
      call timer%stop("mpi_communication")
#endif
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695

      totalblocks = (na-1)/nblk + 1
      max_blocks_row = (totalblocks-1)/np_rows + 1
      max_blocks_col = ((nqc-1)/nblk)/np_cols + 1  ! Columns of q!

      max_local_rows = max_blocks_row*nblk
      max_local_cols = max_blocks_col*nblk

      max_stored_rows = (63/nblk+1)*nblk

      allocate(tmat(max_stored_rows,max_stored_rows), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"trans_ev_complex: error when allocating tmat "//errorMessage
       stop
      endif

      allocate(h1(max_stored_rows*max_stored_rows), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"trans_ev_complex: error when allocating h1 "//errorMessage
       stop
      endif

      allocate(h2(max_stored_rows*max_stored_rows), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"trans_ev_complex: error when allocating h2 "//errorMessage
       stop
      endif

      allocate(tmp1(max_local_cols*max_stored_rows), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"trans_ev_complex: error when allocating tmp1 "//errorMessage
       stop
      endif

      allocate(tmp2(max_local_cols*max_stored_rows), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"trans_ev_complex: error when allocating tmp2 "//errorMessage
       stop
      endif

      allocate(hvb(max_local_rows*nblk), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"trans_ev_complex: error when allocating hvb "//errorMessage
       stop
      endif

      allocate(hvm(max_local_rows,max_stored_rows), stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"trans_ev_complex: error when allocating hvm "//errorMessage
       stop
      endif

      hvm = 0   ! Must be set to 0 !!!
      hvb = 0   ! Safety only

      l_cols = local_index(nqc, my_pcol, np_cols, nblk, -1) ! Local columns of q

      nstor = 0

      ! In the complex case tau(2) /= 0
      if (my_prow == prow(1, nblk, np_rows)) then
696
        q(1,1:l_cols) = q(1,1:l_cols)*(CONE-tau(2))
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
      endif

      do istep=1,na,nblk

        ics = MAX(istep,3)
        ice = MIN(istep+nblk-1,na)
        if (ice<ics) cycle

        cur_pcol = pcol(istep, nblk, np_cols)

        nb = 0
        do ic=ics,ice

          l_colh = local_index(ic  , my_pcol, np_cols, nblk, -1) ! Column of Householder vector
          l_rows = local_index(ic-1, my_prow, np_rows, nblk, -1) ! # rows of Householder vector


          if (my_pcol==cur_pcol) then
            hvb(nb+1:nb+l_rows) = a(1:l_rows,l_colh)
            if (my_prow==prow(ic-1, nblk, np_rows)) then
              hvb(nb+l_rows) = 1.
            endif
          endif

          nb = nb+l_rows
        enddo

#ifdef WITH_MPI
725
726
727
728
#ifdef HAVE_DETAILED_TIMINGS
        call timer%start("mpi_communication")
#endif

729
        if (nb>0) &
730
           call MPI_Bcast(hvb, nb, MPI_COMPLEX_PRECISION, cur_pcol, mpi_comm_cols, mpierr)
731
732
733
#ifdef HAVE_DETAILED_TIMINGS
        call timer%stop("mpi_communication")
#endif
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751

#endif /* WITH_MPI */
        nb = 0
        do ic=ics,ice
          l_rows = local_index(ic-1, my_prow, np_rows, nblk, -1) ! # rows of Householder vector
          hvm(1:l_rows,nstor+1) = hvb(nb+1:nb+l_rows)
          nstor = nstor+1
          nb = nb+l_rows
        enddo

        ! Please note: for smaller matix sizes (na/np_rows<=256), a value of 32 for nstor is enough!
        if (nstor+nblk>max_stored_rows .or. istep+nblk>na .or. (na/np_rows<=256 .and. nstor>=32)) then

          ! Calculate scalar products of stored vectors.
          ! This can be done in different ways, we use zherk

          tmat = 0
          if (l_rows>0) &
752
             call PRECISION_HERK('U', 'C', nstor, l_rows, CONE, hvm, ubound(hvm,dim=1), CZERO, tmat, max_stored_rows)
753
754
755
756
757
758
          nc = 0
          do n=1,nstor-1
            h1(nc+1:nc+n) = tmat(1:n,n+1)
            nc = nc+n
          enddo
#ifdef WITH_MPI
759
760
761
#ifdef HAVE_DETAILED_TIMINGS
          call timer%start("mpi_communication")
#endif
762
          if (nc>0) call mpi_allreduce(h1, h2, nc, MPI_COMPLEX_PRECISION, MPI_SUM, mpi_comm_rows, mpierr)
763
764
765
#ifdef HAVE_DETAILED_TIMINGS
          call timer%stop("mpi_communication")
#endif
766
767
768
769
770
771
772
773
774

#else /* WITH_MPI */
          if (nc>0) h2=h1
#endif /* WITH_MPI */
          ! Calculate triangular matrix T

          nc = 0
          tmat(1,1) = tau(ice-nstor+1)
          do n=1,nstor-1
775
            call PRECISION_TRMV('L', 'C', 'N', n, tmat, max_stored_rows, h2(nc+1),1)
776
777
778
779
780
781
782
783
            tmat(n+1,1:n) = -conjg(h2(nc+1:nc+n))*tau(ice-nstor+n+1)
            tmat(n+1,n+1) = tau(ice-nstor+n+1)
            nc = nc+n
          enddo

          ! Q = Q - V * T * V**T * Q

          if (l_rows>0) then
784
            call PRECISION_GEMM('C', 'N', nstor, l_cols, l_rows, CONE, hvm, ubound(hvm,dim=1), &
785
786
787
788
789
                        q, ldq, CZERO, tmp1 ,nstor)
          else
            tmp1(1:l_cols*nstor) = 0
          endif
#ifdef WITH_MPI
790
791
792
793
#ifdef HAVE_DETAILED_TIMINGS
          call timer%start("mpi_communication")
#endif

794
          call mpi_allreduce(tmp1, tmp2, nstor*l_cols, MPI_COMPLEX_PRECISION, MPI_SUM, mpi_comm_rows, mpierr)
795
796
797
#ifdef HAVE_DETAILED_TIMINGS
          call timer%stop("mpi_communication")
#endif
798

799
          if (l_rows>0) then
800
801
            call PRECISION_TRMM('L', 'L', 'N', 'N', nstor, l_cols, CONE, tmat, max_stored_rows, tmp2, nstor)
            call PRECISION_GEMM('N', 'N', l_rows, l_cols, nstor, -CONE, hvm, ubound(hvm,dim=1), &
802
803
                        tmp2, nstor, CONE, q, ldq)
          endif
804
805
806
807
808

#else
!          tmp2 = tmp1

          if (l_rows>0) then
809
810
            call PRECISION_TRMM('L', 'L', 'N', 'N', nstor, l_cols, CONE, tmat, max_stored_rows, tmp1, nstor)
            call PRECISION_GEMM('N', 'N', l_rows, l_cols, nstor, -CONE, hvm, ubound(hvm,dim=1), &
811
812
813
814
815
816
                        tmp1, nstor, CONE, q, ldq)
          endif

#endif

!          if (l_rows>0) then
817
818
!            call PRECISION_TRMM('L', 'L', 'N', 'N', nstor, l_cols, CONE, tmat, max_stored_rows, tmp2, nstor)
!            call PRECISION_GEMM('N', 'N', l_rows, l_cols, nstor, -CONE, hvm, ubound(hvm,dim=1), &
819
820
821
!                        tmp2, nstor, CONE, q, ldq)
!          endif

822
823
824
825
826
827
828
829
830
831
832
          nstor = 0
        endif

      enddo

      deallocate(tmat, h1, h2, tmp1, tmp2, hvb, hvm, stat=istat, errmsg=errorMessage)
      if (istat .ne. 0) then
       print *,"trans_ev_complex: error when deallocating hvb "//errorMessage
       stop
      endif

833
      call timer%stop("trans_ev_complex" // PRECISION_SUFFIX)
834

835
    end subroutine trans_ev_complex_PRECISION
836

837
    subroutine hh_transform_complex_PRECISION(alpha, xnorm_sq, xf, tau)
838
839
840
841
842
843
844
845

      ! Similar to LAPACK routine ZLARFP, but uses ||x||**2 instead of x(:)
      ! and returns the factor xf by which x has to be scaled.
      ! It also hasn't the special handling for numbers < 1.d-300 or > 1.d150
      ! since this would be expensive for the parallel implementation.
      use precision
#ifdef HAVE_DETAILED_TIMINGS
      use timings
846
847
#else
      use timings_dummy
848
849
850
851
852
853
854
#endif
      implicit none
      complex(kind=COMPLEX_DATATYPE), intent(inout) :: alpha
      real(kind=REAL_DATATYPE), intent(in)       :: xnorm_sq
      complex(kind=COMPLEX_DATATYPE), intent(out)   :: xf, tau

      real(kind=REAL_DATATYPE)                   :: ALPHR, ALPHI, BETA
855
856
857
      
      call timer%start("hh_transform_complex" // PRECISION_SUFFIX)

858
      ALPHR = real( ALPHA, kind=REAL_DATATYPE )
859
      ALPHI = PRECISION_IMAG( ALPHA )
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
      if ( XNORM_SQ==0. .AND. ALPHI==0. ) then

        if ( ALPHR>=0. ) then
          TAU = 0.
        else
          TAU = 2.
          ALPHA = -ALPHA
        endif
        XF = 0.

      else

        BETA = SIGN( SQRT( ALPHR**2 + ALPHI**2 + XNORM_SQ ), ALPHR )
        ALPHA = ALPHA + BETA
        IF ( BETA<0 ) THEN
          BETA = -BETA
          TAU = -ALPHA / BETA
        ELSE
878
879
          ALPHR = ALPHI * (ALPHI/real( ALPHA , kind=KIND_PRECISION))
          ALPHR = ALPHR + XNORM_SQ/real( ALPHA, kind=KIND_PRECISION )
880

881
882
          TAU = PRECISION_CMPLX( ALPHR/BETA, -ALPHI/BETA )
          ALPHA = PRECISION_CMPLX( -ALPHR, ALPHI )
883
        END IF
884
        XF = CONST_REAL_1_0/ALPHA
885
886
887
        ALPHA = BETA
      endif

888
      call timer%stop("hh_transform_complex" // PRECISION_SUFFIX)
889

890
    end subroutine hh_transform_complex_PRECISION
891
892

#define ALREADY_DEFINED 1