elpa_pdlarfb.F90 21.6 KB
Newer Older
1
2
!    This file is part of ELPA.
!
3
!    The ELPA library was originally created by the ELPA consortium,
4
5
!    consisting of the following organizations:
!
6
7
!    - Max Planck Computing and Data Facility (MPCDF), formerly known as
!      Rechenzentrum Garching der Max-Planck-Gesellschaft (RZG),
8
9
10
!    - Bergische Universität Wuppertal, Lehrstuhl für angewandte
!      Informatik,
!    - Technische Universität München, Lehrstuhl für Informatik mit
11
12
13
14
15
!      Schwerpunkt Wissenschaftliches Rechnen ,
!    - Fritz-Haber-Institut, Berlin, Abt. Theorie,
!    - Max-Plack-Institut für Mathematik in den Naturwissenschaftrn,
!      Leipzig, Abt. Komplexe Strukutren in Biologie und Kognition,
!      and
16
17
18
19
!    - IBM Deutschland GmbH
!
!
!    More information can be found here:
20
!    http://elpa.mpcdf.mpg.de/
21
22
!
!    ELPA is free software: you can redistribute it and/or modify
23
24
!    it under the terms of the version 3 of the license of the
!    GNU Lesser General Public License as published by the Free
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
!    Software Foundation.
!
!    ELPA is distributed in the hope that it will be useful,
!    but WITHOUT ANY WARRANTY; without even the implied warranty of
!    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!    GNU Lesser General Public License for more details.
!
!    You should have received a copy of the GNU Lesser General Public License
!    along with ELPA.  If not, see <http://www.gnu.org/licenses/>
!
!    ELPA reflects a substantial effort on the part of the original
!    ELPA consortium, and we ask you to respect the spirit of the
!    license that we chose: i.e., please contribute any changes you
!    may have back to the original ELPA library distribution, and keep
!    any derivatives of ELPA under the same license that we chose for
!    the original distribution, the GNU Lesser General Public License.
!
!
43
44
#include "config-f90.h"

45
46
module elpa_pdlarfb

47
    use elpa1_compute
48
    use qr_utils_mod
49
    use elpa_mpi
50
51
52
53
54
55
56
57
58
59
60
    implicit none

    PRIVATE

    public :: qr_pdlarfb_1dcomm
    public :: qr_pdlarft_pdlarfb_1dcomm
    public :: qr_pdlarft_set_merge_1dcomm
    public :: qr_pdlarft_tree_merge_1dcomm
    public :: qr_pdlarfl_1dcomm
    public :: qr_pdlarfl2_tmatrix_1dcomm
    public :: qr_tmerge_pdlarfb_1dcomm
61

62
63
64
contains

subroutine qr_pdlarfb_1dcomm(m,mb,n,k,a,lda,v,ldv,tau,t,ldt,baseidx,idx,rev,mpicomm,work,lwork)
Andreas Marek's avatar
Andreas Marek committed
65
    use precision
66
67
68
    use qr_utils_mod

    implicit none
69

70
    ! input variables (local)
Andreas Marek's avatar
Andreas Marek committed
71
72
    integer(kind=ik)  :: lda,ldv,ldt,lwork
    real(kind=rk)     :: a(lda,*),v(ldv,*),tau(*),t(ldt,*),work(k,*)
73
74

    ! input variables (global)
Andreas Marek's avatar
Andreas Marek committed
75
    integer(kind=ik)  :: m,mb,n,k,baseidx,idx,rev,mpicomm
76

77
78
79
80
81
    ! output variables (global)

    ! derived input variables from QR_PQRPARAM

    ! local scalars
Andreas Marek's avatar
Andreas Marek committed
82
83
    integer(kind=ik)  :: localsize,offset,baseoffset
    integer(kind=ik)  :: mpirank,mpiprocs,mpierr
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

        if (idx .le. 1) return

    if (n .le. 0) return ! nothing to do

    if (k .eq. 1) then
        call qr_pdlarfl_1dcomm(v,1,baseidx,a,lda,tau(1), &
                                work,lwork,m,n,idx,mb,rev,mpicomm)
        return
    else if (k .eq. 2) then
        call qr_pdlarfl2_tmatrix_1dcomm(v,ldv,baseidx,a,lda,t,ldt, &
                                 work,lwork,m,n,idx,mb,rev,mpicomm)
        return
    end if

    if (lwork .eq. -1) then
        work(1,1) = DBLE(2*k*n)
        return
    end if
103

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    !print *,'updating trailing matrix with k=',k
    call MPI_Comm_rank(mpicomm,mpirank,mpierr)
    call MPI_Comm_size(mpicomm,mpiprocs,mpierr)
    ! use baseidx as idx here, otherwise the upper triangle part will be lost
    ! during the calculation, especially in the reversed case
    call local_size_offset_1d(m,mb,baseidx,baseidx,rev,mpirank,mpiprocs, &
                                localsize,baseoffset,offset)

    ! Z' = Y' * A
    if (localsize .gt. 0) then
        call dgemm("Trans","Notrans",k,n,localsize,1.0d0,v(baseoffset,1),ldv,a(offset,1),lda,0.0d0,work(1,1),k)
    else
        work(1:k,1:n) = 0.0d0
    end if

    ! data exchange
120
#ifdef WITH_MPI
121
    call mpi_allreduce(work(1,1),work(1,n+1),k*n,mpi_real8,mpi_sum,mpicomm,mpierr)
122
123
124
#else
    work(1:k*n,n+1) = work(1:k*n,1)
#endif
125
    call qr_pdlarfb_kernel_local(localsize,n,k,a(offset,1),lda,v(baseoffset,1),ldv,t,ldt,work(1,n+1),k)
126
end subroutine qr_pdlarfb_1dcomm
127
128
129
130

! generalized pdlarfl2 version
! TODO: include T merge here (seperate by "old" and "new" index)
subroutine qr_pdlarft_pdlarfb_1dcomm(m,mb,n,oldk,k,v,ldv,tau,t,ldt,a,lda,baseidx,rev,mpicomm,work,lwork)
Andreas Marek's avatar
Andreas Marek committed
131
    use precision
132
133
134
135
136
    use qr_utils_mod

    implicit none

    ! input variables (local)
Andreas Marek's avatar
Andreas Marek committed
137
138
    integer(kind=ik)  :: ldv,ldt,lda,lwork
    real(kind=rk)     :: v(ldv,*),tau(*),t(ldt,*),work(k,*),a(lda,*)
139
140

    ! input variables (global)
Andreas Marek's avatar
Andreas Marek committed
141
    integer(kind=ik)  :: m,mb,n,k,oldk,baseidx,rev,mpicomm
142

143
144
145
146
147
    ! output variables (global)

    ! derived input variables from QR_PQRPARAM

    ! local scalars
Andreas Marek's avatar
Andreas Marek committed
148
149
150
    integer(kind=ik)  :: localsize,offset,baseoffset
    integer(kind=ik)  :: mpirank,mpiprocs,mpierr
    integer(kind=ik)  :: icol
151

Andreas Marek's avatar
Andreas Marek committed
152
    integer(kind=ik)  :: sendoffset,recvoffset,sendsize
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170

    sendoffset = 1
    sendsize = k*(k+n+oldk)
    recvoffset = sendoffset+(k+n+oldk)

    if (lwork .eq. -1) then
        work(1,1) = DBLE(2*(k*k+k*n+oldk))
        return
    end if
    call MPI_Comm_rank(mpicomm,mpirank,mpierr)
    call MPI_Comm_size(mpicomm,mpiprocs,mpierr)
    call local_size_offset_1d(m,mb,baseidx,baseidx,rev,mpirank,mpiprocs, &
                                localsize,baseoffset,offset)

    if (localsize .gt. 0) then
            ! calculate inner product of householdervectors
            call dsyrk("Upper","Trans",k,localsize,1.0d0,v(baseoffset,1),ldv,0.0d0,work(1,1),k)

171
            ! calculate matrix matrix product of householder vectors and target matrix
172
173
174
175
176
177
178
179
180
181
            ! Z' = Y' * A
            call dgemm("Trans","Notrans",k,n,localsize,1.0d0,v(baseoffset,1),ldv,a(offset,1),lda,0.0d0,work(1,k+1),k)

            ! TODO: reserved for T merge parts
            work(1:k,n+k+1:n+k+oldk) = 0.0d0
    else
        work(1:k,1:(n+k+oldk)) = 0.0d0
    end if

    ! exchange data
182
#ifdef WITH_MPI
183
    call mpi_allreduce(work(1,sendoffset),work(1,recvoffset),sendsize,mpi_real8,mpi_sum,mpicomm,mpierr)
184
185
186
#else
    work(1:sendsize,recvoffset) = work(1:sendsize,sendoffset)
#endif
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        ! generate T matrix (pdlarft)
        t(1:k,1:k) = 0.0d0 ! DEBUG: clear buffer first

        ! T1 = tau1
        ! | tauk  Tk-1' * (-tauk * Y(:,1,k+1:n) * Y(:,k))' |
        ! | 0           Tk-1                           |
        t(k,k) = tau(k)
        do icol=k-1,1,-1
            t(icol,icol+1:k) = -tau(icol)*work(icol,recvoffset+icol:recvoffset+k-1)
            call dtrmv("Upper","Trans","Nonunit",k-icol,t(icol+1,icol+1),ldt,t(icol,icol+1),ldt)
            t(icol,icol) = tau(icol)
        end do

        ! TODO: elmroth and gustavson
201

202
203
204
205
206
207
208
209
210
211
        ! update matrix (pdlarfb)
        ! Z' = T * Z'
        call dtrmm("Left","Upper","Notrans","Nonunit",k,n,1.0d0,t,ldt,work(1,recvoffset+k),k)

        ! A = A - Y * V'
        call dgemm("Notrans","Notrans",localsize,n,k,-1.0d0,v(baseoffset,1),ldv,work(1,recvoffset+k),k,1.0d0,a(offset,1),lda)

end subroutine qr_pdlarft_pdlarfb_1dcomm

subroutine qr_pdlarft_set_merge_1dcomm(m,mb,n,blocksize,v,ldv,t,ldt,baseidx,rev,mpicomm,work,lwork)
Andreas Marek's avatar
Andreas Marek committed
212
    use precision
213
214
215
    use qr_utils_mod

    implicit none
216

217
    ! input variables (local)
Andreas Marek's avatar
Andreas Marek committed
218
219
    integer(kind=ik)  :: ldv,ldt,lwork
    real(kind=rk)     :: v(ldv,*),t(ldt,*),work(n,*)
220
221

    ! input variables (global)
Andreas Marek's avatar
Andreas Marek committed
222
    integer(kind=ik)  :: m,mb,n,blocksize,baseidx,rev,mpicomm
223

224
225
226
227
228
    ! output variables (global)

    ! derived input variables from QR_PQRPARAM

    ! local scalars
Andreas Marek's avatar
Andreas Marek committed
229
230
    integer(kind=ik)  :: localsize,offset,baseoffset
    integer(kind=ik)  :: mpirank,mpiprocs,mpierr
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

    if (lwork .eq. -1) then
        work(1,1) = DBLE(2*n*n)
        return
    end if
    call MPI_Comm_rank(mpicomm,mpirank,mpierr)
    call MPI_Comm_size(mpicomm,mpiprocs,mpierr)
    call local_size_offset_1d(m,mb,baseidx,baseidx,rev,mpirank,mpiprocs, &
                                localsize,baseoffset,offset)

    if (localsize .gt. 0) then
        call dsyrk("Upper","Trans",n,localsize,1.0d0,v(baseoffset,1),ldv,0.0d0,work(1,1),n)
    else
        work(1:n,1:n) = 0.0d0
    end if
246
#ifdef WITH_MPI
247
    call mpi_allreduce(work(1,1),work(1,n+1),n*n,mpi_real8,mpi_sum,mpicomm,mpierr)
248
249
250
#else
    work(1:n,n+1:n+1+n-1) = work(1:n,1:n)
#endif
251
252
253
254
255
256
257
258
        ! skip Y4'*Y4 part
        offset = mod(n,blocksize)
        if (offset .eq. 0) offset=blocksize
        call qr_tmerge_set_kernel(n,blocksize,t,ldt,work(1,n+1+offset),n)

end subroutine qr_pdlarft_set_merge_1dcomm

subroutine qr_pdlarft_tree_merge_1dcomm(m,mb,n,blocksize,treeorder,v,ldv,t,ldt,baseidx,rev,mpicomm,work,lwork)
Andreas Marek's avatar
Andreas Marek committed
259
    use precision
260
261
262
    use qr_utils_mod

    implicit none
263

264
    ! input variables (local)
Andreas Marek's avatar
Andreas Marek committed
265
266
    integer(kind=ik) :: ldv,ldt,lwork
    real(kind=rk)    :: v(ldv,*),t(ldt,*),work(n,*)
267
268

    ! input variables (global)
Andreas Marek's avatar
Andreas Marek committed
269
    integer(kind=ik) :: m,mb,n,blocksize,treeorder,baseidx,rev,mpicomm
270

271
272
273
274
275
    ! output variables (global)

    ! derived input variables from QR_PQRPARAM

    ! local scalars
Andreas Marek's avatar
Andreas Marek committed
276
277
    integer(kind=ik) :: localsize,offset,baseoffset
    integer(kind=ik) :: mpirank,mpiprocs,mpierr
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294

    if (lwork .eq. -1) then
        work(1,1) = DBLE(2*n*n)
        return
    end if

    if (n .le. blocksize) return ! nothing to do
    call MPI_Comm_rank(mpicomm,mpirank,mpierr)
    call MPI_Comm_size(mpicomm,mpiprocs,mpierr)
    call local_size_offset_1d(m,mb,baseidx,baseidx,rev,mpirank,mpiprocs, &
                                localsize,baseoffset,offset)

    if (localsize .gt. 0) then
        call dsyrk("Upper","Trans",n,localsize,1.0d0,v(baseoffset,1),ldv,0.0d0,work(1,1),n)
    else
        work(1:n,1:n) = 0.0d0
    end if
295
#ifdef WITH_MPI
296
    call mpi_allreduce(work(1,1),work(1,n+1),n*n,mpi_real8,mpi_sum,mpicomm,mpierr)
297
298
299
#else
    work(1:n,n+1:n+1+n-1) = work(1:n,1:n)
#endif
300
301
302
303
304
305
306
        ! skip Y4'*Y4 part
        offset = mod(n,blocksize)
        if (offset .eq. 0) offset=blocksize
        call qr_tmerge_tree_kernel(n,blocksize,treeorder,t,ldt,work(1,n+1+offset),n)

end subroutine qr_pdlarft_tree_merge_1dcomm

307
! apply householder vector to the left
308
309
310
! - assume unitary matrix
! - assume right positions for v
subroutine qr_pdlarfl_1dcomm(v,incv,baseidx,a,lda,tau,work,lwork,m,n,idx,mb,rev,mpicomm)
Andreas Marek's avatar
Andreas Marek committed
311
    use precision
312
313
314
315
    use ELPA1
    use qr_utils_mod

    implicit none
316

317
    ! input variables (local)
Andreas Marek's avatar
Andreas Marek committed
318
319
    integer(kind=ik) :: incv,lda,lwork,baseidx
    real(kind=rk)    :: v(*),a(lda,*),work(*)
320
321

    ! input variables (global)
Andreas Marek's avatar
Andreas Marek committed
322
323
    integer(kind=ik) :: m,n,mb,rev,idx,mpicomm
    real(kind=rk)    :: tau
324

325
    ! output variables (global)
326

327
    ! local scalars
Andreas Marek's avatar
Andreas Marek committed
328
329
330
331
    integer(kind=ik) :: mpierr,mpirank,mpiprocs
    integer(kind=ik) :: sendsize,recvsize,icol
    integer(kind=ik) :: local_size,local_offset
    integer(kind=ik) :: v_local_offset
332
333

    ! external functions
Andreas Marek's avatar
Andreas Marek committed
334
    real(kind=rk), external :: ddot
335
336
337
338
339
340
341
342
343
    call MPI_Comm_rank(mpicomm, mpirank, mpierr)
    call MPI_Comm_size(mpicomm, mpiprocs, mpierr)
    sendsize = n
    recvsize = sendsize

    if (lwork .eq. -1) then
        work(1) = DBLE(sendsize + recvsize)
        return
    end if
344

345
346
347
    if (n .le. 0) return

        if (idx .le. 1) return
348

349
350
    call local_size_offset_1d(m,mb,baseidx,idx,rev,mpirank,mpiprocs, &
                              local_size,v_local_offset,local_offset)
351

352
353
354
355
356
    !print *,'hl ref',local_size,n

    v_local_offset = v_local_offset * incv

    if (local_size > 0) then
357

358
359
360
361
362
363
364
        do icol=1,n
            work(icol) = dot_product(v(v_local_offset:v_local_offset+local_size-1),a(local_offset:local_offset+local_size-1,icol))

        end do
    else
        work(1:n) = 0.0d0
    end if
365
#ifdef WITH_MPI
366
    call mpi_allreduce(work, work(sendsize+1), sendsize, mpi_real8, mpi_sum, mpicomm, mpierr)
367
368
369
#else
    work(sendsize+1:sendsize+1+sendsize+1+sendsize-1) = work(1:sendsize)
#endif
370
371
372
373
374
375
376
377
378
379
380
381
    if (local_size > 0) then

         do icol=1,n
               a(local_offset:local_offset+local_size-1,icol) = a(local_offset:local_offset+local_size-1,icol) &
                                                                - tau*work(sendsize+icol)*v(v_local_offset:v_local_offset+ &
                                                                           local_size-1)
         enddo
    end if

end subroutine qr_pdlarfl_1dcomm

subroutine qr_pdlarfl2_tmatrix_1dcomm(v,ldv,baseidx,a,lda,t,ldt,work,lwork,m,n,idx,mb,rev,mpicomm)
Andreas Marek's avatar
Andreas Marek committed
382
    use precision
383
384
385
386
    use ELPA1
    use qr_utils_mod

    implicit none
387

388
    ! input variables (local)
Andreas Marek's avatar
Andreas Marek committed
389
390
    integer(kind=ik) :: ldv,lda,lwork,baseidx,ldt
    real(kind=rk)    :: v(ldv,*),a(lda,*),work(*),t(ldt,*)
391
392

    ! input variables (global)
Andreas Marek's avatar
Andreas Marek committed
393
    integer(kind=ik) :: m,n,mb,rev,idx,mpicomm
394

395
    ! output variables (global)
396

397
    ! local scalars
Andreas Marek's avatar
Andreas Marek committed
398
399
400
401
402
403
404
405
406
407
    integer(kind=ik) :: mpierr,mpirank,mpiprocs,mpirank_top1,mpirank_top2
    integer(kind=ik) :: dgemv1_offset,dgemv2_offset
    integer(kind=ik) :: sendsize, recvsize
    integer(kind=ik) :: local_size1,local_offset1
    integer(kind=ik) :: local_size2,local_offset2
    integer(kind=ik) :: local_size_dger,local_offset_dger
    integer(kind=ik) :: v1_local_offset,v2_local_offset
    integer(kind=ik) :: v_local_offset_dger
    real(kind=rk)    :: hvdot
    integer(kind=ik) :: irow,icol,v1col,v2col
408
409

    ! external functions
Andreas Marek's avatar
Andreas Marek committed
410
    real(kind=rk), external :: ddot
411
412
413
414
415
416
417
418
419
    call MPI_Comm_rank(mpicomm, mpirank, mpierr)
    call MPI_Comm_size(mpicomm, mpiprocs, mpierr)
    sendsize = 2*n
    recvsize = sendsize

    if (lwork .eq. -1) then
        work(1) = sendsize + recvsize
        return
    end if
420

421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    dgemv1_offset = 1
    dgemv2_offset = dgemv1_offset + n

        ! in 2x2 matrix case only one householder vector was generated
        if (idx .le. 2) then
            call qr_pdlarfl_1dcomm(v(1,2),1,baseidx,a,lda,t(2,2), &
                                    work,lwork,m,n,idx,mb,rev,mpicomm)
            return
        end if

        call local_size_offset_1d(m,mb,baseidx,idx,rev,mpirank,mpiprocs, &
                                  local_size1,v1_local_offset,local_offset1)
        call local_size_offset_1d(m,mb,baseidx,idx-1,rev,mpirank,mpiprocs, &
                                  local_size2,v2_local_offset,local_offset2)

        v1_local_offset = v1_local_offset * 1
        v2_local_offset = v2_local_offset * 1

        v1col = 2
        v2col = 1

        ! keep buffers clean in case that local_size1/local_size2 are zero
        work(1:sendsize) = 0.0d0

        call dgemv("Trans",local_size1,n,1.0d0,a(local_offset1,1),lda,v(v1_local_offset,v1col),1,0.0d0,work(dgemv1_offset),1)
        call dgemv("Trans",local_size2,n,t(v2col,v2col),a(local_offset2,1),lda,v(v2_local_offset,v2col),1,0.0d0, &
                   work(dgemv2_offset),1)
448
#ifdef WITH_MPI
449
        call mpi_allreduce(work, work(sendsize+1), sendsize, mpi_real8, mpi_sum, mpicomm, mpierr)
