Commit 57426d58 authored by Pavel Kus's avatar Pavel Kus

compatible interface for double complex

parent 3cf77a0a
......@@ -188,6 +188,20 @@ extern "C" {
return val;
}
void cublasZgemv_elpa_wrapper (char trans, int m, int n, double complex alpha,
const double complex *A, int lda, const double complex *x, int incx,
double complex beta, double complex *y, int incy) {
cuDoubleComplex alpha_casted = *((cuDoubleComplex*)(&alpha));
cuDoubleComplex beta_casted = *((cuDoubleComplex*)(&beta));
const cuDoubleComplex* A_casted = (const cuDoubleComplex*) A;
const cuDoubleComplex* x_casted = (const cuDoubleComplex*) x;
cuDoubleComplex* y_casted = (cuDoubleComplex*) y;
cublasZgemv(trans, m, n, alpha_casted, A_casted, lda, x_casted, incx, beta_casted, y_casted, incy);
}
void cublasCgemv_elpa_wrapper (char trans, int m, int n, float complex alpha,
const float complex *A, int lda, const float complex *x, int incx,
float complex beta, float complex *y, int incy) {
......@@ -202,6 +216,21 @@ extern "C" {
cublasCgemv(trans, m, n, alpha_casted, A_casted, lda, x_casted, incx, beta_casted, y_casted, incy);
}
void cublasZgemm_elpa_wrapper (char transa, char transb, int m, int n, int k,
double complex alpha, const double complex *A, int lda,
const double complex *B, int ldb, double complex beta,
double complex *C, int ldc) {
cuDoubleComplex alpha_casted = *((cuDoubleComplex*)(&alpha));
cuDoubleComplex beta_casted = *((cuDoubleComplex*)(&beta));
const cuDoubleComplex* A_casted = (const cuDoubleComplex*) A;
const cuDoubleComplex* B_casted = (const cuDoubleComplex*) B;
cuDoubleComplex* C_casted = (cuDoubleComplex*) C;
cublasZgemm(transa, transb, m, n, k, alpha_casted, A_casted, lda, B_casted, ldb, beta_casted, C_casted, ldc);
}
void cublasCgemm_elpa_wrapper (char transa, char transb, int m, int n, int k,
float complex alpha, const float complex *A, int lda,
const float complex *B, int ldb, float complex beta,
......@@ -217,6 +246,18 @@ extern "C" {
cublasCgemm(transa, transb, m, n, k, alpha_casted, A_casted, lda, B_casted, ldb, beta_casted, C_casted, ldc);
}
void cublasZtrmm_elpa_wrapper (char side, char uplo, char transa, char diag,
int m, int n, double complex alpha, const double complex *A,
int lda, double complex *B, int ldb){
cuDoubleComplex alpha_casted = *((cuDoubleComplex*)(&alpha));
const cuDoubleComplex* A_casted = (const cuDoubleComplex*) A;
cuDoubleComplex* B_casted = (cuDoubleComplex*) B;
cublasZtrmm(side, uplo, transa, diag, m, n, alpha_casted, A_casted, lda, B_casted, ldb);
}
void cublasCtrmm_elpa_wrapper (char side, char uplo, char transa, char diag,
int m, int n, float complex alpha, const float complex *A,
int lda, float complex *B, int ldb){
......
......@@ -295,7 +295,7 @@ module cuda_functions
end interface
interface
subroutine cublas_zgemm_c(cta, ctb, m, n, k, alpha, a, lda, b, ldb, beta, c,ldc) bind(C,name='cublasZgemm')
subroutine cublas_zgemm_c(cta, ctb, m, n, k, alpha, a, lda, b, ldb, beta, c,ldc) bind(C,name='cublasZgemm_elpa_wrapper')
use iso_c_binding
......@@ -325,7 +325,7 @@ module cuda_functions
end interface
interface
subroutine cublas_ztrmm_c(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb) bind(C,name='cublasZtrmm')
subroutine cublas_ztrmm_c(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb) bind(C,name='cublasZtrmm_elpa_wrapper')
use iso_c_binding
......@@ -383,7 +383,7 @@ module cuda_functions
end interface
interface
subroutine cublas_zgemv_c(cta, m, n, alpha, a, lda, x, incx, beta, y, incy) bind(C,name='cublasZgemv')
subroutine cublas_zgemv_c(cta, m, n, alpha, a, lda, x, incx, beta, y, incy) bind(C,name='cublasZgemv_elpa_wrapper')
use iso_c_binding
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment