Commit b3afa520 authored by Andreas Marek's avatar Andreas Marek
Browse files

Fix interface in elpa_pdlarfb

parent 776cee73
......@@ -47,6 +47,7 @@ subroutine qr_pdlarfb_1dcomm_&
(m,mb,n,k,a,lda,v,ldv,tau,t,ldt,baseidx,idx,rev,mpicomm,work,lwork)
use precision
use qr_utils_mod
use elpa_blas_interfaces
implicit none
......@@ -106,9 +107,13 @@ subroutine qr_pdlarfb_1dcomm_&
! Z' = Y' * A
if (localsize .gt. 0) then
#ifdef DOUBLE_PRECISION_REAL
call dgemm("Trans","Notrans",k,n,localsize,1.0_rk8,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk8,work(1,1),k)
call dgemm("Trans", "Notrans",int(k,kind=BLAS_KIND), int(n,kind=BLAS_KIND), localsize, &
1.0_rk8, v(baseoffset,1), int(ldv,kind=BLAS_KIND), &
a(offset,1), int(lda, kind=BLAS_KIND), 0.0_rk8, work(1,1), int(k,kind=BLAS_KIND))
#else
call sgemm("Trans","Notrans",k,n,localsize,1.0_rk4,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk4,work(1,1),k)
call sgemm("Trans", "Notrans",int(k,kind=BLAS_KIND), int(n,kind=BLAS_KIND), localsize, &
1.0_rk4, v(baseoffset,1), int(ldv,kind=BLAS_KIND), &
a(offset,1), int(lda, kind=BLAS_KIND), 0.0_rk4, work(1,1), int(k,kind=BLAS_KIND))
#endif
else
#ifdef DOUBLE_PRECISION_REAL
......@@ -144,6 +149,7 @@ subroutine qr_pdlarft_pdlarfb_1dcomm_&
(m,mb,n,oldk,k,v,ldv,tau,t,ldt,a,lda,baseidx,rev,mpicomm,work,lwork)
use precision
use qr_utils_mod
use elpa_blas_interfaces
implicit none
......@@ -189,11 +195,14 @@ subroutine qr_pdlarft_pdlarfb_1dcomm_&
#ifdef DOUBLE_PRECISION_REAL
if (localsize .gt. 0) then
! calculate inner product of householdervectors
call dsyrk("Upper","Trans",k,localsize,1.0_rk8,v(baseoffset,1),ldv,0.0_rk8,work(1,1),k)
call dsyrk("Upper", "Trans", int(k,kind=BLAS_KIND), int(localsize,kind=BLAS_KIND), 1.0_rk8, &
v(baseoffset,1), int(ldv,kind=BLAS_KIND), 0.0_rk8, work(1,1), int(k,kind=BLAS_KIND))
! calculate matrix matrix product of householder vectors and target matrix
! Z' = Y' * A
call dgemm("Trans","Notrans",k,n,localsize,1.0_rk8,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk8,work(1,k+1),k)
call dgemm("Trans", "Notrans", int(k,kind=BLAS_KIND), int(n,kind=BLAS_KIND), localsize, &
1.0_rk8, v(baseoffset,1), int(ldv,kind=BLAS_KIND), a(offset,1), &
int(lda,kind=BLAS_KIND), 0.0_rk8, work(1,k+1), int(k,kind=BLAS_KIND))
! TODO: reserved for T merge parts
work(1:k,n+k+1:n+k+oldk) = 0.0_rk8
......@@ -203,11 +212,14 @@ subroutine qr_pdlarft_pdlarfb_1dcomm_&
#else /* DOUBLE_PRECISION_REAL */
if (localsize .gt. 0) then
! calculate inner product of householdervectors
call ssyrk("Upper","Trans",k,localsize,1.0_rk4,v(baseoffset,1),ldv,0.0_rk4,work(1,1),k)
call ssyrk("Upper", "Trans", int(k,kind=BLAS_KIND), int(localsize,kind=BLAS_KIND), 1.0_rk4, &
v(baseoffset,1), int(ldv,kind=BLAS_KIND), 0.0_rk4, work(1,1), int(k,kind=BLAS_KIND))
! calculate matrix matrix product of householder vectors and target matrix
! Z' = Y' * A
call sgemm("Trans","Notrans",k,n,localsize,1.0_rk4,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk4,work(1,k+1),k)
call sgemm("Trans", "Notrans", int(k,kind=BLAS_KIND), int(n,kind=BLAS_KIND), int(localsize,kind=BLAS_KIND), &
1.0_rk4, v(baseoffset,1), int(ldv,kind=BLAS_KIND), a(offset,1), int(lda,kind=BLAS_KIND), &
0.0_rk4, work(1,k+1), int(k,kind=BLAS_KIND))
! TODO: reserved for T merge parts
work(1:k,n+k+1:n+k+oldk) = 0.0_rk4
......@@ -243,9 +255,11 @@ subroutine qr_pdlarft_pdlarfb_1dcomm_&
do icol=k-1,1,-1
t(icol,icol+1:k) = -tau(icol)*work(icol,recvoffset+icol:recvoffset+k-1)
#ifdef DOUBLE_PRECISION_REAL
call dtrmv("Upper","Trans","Nonunit",k-icol,t(icol+1,icol+1),ldt,t(icol,icol+1),ldt)
call dtrmv("Upper", "Trans", "Nonunit", int(k-icol,kind=BLAS_KIND), t(icol+1,icol+1), &
int(ldt,kind=BLAS_KIND), t(icol,icol+1), int(ldt,kind=BLAS_KIND))
#else
call strmv("Upper","Trans","Nonunit",k-icol,t(icol+1,icol+1),ldt,t(icol,icol+1),ldt)
call strmv("Upper","Trans","Nonunit",int(k-icol,kind=BLAS_KIND), t(icol+1,icol+1), &
int(ldt,kind=BLAS_KIND), t(icol,icol+1), int(ldt,kind=BLAS_KIND))
#endif
t(icol,icol) = tau(icol)
end do
......@@ -279,7 +293,7 @@ subroutine qr_pdlarft_set_merge_1dcomm_&
(m,mb,n,blocksize,v,ldv,t,ldt,baseidx,rev,mpicomm,work,lwork)
use precision
use qr_utils_mod
use elpa_blas_interfaces
implicit none
! input variables (local)
......@@ -316,13 +330,15 @@ subroutine qr_pdlarft_set_merge_1dcomm_&
localsize,baseoffset,offset)
#ifdef DOUBLE_PRECISION_REAL
if (localsize .gt. 0) then
call dsyrk("Upper","Trans",n,localsize,1.0_rk8,v(baseoffset,1),ldv,0.0_rk8,work(1,1),n)
call dsyrk("Upper", "Trans", int(n,kind=BLAS_KIND), int(localsize,kind=BLAS_KIND), &
1.0_rk8, v(baseoffset,1), int(ldv,kind=BLAS_KIND), 0.0_rk8, work(1,1), int(n,kind=BLAS_KIND))
else
work(1:n,1:n) = 0.0_rk8
end if
#else
if (localsize .gt. 0) then
call ssyrk("Upper","Trans",n,localsize,1.0_rk4,v(baseoffset,1),ldv,0.0_rk4,work(1,1),n)
call ssyrk("Upper", "Trans", int(n,kind=BLAS_KIND), int(localsize,kind=BLAS_KIND), &
1.0_rk4, v(baseoffset,1), int(ldv,kind=BLAS_KIND), 0.0_rk4, work(1,1), int(n,kind=BLAS_KIND))
else
work(1:n,1:n) = 0.0_rk4
end if
......@@ -356,7 +372,7 @@ subroutine qr_pdlarft_tree_merge_1dcomm_&
(m,mb,n,blocksize,treeorder,v,ldv,t,ldt,baseidx,rev,mpicomm,work,lwork)
use precision
use qr_utils_mod
use elpa_blas_interfaces
implicit none
! input variables (local)
......@@ -395,13 +411,15 @@ subroutine qr_pdlarft_tree_merge_1dcomm_&
#ifdef DOUBLE_PRECISION_REAL
if (localsize .gt. 0) then
call dsyrk("Upper","Trans",n,localsize,1.0_rk8,v(baseoffset,1),ldv,0.0_rk8,work(1,1),n)
call dsyrk("Upper", "Trans", int(n,kind=BLAS_KIND), int(localsize,kind=BLAS_KIND), &
1.0_rk8, v(baseoffset,1), int(ldv,kind=BLAS_KIND), 0.0_rk8, work(1,1), int(n,kind=BLAS_KIND))
else
work(1:n,1:n) = 0.0_rk8
end if
#else
if (localsize .gt. 0) then
call ssyrk("Upper","Trans",n,localsize,1.0_rk4,v(baseoffset,1),ldv,0.0_rk4,work(1,1),n)
call dsyrk("Upper", "Trans", int(n,kind=BLAS_KIND), int(localsize,kind=BLAS_KIND), &
1.0_rk4, v(baseoffset,1), int(ldv,kind=BLAS_KIND), 0.0_rk4, work(1,1), int(n,kind=BLAS_KIND))
else
work(1:n,1:n) = 0.0_rk4
end if
......@@ -437,7 +455,7 @@ subroutine qr_pdlarfl_1dcomm_&
use precision
use elpa1_impl
use qr_utils_mod
use elpa_blas_interfaces
implicit none
! input variables (local)
......@@ -529,7 +547,7 @@ subroutine qr_pdlarfl2_tmatrix_1dcomm_&
use precision
use elpa1_impl
use qr_utils_mod
use elpa_blas_interfaces
implicit none
! input variables (local)
......@@ -596,15 +614,21 @@ subroutine qr_pdlarfl2_tmatrix_1dcomm_&
#ifdef DOUBLE_PRECISION_REAL
work(1:sendsize) = 0.0_rk8
call dgemv("Trans",local_size1,n,1.0_rk8,a(local_offset1,1),lda,v(v1_local_offset,v1col),1,0.0_rk8,work(dgemv1_offset),1)
call dgemv("Trans",local_size2,n,t(v2col,v2col),a(local_offset2,1),lda,v(v2_local_offset,v2col),1,0.0_rk8, &
work(dgemv2_offset),1)
call dgemv("Trans", int(local_size1,kind=BLAS_KIND), int(n,kind=BLAS_KIND), 1.0_rk8, a(local_offset1,1), &
int(lda,kind=BLAS_KIND), v(v1_local_offset,v1col), 1_BLAS_KIND, 0.0_rk8, work(dgemv1_offset), &
1_BLAS_KIND)
call dgemv("Trans",int(local_size2,kind=BLAS_KIND), int(n,kind=BLAS_KIND), t(v2col,v2col), a(local_offset2,1), &
int(lda,kind=BLAS_KIND), v(v2_local_offset,v2col), 1_BLAS_KIND,0.0_rk8, &
work(dgemv2_offset),1_BLAS_KIND)
#else
work(1:sendsize) = 0.0_rk4
call sgemv("Trans",local_size1,n,1.0_rk4,a(local_offset1,1),lda,v(v1_local_offset,v1col),1,0.0_rk4,work(dgemv1_offset),1)
call sgemv("Trans",local_size2,n,t(v2col,v2col),a(local_offset2,1),lda,v(v2_local_offset,v2col),1,0.0_rk4, &
work(dgemv2_offset),1)
call sgemv("Trans", int(local_size1,kind=BLAS_KIND), int(n,kind=BLAS_KIND), 1.0_rk4, a(local_offset1,1), &
int(lda,kind=BLAS_KIND), v(v1_local_offset,v1col), 1_BLAS_KIND, 0.0_rk4, work(dgemv1_offset), &
1_BLAS_KIND)
call sgemv("Trans",int(local_size2,kind=BLAS_KIND), int(n,kind=BLAS_KIND), t(v2col,v2col), a(local_offset2,1), &
int(lda,kind=BLAS_KIND), v(v2_local_offset,v2col), 1_BLAS_KIND,0.0_rk4, &
work(dgemv2_offset),1_BLAS_KIND)
#endif
#ifdef WITH_MPI
......@@ -621,9 +645,11 @@ subroutine qr_pdlarfl2_tmatrix_1dcomm_&
#endif
! update second Vector
#ifdef DOUBLE_PRECISION_REAL
call daxpy(n,t(1,2),work(sendsize+dgemv1_offset),1,work(sendsize+dgemv2_offset),1)
call daxpy(int(n,kind=BLAS_KIND), t(1,2), work(sendsize+dgemv1_offset), 1_BLAS_KIND, &
work(sendsize+dgemv2_offset),1_BLAS_KIND)
#else
call saxpy(n,t(1,2),work(sendsize+dgemv1_offset),1,work(sendsize+dgemv2_offset),1)
call saxpy(int(n,kind=BLAS_KIND), t(1,2), work(sendsize+dgemv1_offset), 1_BLAS_KIND, &
work(sendsize+dgemv2_offset),1_BLAS_KIND)
#endif
call local_size_offset_1d(m,mb,baseidx,idx-2,rev,mpirank,mpiprocs, &
......@@ -670,7 +696,7 @@ subroutine qr_tmerge_pdlarfb_1dcomm_&
(m,mb,n,oldk,k,v,ldv,t,ldt,a,lda,baseidx,rev,updatemode,mpicomm,work,lwork)
use precision
use qr_utils_mod
use elpa_blas_interfaces
implicit none
! input variables (local)
......@@ -746,17 +772,23 @@ subroutine qr_tmerge_pdlarfb_1dcomm_&
! calculate matrix matrix product of householder vectors and target matrix
if (updatemode .eq. ichar('I')) then
! Z' = (Y1,Y2)' * A
call dgemm("Trans","Notrans",k+oldk,n,localsize,1.0_rk8,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk8, &
work(sendoffset+updateoffset),updatelda)
call dgemm("Trans", "Notrans", int(k+oldk,kind=BLAS_KIND), int(n,kind=BLAS_KIND), &
int(localsize,kind=BLAS_KIND), 1.0_rk8, v(baseoffset,1), int(ldv,kind=BLAS_KIND), &
a(offset,1), int(lda,kind=BLAS_KIND), 0.0_rk8, &
work(sendoffset+updateoffset), int(updatelda,kind=BLAS_KIND))
else
! Z' = Y1' * A
call dgemm("Trans","Notrans",k,n,localsize,1.0_rk8,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk8, &
work(sendoffset+updateoffset),updatelda)
call dgemm("Trans", "Notrans", int(k,kind=BLAS_KIND), int(n,kind=BLAS_KIND), &
int(localsize,kind=BLAS_KIND), 1.0_rk8, v(baseoffset,1), int(ldv,kind=BLAS_KIND), &
a(offset,1), int(lda,kind=BLAS_KIND), 0.0_rk8, &
work(sendoffset+updateoffset), int(updatelda,kind=BLAS_KIND))
end if
! calculate parts needed for T merge
call dgemm("Trans","Notrans",k,oldk,localsize,1.0_rk8,v(baseoffset,1),ldv,v(baseoffset,k+1),ldv,0.0_rk8, &
work(sendoffset+mergeoffset),mergelda)
call dgemm("Trans", "Notrans", int(k,kind=BLAS_KIND), int(oldk,kind=BLAS_KIND), &
int(localsize,kind=BLAS_KIND), 1.0_rk8, v(baseoffset,1), int(ldv,kind=BLAS_KIND), &
v(baseoffset,k+1), int(ldv,kind=BLAS_KIND), 0.0_rk8, &
work(sendoffset+mergeoffset), int(mergelda,kind=BLAS_KIND))
else
! cleanup buffer
......@@ -767,17 +799,23 @@ subroutine qr_tmerge_pdlarfb_1dcomm_&
! calculate matrix matrix product of householder vectors and target matrix
if (updatemode .eq. ichar('I')) then
! Z' = (Y1,Y2)' * A
call sgemm("Trans","Notrans",k+oldk,n,localsize,1.0_rk4,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk4, &
work(sendoffset+updateoffset),updatelda)
call sgemm("Trans", "Notrans", int(k+oldk,kind=BLAS_KIND), int(n,kind=BLAS_KIND), &
int(localsize,kind=BLAS_KIND), 1.0_rk4, v(baseoffset,1), int(ldv,kind=BLAS_KIND), &
a(offset,1), int(lda,kind=BLAS_KIND), 0.0_rk4, &
work(sendoffset+updateoffset), int(updatelda,kind=BLAS_KIND))
else
! Z' = Y1' * A
call sgemm("Trans","Notrans",k,n,localsize,1.0_rk4,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk4, &
work(sendoffset+updateoffset),updatelda)
call sgemm("Trans", "Notrans", int(k,kind=BLAS_KIND), int(n,kind=BLAS_KIND), &
int(localsize,kind=BLAS_KIND), 1.0_rk4, v(baseoffset,1), int(ldv,kind=BLAS_KIND), &
a(offset,1), int(lda,kind=BLAS_KIND), 0.0_rk4, &
work(sendoffset+updateoffset), int(updatelda,kind=BLAS_KIND))
end if
! calculate parts needed for T merge
call sgemm("Trans","Notrans",k,oldk,localsize,1.0_rk4,v(baseoffset,1),ldv,v(baseoffset,k+1),ldv,0.0_rk4, &
work(sendoffset+mergeoffset),mergelda)
call sgemm("Trans", "Notrans", int(k,kind=BLAS_KIND), int(oldk,kind=BLAS_KIND), &
int(localsize,kind=BLAS_KIND), 1.0_rk4, v(baseoffset,1), int(ldv,kind=BLAS_KIND), &
v(baseoffset,k+1), int(ldv,kind=BLAS_KIND), 0.0_rk4, &
work(sendoffset+mergeoffset), int(mergelda,kind=BLAS_KIND))
else
! cleanup buffer
......@@ -798,8 +836,10 @@ subroutine qr_tmerge_pdlarfb_1dcomm_&
if (localsize .gt. 0) then
! calculate matrix matrix product of householder vectors and target matrix
! Z' = (Y1)' * A
call dgemm("Trans","Notrans",k,n,localsize,1.0_rk8,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk8, &
work(sendoffset+updateoffset),updatelda)
call dgemm("Trans","Notrans", int(k,kind=BLAS_KIND), int(n,kind=BLAS_KIND), &
int(localsize, kind=BLAS_KIND), 1.0_rk8, v(baseoffset,1), int(ldv,kind=BLAS_KIND), &
a(offset,1), int(lda,kind=BLAS_KIND),0.0_rk8, &
work(sendoffset+updateoffset), int(updatelda,kind=BLAS_KIND))
else
! cleanup buffer
......@@ -809,8 +849,10 @@ subroutine qr_tmerge_pdlarfb_1dcomm_&
if (localsize .gt. 0) then
! calculate matrix matrix product of householder vectors and target matrix
! Z' = (Y1)' * A
call sgemm("Trans","Notrans",k,n,localsize,1.0_rk4,v(baseoffset,1),ldv,a(offset,1),lda,0.0_rk4, &
work(sendoffset+updateoffset),updatelda)
call sgemm("Trans","Notrans", int(k,kind=BLAS_KIND), int(n,kind=BLAS_KIND), &
int(localsize, kind=BLAS_KIND), 1.0_rk4, v(baseoffset,1), int(ldv,kind=BLAS_KIND), &
a(offset,1), int(lda,kind=BLAS_KIND),0.0_rk4, &
work(sendoffset+updateoffset), int(updatelda,kind=BLAS_KIND))
else
! cleanup buffer
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment