#include <mpi.h>
//#include <complex>
#include "mkl.h"

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include <time.h>


// float 
//#define BASE_PREC_FLOAT

#define BASE_PREC_DOUBLE

#if defined (BASE_PREC_FLOAT)
#define T_base_precision float
#elif defined (BASE_PREC_DOUBLE)
#define T_base_precision double
#endif

#define MAX(a,b) ((a) > (b) ? (a) : (b))
#define MIN(a,b) ((a) < (b) ? (a) : (b))

#if defined(USE_CUDA)
#include <cuda_runtime.h>

#define gpuErrCheck(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line)
{
   if (code != cudaSuccess) 
   {
      fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
      exit(code);
   }
}
#endif
//______________________________________________

int debug_mode=0;

int N; //problem size
int nev; // warmup problem size

enum matrix_types {Clement, Random_0_1, Random_m1_1, Sequence}; // symmetric Clement matrix
enum matrix_types matrix_type = Random_0_1;


/************   MPI  ***************************/

int world_rank, world_size;
int world_size_external; // for partial usage of the node

/************  BLACS ***************************/

int NB=256; // Global column block size for partitioning the global matrix
int ictxt; // ordinal number i of the "context",
int nprow; // Number of total process rows    in the process grid (equivalent to Pr)
int npcol; // Number of total process columns in the process grid (equivalent to Pc)

int myrow; // The calling process's row     coordinate in the process grid
int mycol; // The calling process's column coordinate in the process grid.
int info=0; // Output integer argument of driver and computational routines indicating the success(0) or failure(1) of the routine

int m_loc; // size of local matrices is m_loc x n_loc for NxN global matrix; size of local vectors is m_loc x 1
int n_loc;

int m_loc_reduced; // size of local matrices is m_loc_reduced x n_loc_reduced for nev x nev global matrix
int n_loc_reduced;

int desc_N_N[9], desc_N_nev[9], desc_nev_nev[9];
#ifdef __cplusplus
extern "C" 
	{
#endif
	// Cblacs declarations
	void Cblacs_pinfo(int*, int*);
	void Cblacs_get(int, int, int*);
	void Cblacs_gridinit(int*, const char*, int, int);
	void Cblacs_pcoord(int, int, int*, int*);
	void Cblacs_gridexit(int);
	void Cblacs_barrier(int, const char*);
	int numroc_(int*, int*, int*, int*, int*);
	void Cblacs_gridinfo(int, int*, int*, int*, int*);
	void descinit_ (MKL_INT *desc, const MKL_INT *m, const MKL_INT *n, const MKL_INT *mb, const MKL_INT *nb, const MKL_INT *irsrc, const MKL_INT *icsrc, const MKL_INT *ictxt, const MKL_INT *lld, MKL_INT *info);
	
   // P BLAS declarations
   void pdgemm_ ( const char *transa , const char *transb , const MKL_INT *m , const MKL_INT *n , const MKL_INT *k , const double *alpha , const double *a , const MKL_INT *ia , const MKL_INT *ja , const MKL_INT *desca , const double *b , const MKL_INT *ib , const MKL_INT *jb , const MKL_INT *descb , const double *beta , double *c , const MKL_INT *ic , const MKL_INT *jc , const MKL_INT *descc );
   void pdscal_ ( const MKL_INT *n , const double *a , double *x , const MKL_INT *ix , const MKL_INT *jx , const MKL_INT *descx , const MKL_INT *incx );
   void pdaxpy_ ( const MKL_INT *n , const double *a , const double *x , const MKL_INT *ix , const MKL_INT *jx , const MKL_INT *descx , const MKL_INT *incx , double *y , const MKL_INT *iy , const MKL_INT *jy , const MKL_INT *descy , const MKL_INT *incy );
   void pdnrm2_ ( const MKL_INT *n , double *norm2 , const double *x , const MKL_INT *ix , const MKL_INT *jx , const MKL_INT *descx , const MKL_INT *incx );
	void psgeadd_ ( const char *trans , const MKL_INT *m , const MKL_INT *n , const float *alpha , const float *a , const MKL_INT *ia , const MKL_INT *ja , const MKL_INT *desca , const float *beta , float *c , const MKL_INT *ic , const MKL_INT *jc , const MKL_INT *descc );
   void pdgeadd_ ( const char *trans , const MKL_INT *m , const MKL_INT *n , const double *alpha , const double *a , const MKL_INT *ia , const MKL_INT *ja , const MKL_INT *desca , const double *beta , double *c , const MKL_INT *ic , const MKL_INT *jc , const MKL_INT *descc );
   
   //void cblas_dgemm ( const CBLAS_LAYOUT Layout , const CBLAS_TRANSPOSE transa , const CBLAS_TRANSPOSE transb , const MKL_INT m , const MKL_INT n , const MKL_INT k , const double alpha , const double *a , const MKL_INT lda , const double *b , const MKL_INT ldb , const double beta , double *c , const MKL_INT ldc );

   //void RefSyEv(T_high_precision *A, T_high_precision *X, T_high_precision *lambda, int N, int nev, int m_loc, int n_loc, int m_loc_reduced, int n_loc_reduced, int *desc_N_N, int *desc_N_nev, int *desc_nev_nev, int myrow, int mycol, int nprow, int npcol, int world_rank, int NB, int debug_mode);
#ifdef __cplusplus
	}
#endif

void createCublasHandle();

//____________________________________________________________________________________________

int LoadCblacs(int N) // return 0 (for successfull) and 1 (for unsuccessful) initialization
	{
	/*
   if (world_size== 2) {nprow = 2; npcol = 1;}
   if (world_size==12) {nprow = 3; npcol = 4;}
   if (world_size==24) {nprow = 4; npcol = 6;}
   if (world_size==36) {nprow = 6; npcol = 6;}
   */
   if (world_size_external== 2) {nprow = 1; npcol = 2;}
   if (world_size_external== 8) {nprow = 2; npcol = 4;}
   else if (world_size_external==72) {nprow = 8; npcol = 9;}
   
   else
      {      
      nprow=sqrt(world_size_external); // Number of process rows     in the process grid (equivalent to Pr)
      if (nprow>1 && NB*nprow>N) nprow = N/NB;
      npcol=nprow; // Number of process columns in the process grid (equivalent to Pc)
      }
      
	if (nprow==1 && npcol==1) NB=N;
	if (world_rank==0) printf("N= %d, nprow=%d, npcol=%d, NB=%d \n", N, nprow, npcol, NB);

	Cblacs_pinfo( &world_rank, &world_size ) ; // Routine is used when some initial system information is required before the BLACS are set up
	Cblacs_get( -1, 0, &ictxt );
	Cblacs_gridinit( &ictxt, "Row", nprow, npcol );
	Cblacs_gridinfo( ictxt, &nprow, &npcol, &myrow, &mycol );

	if (!((myrow>-1)&(mycol>-1)&(myrow<nprow)&(mycol<npcol))) return 1;

	// NUMROC computes the NUMber of Rows Or Columns of a distributed (small local) matrix owned by the process indicated by myrow,mycol
	int iZERO=0;
	m_loc = numroc_( &N, &NB, &myrow, &iZERO, &nprow ); // size of local matrix is m_loc x n_loc
	n_loc = numroc_( &N, &NB, &mycol, &iZERO, &npcol );
   
   //Cblacs_barrier(ictxt, "All");

	// DESCINIT initializes the descriptor vector with the 8 input arguments: M, N, MB, NB, IRSRC, ICSRC, ICTXT, LLD.
	descinit_(desc_N_N,   &N, &N  ,    &NB, &NB,   &iZERO, &iZERO, &ictxt, &m_loc, &info); // m_loc -- is local leading dimension
	//descinit_(desc_N_nev, &N, &nev,    &NB, &NB,   &iZERO, &iZERO, &ictxt, &m_loc, &info); // m_loc -- is local leading dimension
	nev = MIN(N, 10240);
   descinit_(desc_nev_nev, &nev, &nev,    &NB, &NB,   &iZERO, &iZERO, &ictxt, &m_loc, &info); // m_loc_reduced -- is local leading dimension
	//descinit_(descVector,  &N, &iONE, &NB, &iONE, &iZERO, &iZERO, &ictxt, &m_loc, &info);

	// print parameters of the processes
	if (debug_mode==1)
      {
      int iter_rank;
      for(iter_rank=0; iter_rank<world_size; iter_rank++)
         {
         if (iter_rank==world_rank) printf("load_cblacs.h: Proc-%d, m_loc= %d, n_loc=%d \n", world_rank, m_loc, n_loc);
         //Cblacs_barrier(ictxt, "All");
         }
      }
   else if (world_rank==0) printf("load_cblacs.h: Proc-%d, m_loc= %d, n_loc=%d \n", world_rank, m_loc, n_loc);

   return 0;
	}


//___________________________________________________________________________

void PrintMatrix(T_base_precision *Mat, int n_rows, int n_cols)
	{
	int i,j;
	for(i=0; i<n_rows; i++) 
		{
		for(j=0; j<n_cols; j++)
			{
			printf("%.16g\t", Mat[i+j*n_rows]);
			}
		printf("\n");
		}
	}
   
//___________________________________________________________________________

T_base_precision function_H (int I_gl, int J_gl)
   {
   T_base_precision func_H=0;
   
   if (matrix_type==Clement)
      {
      if (I_gl==J_gl+1 || J_gl==I_gl+1) 
         {
         int K = MIN(I_gl, J_gl);
         func_H = (T_base_precision) sqrtl((K+1) * (N - K-1));
         }
      }
   else if (matrix_type==Sequence)
      {
      if (J_gl>I_gl) return function_H (J_gl, I_gl);
      return I_gl + N*J_gl + 1;
      }
   else if (matrix_type==Random_0_1)
      {
      func_H = (double)(rand())/RAND_MAX;
      //if (abs(I_gl-J_gl)!=1) func_H = double(rand())/RAND_MAX;
      //if (!(I_gl==0&&J_gl==1 || I_gl==1&&J_gl==0)) func_H = double(rand())/RAND_MAX;
      }
   else if (matrix_type==Random_m1_1)
      {
      func_H = (2.0*rand())/RAND_MAX - 1.0;
      }
   else
      {
      printf("matrix_type=%d is not supported", matrix_type);
      exit(1);
      }
   
   return func_H;
   }

//___________________________________________________________________________


int main(int argc, char** argv)
{

/************  MPI ***************************/
MPI_Init( &argc, &argv);
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
if (world_rank==0) printf("world_size=%d \n", world_size);
world_size_external = world_size;

/************  --- ***************************/

if (argc==1) // one argument was provided: filename (default)
   {
   N = 10;
   }
   
if (argc==2) // two arguments were provided: filename (default), N
   {
   N = atoi(argv[1]);
   }

if (argc==3) // two arguments were provided: filename (default), N, NB
   {
   N = atoi(argv[1]);
   NB = atoi(argv[2]);
   }

if (argc==4) // two arguments were provided: filename (default), N, NB, world_size_external
   {
   N = atoi(argv[1]);
   NB = atoi(argv[2]);
   world_size_external = atoi(argv[3]);
   }
   
if (world_rank==0) printf("world_size_external=%d \n", world_size_external);   
if (N>=100) debug_mode=0;
//____________________________________________ 
// write output into file for debugging
//std::string OutputName = "output-"+std::to_string(world_rank)+".txt"; 
//if (debug_mode==1 && world_size>1) freopen(OutputName.c_str(), "w", stdout);
//____________________________________________    

clock_t t0, t1;
//double t_gemm;
   
//std::mt19937 gen(1337.0);
//std::mt19937 gen(1.0);
//std::normal_distribution<> d;
//std::uniform_real_distribution<> uniform_distr(-1.0, 1.0);
//std::uniform_real_distribution<> uniform_distr(0, 1.0);
	
//srand(time(NULL));	
srand(1+world_rank); // seeding
//double r = double(rand())/RAND_MAX; // Returns a pseudo-random integer between 0 and 1
	
/*parameters of block-cyclic data layout*/
int irsrc = 0; 
int icsrc = 0;

  
int dims[2];
dims[0] = dims[1] = 0;
//MPI proc grid = dims[0] x dims[1]
MPI_Dims_create(world_size, 2, dims);

	
//----------------------------------------------------

// 0. Setup blacs
LoadCblacs(N);
if (!((myrow>-1)&(mycol>-1)&(myrow<nprow)&(mycol<npcol))) // checks whether the process is in the grid.
	{
	printf("bye-bye from process-%d \n", world_rank); 
	MPI_Finalize();
	return 0;
	}

 #if defined (USE_CUDA)
   createCublasHandle();
 #endif

// 1. Fill the local matrix
t0 = clock();
//vector<T_base_precision> H_loc = vector<T_base_precision>(m_loc*n_loc, 0);
T_base_precision *A_loc, *B_loc, *C_loc;
A_loc  = (T_base_precision *) calloc(m_loc*n_loc, sizeof(T_base_precision));
B_loc  = (T_base_precision *) calloc(m_loc*n_loc, sizeof(T_base_precision));
C_loc  = (T_base_precision *) calloc(m_loc*n_loc, sizeof(T_base_precision));

int i_loc, j_loc;	
for(i_loc=0; i_loc<m_loc; i_loc++) // iteration over the first state
	{
	int l_1 = i_loc/NB; // local coord of the (NBxNB) block among other blocks
	int x_1 = i_loc%NB; // local coord within the block
	int I_gl= (l_1*nprow + myrow)*NB + x_1; // gl-global; nprow = "P_r"; myrow="p_r"  (quoted-ScaLAPACK userguide notation, p.88-90)
		
	for(j_loc=0; j_loc<n_loc; j_loc++) // iteration over the second state
		{
      int l_2 = j_loc/NB; // local coord of the (NBxNB) block among other blocks
      int x_2 = j_loc%NB; // local coord within the block
      int J_gl= (l_2*npcol + mycol)*NB + x_2;
    
		//H_loc[i_loc+j_loc*m_loc] = H[I_gl+J_gl*N];
		A_loc[i_loc+j_loc*m_loc] = function_H(I_gl, J_gl);
		B_loc[i_loc+j_loc*m_loc] = function_H(I_gl, J_gl);
		}
	}

t1 = clock();
//t_gemm=(double)(t1-t0)/CLOCKS_PER_SEC;
if (world_rank==0 || debug_mode==1) printf("Matrix creation time: %f sec.\n", (double)(t1-t0)/CLOCKS_PER_SEC);
   
#ifdef PIN_MATRICES
   gpuErrCheck( cudaHostRegister(A_loc, m_loc*n_loc*sizeof(T_base_precision), cudaHostRegisterDefault) );
   gpuErrCheck( cudaHostRegister(B_loc, m_loc*n_loc*sizeof(T_base_precision), cudaHostRegisterDefault) );
   gpuErrCheck( cudaHostRegister(C_loc, m_loc*n_loc*sizeof(T_base_precision), cudaHostRegisterDefault) );
#endif
   
if (debug_mode==1) 
   {
   printf("A_loc: \n");
   PrintMatrix(A_loc, m_loc, n_loc);
   
   printf("B_loc: \n");
   PrintMatrix(B_loc, m_loc, n_loc);
   }

   
// 2. Call ScaLAPACK routines (perform the actual matrix diagonalization)

T_base_precision real_minus_ONE = (T_base_precision)(-1.0);
T_base_precision real_ONE = (T_base_precision)( 1.0);
T_base_precision real_ZERO = (T_base_precision)(0.0);
int iONE = 1; 

t0 = clock();
#if defined (BASE_PREC_DOUBLE)
   pdgemm_ ("N", "N", &nev  , &nev, &nev, &real_ONE, A_loc, &iONE, &iONE, desc_nev_nev, B_loc , &iONE, &iONE, desc_nev_nev, &real_ZERO, C_loc, &iONE, &iONE, desc_nev_nev); // C = A*B
#endif
t1 = clock();
double t_gemm_warmup=(double)(t1-t0)/CLOCKS_PER_SEC;
if (world_rank==0 || debug_mode==1) printf("Warmup GEMM TN time: %f sec.\n", t_gemm_warmup);

t0 = clock();
#if defined (BASE_PREC_DOUBLE)
   pdgemm_ ("T", "N", &N  , &N, &N, &real_ONE, A_loc, &iONE, &iONE, desc_N_N, B_loc , &iONE, &iONE, desc_N_N, &real_ZERO, C_loc, &iONE, &iONE, desc_N_N); // C = A*B
#endif
t1 = clock();
double t_gemm_tn=(double)(t1-t0)/CLOCKS_PER_SEC;
if (world_rank==0 || debug_mode==1) printf("GEMM TN time: %f sec.\n", t_gemm_tn);

   
t0 = clock();
#if defined (BASE_PREC_DOUBLE)
   pdgemm_ ("N", "N", &N  , &N, &N, &real_ONE, A_loc, &iONE, &iONE, desc_N_N, B_loc , &iONE, &iONE, desc_N_N, &real_ZERO, C_loc, &iONE, &iONE, desc_N_N); // C = A*B
#endif
t1 = clock();
double t_gemm_nn=(double)(t1-t0)/CLOCKS_PER_SEC;
if (world_rank==0 || debug_mode==1) printf("GEMM NN time: %f sec.\n", t_gemm_nn);
   

 

// 3. Finalize

if (debug_mode==1) 
   {
   printf("C:\n");
   PrintMatrix(C_loc, N, N);
   }

/*   
if (world_rank==0 || debug_mode==1) 
   {
   printf("\n N, t_elpa_float, t_elpa_double, t_refinement(s)\n");
   printf("%d,\t%f,\t%f", N, t_elpa_float, t_elpa_double);
   for(int iter=0; iter<n_iter_ref; iter++) printf(",\t%f", t_refinement[iter]);
   printf("\n\n Done! \n");
   
   // write to file
   std::ofstream myfile;
   myfile.open ("s_data.txt", std::ios_base::app | std::ios_base::out); // add to the current file (not rewrite)
   myfile << std::fixed;
   myfile << N << ",\t\t" << t_elpa_float << ",\t" << t_elpa_double;
   for(int iter=0; iter<n_iter_ref; iter++) myfile << ",\t" << t_refinement[iter];
   myfile << std::endl;
   myfile.close();
   }
*/

#ifdef PIN_MATRICES
   gpuErrCheck( cudaHostUnregister(A_loc) );
   gpuErrCheck( cudaHostUnregister(B_loc) );
   gpuErrCheck( cudaHostUnregister(C_loc) );
   if (world_rank==0) printf("ScaLAPACK matrices A_loc, B_loc, C_loc on CPUs are pinned\n");
#endif

free(A_loc);
free(B_loc);
free(C_loc);

if (world_rank==0 || debug_mode==1) 
   {
   printf("\n N, world_size_external, t_gemm_nn, t_gemm_tn\n");
   printf("%d,\t\t%d,\t\t%f,\t%f\n", N, world_size_external, t_gemm_nn, t_gemm_tn);
   
   FILE *f = fopen("s_data.txt", "a"); // a=append
   if (f == NULL)
      {
      printf("Error opening file!\n");
      exit(1);
      }
   fprintf(f, "%d,\t\t%d,\t\t%f,\t%f\n", N, world_size_external, t_gemm_nn, t_gemm_tn);
   
   // write to file
   /*
   std::ofstream myfile;
   myfile.open ("s_data.txt", std::ios_base::app | std::ios_base::out); // add to the current file (not rewrite)
   myfile << std::fixed;
   myfile << N << ",\t\t" << t_gemm_nn << ",\t" << t_gemm_tn;
   myfile << std::endl;
   myfile.close();
   */
   }
   
if ((myrow>-1)&(mycol>-1)&(myrow<nprow)&(mycol<npcol)) Cblacs_gridexit(0); // leave grid if we were in one
  
MPI_Finalize();
}