450
451
452
#else
        work(sendsize+1:sendsize+1+sendsize-1) = work(1:sendsize)
#endif
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
        ! update second vector
        call daxpy(n,t(1,2),work(sendsize+dgemv1_offset),1,work(sendsize+dgemv2_offset),1)

        call local_size_offset_1d(m,mb,baseidx,idx-2,rev,mpirank,mpiprocs, &
                                  local_size_dger,v_local_offset_dger,local_offset_dger)

        ! get ranks of processes with topelements
        mpirank_top1 = MOD((idx-1)/mb,mpiprocs)
        mpirank_top2 = MOD((idx-2)/mb,mpiprocs)

        if (mpirank_top1 .eq. mpirank) local_offset1 = local_size1
        if (mpirank_top2 .eq. mpirank) then
            local_offset2 = local_size2
            v2_local_offset = local_size2
        end if

    ! use hvdot as temporary variable
    hvdot = t(v1col,v1col)
    do icol=1,n
        ! make use of "1" entries in householder vectors
        if (mpirank_top1 .eq. mpirank) then
            a(local_offset1,icol) = a(local_offset1,icol) &
                                    - work(sendsize+dgemv1_offset+icol-1)*hvdot
        end if

        if (mpirank_top2 .eq. mpirank) then
479
            a(local_offset2,icol) = a(local_offset2,icol) &
480
481
482
483
484
485
                                    - v(v2_local_offset,v1col)*work(sendsize+dgemv1_offset+icol-1)*hvdot &
                                    - work(sendsize+dgemv2_offset+icol-1)
        end if

        do irow=1,local_size_dger
            a(local_offset_dger+irow-1,icol) = a(local_offset_dger+irow-1,icol) &
486
                                    - work(sendsize+dgemv1_offset+icol-1)*v(v_local_offset_dger+irow-1,v1col)*hvdot &
487
488
489
490
491
492
493
494
495
                                    - work(sendsize+dgemv2_offset+icol-1)*v(v_local_offset_dger+irow-1,v2col)
        end do
    end do

end subroutine qr_pdlarfl2_tmatrix_1dcomm

! generalized pdlarfl2 version
! TODO: include T merge here (seperate by "old" and "new" index)
subroutine qr_tmerge_pdlarfb_1dcomm(m,mb,n,oldk,k,v,ldv,t,ldt,a,lda,baseidx,rev,updatemode,mpicomm,work,lwork)
Andreas Marek's avatar
Andreas Marek committed
496
    use precision
497
498
499
500
501
    use qr_utils_mod

    implicit none

    ! input variables (local)
Andreas Marek's avatar
Andreas Marek committed
502
503
    integer(kind=ik) :: ldv,ldt,lda,lwork
    real(kind=rk)    :: v(ldv,*),t(ldt,*),work(*),a(lda,*)
504
505

    ! input variables (global)
Andreas Marek's avatar
Andreas Marek committed
506
    integer(kind=ik) :: m,mb,n,k,oldk,baseidx,rev,updatemode,mpicomm
507

508
509
510
511
512
    ! output variables (global)

    ! derived input variables from QR_PQRPARAM

    ! local scalars
Andreas Marek's avatar
Andreas Marek committed
513
514
    integer(kind=ik) :: localsize,offset,baseoffset
    integer(kind=ik) :: mpirank,mpiprocs,mpierr
515

Andreas Marek's avatar
Andreas Marek committed
516
517
518
519
    integer(kind=ik) :: sendoffset,recvoffset,sendsize
    integer(kind=ik) :: updateoffset,updatelda,updatesize
    integer(kind=ik) :: mergeoffset,mergelda,mergesize
    integer(kind=ik) :: tgenoffset,tgenlda,tgensize
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553

        if (updatemode .eq. ichar('I')) then
            updatelda = oldk+k
        else
            updatelda = k
        end if

        updatesize = updatelda*n

        mergelda = k
        mergesize = mergelda*oldk

        tgenlda = 0
        tgensize = 0

        sendsize = updatesize + mergesize + tgensize

    if (lwork .eq. -1) then
        work(1) = DBLE(2*sendsize)
        return
    end if
    call MPI_Comm_rank(mpicomm,mpirank,mpierr)
    call MPI_Comm_size(mpicomm,mpiprocs,mpierr)
    ! use baseidx as idx here, otherwise the upper triangle part will be lost
    ! during the calculation, especially in the reversed case
    call local_size_offset_1d(m,mb,baseidx,baseidx,rev,mpirank,mpiprocs, &
                                localsize,baseoffset,offset)

    sendoffset = 1

        if (oldk .gt. 0) then
            updateoffset = 0
            mergeoffset = updateoffset + updatesize
            tgenoffset = mergeoffset + mergesize
554

555
556
557
558
            sendsize = updatesize + mergesize + tgensize

            !print *,'sendsize',sendsize,updatesize,mergesize,tgensize
            !print *,'merging nr of rotations', oldk+k
559

560
            if (localsize .gt. 0) then
561
                ! calculate matrix matrix product of householder vectors and target matrix
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584

                if (updatemode .eq. ichar('I')) then
                    ! Z' = (Y1,Y2)' * A
                    call dgemm("Trans","Notrans",k+oldk,n,localsize,1.0d0,v(baseoffset,1),ldv,a(offset,1),lda,0.0d0, &
                               work(sendoffset+updateoffset),updatelda)
                else
                    ! Z' = Y1' * A
                    call dgemm("Trans","Notrans",k,n,localsize,1.0d0,v(baseoffset,1),ldv,a(offset,1),lda,0.0d0, &
                               work(sendoffset+updateoffset),updatelda)
                end if

                ! calculate parts needed for T merge
                call dgemm("Trans","Notrans",k,oldk,localsize,1.0d0,v(baseoffset,1),ldv,v(baseoffset,k+1),ldv,0.0d0, &
                           work(sendoffset+mergeoffset),mergelda)

            else
                ! cleanup buffer
                work(sendoffset:sendoffset+sendsize-1) = 0.0d0
            end if
        else
            ! do not calculate parts for T merge as there is nothing to merge

            updateoffset = 0
585

586
            tgenoffset = updateoffset + updatesize
587

588
            sendsize = updatesize + tgensize
589

590
            if (localsize .gt. 0) then
591
                ! calculate matrix matrix product of householder vectors and target matrix
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
                ! Z' = (Y1)' * A
                call dgemm("Trans","Notrans",k,n,localsize,1.0d0,v(baseoffset,1),ldv,a(offset,1),lda,0.0d0, &
                           work(sendoffset+updateoffset),updatelda)

            else
                ! cleanup buffer
                work(sendoffset:sendoffset+sendsize-1) = 0.0d0
            end if

        end if

    recvoffset = sendoffset + sendsize

    if (sendsize .le. 0) return ! nothing to do

    ! exchange data
608
#ifdef WITH_MPI
609
    call mpi_allreduce(work(sendoffset),work(recvoffset),sendsize,mpi_real8,mpi_sum,mpicomm,mpierr)
610
611
612
#else
    work(recvoffset:recvoffset+sendsize-1) = work(sendoffset:sendoffset+sendsize-1)
#endif
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
    updateoffset = recvoffset+updateoffset
    mergeoffset = recvoffset+mergeoffset
    tgenoffset = recvoffset+tgenoffset

        if (oldk .gt. 0) then
            call qr_pdlarft_merge_kernel_local(oldk,k,t,ldt,work(mergeoffset),mergelda)

            if (localsize .gt. 0) then
                if (updatemode .eq. ichar('I')) then

                    ! update matrix (pdlarfb) with complete T
                    call qr_pdlarfb_kernel_local(localsize,n,k+oldk,a(offset,1),lda,v(baseoffset,1),ldv,t(1,1),ldt, &
                                                 work(updateoffset),updatelda)
                else
                    ! update matrix (pdlarfb) with small T (same as update with no old T TODO)
                    call qr_pdlarfb_kernel_local(localsize,n,k,a(offset,1),lda,v(baseoffset,1),ldv,t(1,1),ldt, &
                                                 work(updateoffset),updatelda)
                end if
            end if
        else
            if (localsize .gt. 0) then
                ! update matrix (pdlarfb) with small T
                call qr_pdlarfb_kernel_local(localsize,n,k,a(offset,1),lda,v(baseoffset,1),ldv,t(1,1),ldt, &
                                             work(updateoffset),updatelda)
            end if
        end if

end subroutine qr_tmerge_pdlarfb_1dcomm

end module elpa_pdlarfb