From e914db36b907d7b5adc90866b46d72409224d8d0 Mon Sep 17 00:00:00 2001
From: David Rohr <drohr@jwdt.org>
Date: Sun, 13 Apr 2014 18:12:52 +0200
Subject: [PATCH] implemented FFT algorithm on GPU, first version

---
 CMakeLists.txt                |   1 +
 bioem.cpp                     |  90 +++----------------
 bioem_algorithm.h             |  60 +++++++++++++
 bioem_cuda.cu                 | 163 +++++++++++++++++++++++++---------
 include/bioem.h               |   3 +-
 include/bioem_cuda_internal.h |   7 ++
 include/defs.h                |   9 +-
 include/map.h                 |  30 ++-----
 include/param.h               |   7 +-
 map.cpp                       |  22 ++---
 param.cpp                     |  33 ++++---
 11 files changed, 251 insertions(+), 174 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3077dd7..a97eed3 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -28,6 +28,7 @@ INCLUDE_DIRECTORIES( include $HOME/usr/include )
 IF(CUDA_FOUND)
 	ADD_DEFINITIONS(-DWITH_CUDA)
 	CUDA_ADD_EXECUTABLE( bioEM bioem.cpp  main.cpp  map.cpp  model.cpp  param.cpp cmodules/timer.cpp bioem_cuda.cu )
+	CUDA_ADD_CUFFT_TO_TARGET(bioEM)
 ELSE()
 	ADD_EXECUTABLE( bioEM bioem.cpp  main.cpp  map.cpp  model.cpp  param.cpp cmodules/timer.cpp )
 ENDIF()
diff --git a/bioem.cpp b/bioem.cpp
index b42a35d..c5eeca6 100644
--- a/bioem.cpp
+++ b/bioem.cpp
@@ -263,7 +263,7 @@ int bioem::run()
 			/*** Calculating convolutions of projection map and crosscorrelations ***/
 
 			timer.ResetStart();
-			createConvolutedProjectionMap(iProjectionOut,iConv,proj_mapFFT,conv_map,conv_mapFFT,sumCONV,sumsquareCONV);
+			createConvolutedProjectionMap(iProjectionOut, iConv, proj_mapFFT, conv_map, conv_mapFFT, sumCONV, sumsquareCONV);
 			printf("Time Convolution %d %d: %f\n", iProjectionOut, iConv, timer.GetCurrentElapsedTime());
 
 			/***************************************************************************************/
@@ -348,10 +348,10 @@ int bioem::run()
 		param.refCTF =NULL;
 	}
 
-	if(RefMap.RefMapFFT)
+	if(RefMap.RefMapsFFT)
 	{
-		delete[] RefMap.RefMapFFT;
-		RefMap.RefMapFFT = NULL;
+		delete[] RefMap.RefMapsFFT;
+		RefMap.RefMapsFFT = NULL;
 	}
 	return(0);
 }
@@ -375,7 +375,7 @@ int bioem::compareRefMaps(int iProjectionOut, int iConv, const bioem_map& conv_m
 
 			for (int iRefMap = iStart; iRefMap < iEnd; iRefMap ++)
 			{
-				calculateCCFFT(iRefMap,iProjectionOut, iConv, sumC,sumsquareC, localmultFFT, localCCT,lCC);
+				calculateCCFFT(iRefMap,iProjectionOut, iConv, sumC, sumsquareC, localmultFFT, localCCT,lCC);
 			}
 			myfftw_free(localCCT);
 			myfftw_free(lCC);
@@ -392,82 +392,20 @@ int bioem::compareRefMaps(int iProjectionOut, int iConv, const bioem_map& conv_m
 	return(0);
 }
 
-/////////////NEW ROUTINE ////////////////
-inline int bioem::calculateCCFFT(int iRefMap, int iOrient, int iConv, myfloat_t sumC,myfloat_t sumsquareC, mycomplex_t* localConvFFT,mycomplex_t* localCCT,myfloat_t* lCC)
+inline void bioem::calculateCCFFT(int iRefMap, int iOrient, int iConv, myfloat_t sumC,myfloat_t sumsquareC, mycomplex_t* localConvFFT,mycomplex_t* localCCT,myfloat_t* lCC)
 {
+	const mycomplex_t* RefMapFFT = &RefMap.RefMapsFFT[iRefMap * param.RefMapSize];
 	for(int i = 0;i < param.param_device.NumberPixels * param.param_device.NumberFFTPixels1D;i++)
 	{
-		localCCT[i][0] = localConvFFT[i][0] * RefMap.RefMapFFT[iRefMap].cpoints[i][0] + localConvFFT[i][1] * RefMap.RefMapFFT[iRefMap].cpoints[i][1];
-		localCCT[i][1] = localConvFFT[i][1] * RefMap.RefMapFFT[iRefMap].cpoints[i][0] - localConvFFT[i][0] * RefMap.RefMapFFT[iRefMap].cpoints[i][1];
+		localCCT[i][0] = localConvFFT[i][0] * RefMapFFT[i][0] + localConvFFT[i][1] * RefMapFFT[i][1];
+		localCCT[i][1] = localConvFFT[i][1] * RefMapFFT[i][0] - localConvFFT[i][0] * RefMapFFT[i][1];
 	}
 
 	myfftw_execute_dft_c2r(param.fft_plan_c2r_backward,localCCT,lCC);
 
-// Storing CORRELATIONS FOR CORRESPONDING DISPLACEMENTS & Normalizing after Backward FFT
-	for (int cent_x = 0; cent_x <= param.param_device.maxDisplaceCenter; cent_x=cent_x+param.param_device.GridSpaceCenter)
-	{
-		for (int cent_y = 0; cent_y <= param.param_device.maxDisplaceCenter; cent_y=cent_y+param.param_device.GridSpaceCenter)
-		{
-			calProb(iRefMap, iOrient, iConv, sumC, sumsquareC, (myfloat_t) lCC[cent_x*param.param_device.NumberPixels+cent_y]/ (myfloat_t) (param.param_device.NumberPixels * param.param_device.NumberPixels), cent_x, cent_y);
-		}
-		for (int cent_y = param.param_device.NumberPixels-param.param_device.maxDisplaceCenter; cent_y < param.param_device.NumberPixels; cent_y=cent_y+param.param_device.GridSpaceCenter)
-		{
-			calProb(iRefMap, iOrient, iConv, sumC, sumsquareC, (myfloat_t) lCC[cent_x*param.param_device.NumberPixels+cent_y]/ (myfloat_t) (param.param_device.NumberPixels*param.param_device.NumberPixels), cent_x, param.param_device.NumberPixels-cent_y);
-		}
-	}
-	for (int cent_x = param.param_device.NumberPixels-param.param_device.maxDisplaceCenter; cent_x < param.param_device.NumberPixels; cent_x=cent_x+param.param_device.GridSpaceCenter)
-	{
-		for (int cent_y = 0; cent_y < param.param_device.maxDisplaceCenter; cent_y=cent_y+param.param_device.GridSpaceCenter)
-		{
-			calProb(iRefMap, iOrient, iConv, sumC, sumsquareC, (myfloat_t) lCC[cent_x*param.param_device.NumberPixels+cent_y]/ (myfloat_t) (param.param_device.NumberPixels*param.param_device.NumberPixels), param.param_device.NumberPixels-cent_x, cent_y);
-		}
-		for (int cent_y = param.param_device.NumberPixels-param.param_device.maxDisplaceCenter; cent_y <= param.param_device.NumberPixels; cent_y=cent_y+param.param_device.GridSpaceCenter)
-		{
-			calProb(iRefMap, iOrient, iConv, sumC, sumsquareC, (myfloat_t) lCC[cent_x*param.param_device.NumberPixels+cent_y]/ (myfloat_t) (param.param_device.NumberPixels*param.param_device.NumberPixels), param.param_device.NumberPixels-cent_x, param.param_device.NumberPixels-cent_y);
-		}
-	}
-
-	return (0);
+	doRefMapFFT(iRefMap, iOrient, iConv, lCC, sumC, sumsquareC, pProb, param.param_device, RefMap);
 }
 
-inline int bioem::calProb(int iRefMap,int iOrient, int iConv,myfloat_t sumC,myfloat_t sumsquareC, float value, int disx, int disy)
-{
-
-	/********************************************************/
-	/*********** Calculates the BioEM probability ***********/
-	/********************************************************/
-
-	const myfloat_t logpro = calc_logpro(param.param_device, sumC, sumsquareC, value, RefMap.sum_RefMap[iRefMap], RefMap.sumsquare_RefMap[iRefMap]);
-
-	//update_prob<-1>(logpro, iRefMap, iOrient, iConv, disx, disy, pProb);
-	//GCC is too stupid to inline properly, so the code is copied here
-    if(pProb[iRefMap].Constoadd < logpro)
-    {
-		pProb[iRefMap].Total = pProb[iRefMap].Total * exp(-logpro + pProb[iRefMap].Constoadd);
-		pProb[iRefMap].Constoadd = logpro;
-	}
-	pProb[iRefMap].Total += exp(logpro - pProb[iRefMap].Constoadd);
-
-	if(pProb[iRefMap].ConstAngle[iOrient] < logpro)
-	{
-		pProb[iRefMap].forAngles[iOrient] = pProb[iRefMap].forAngles[iOrient] * exp(-logpro + pProb[iRefMap].ConstAngle[iOrient]);
-		pProb[iRefMap].ConstAngle[iOrient] = logpro;
-	}
-	pProb[iRefMap].forAngles[iOrient] += exp(logpro - pProb[iRefMap].ConstAngle[iOrient]);
-
-	if(pProb[iRefMap].max_prob < logpro)
-	{
-		pProb[iRefMap].max_prob = logpro;
-		pProb[iRefMap].max_prob_cent_x = disx;
-		pProb[iRefMap].max_prob_cent_y = disy;
-		pProb[iRefMap].max_prob_orient = iOrient;
-		pProb[iRefMap].max_prob_conv = iConv;
-	}
-
-	return (0);
-}
-
-
 int bioem::createProjection(int iMap,mycomplex_t* mapFFT)
 {
 	/**************************************************************************************/
@@ -501,7 +439,6 @@ int bioem::createProjection(int iMap,mycomplex_t* mapFFT)
 	rotmat[2][1]=-sin(beta)*cos(alpha);
 	rotmat[2][2]=cos(beta);
 
-
 	for(int n=0; n< Model.nPointsModel; n++)
 	{
 		RotatedPointsModel[n].pos[0]=0.0;
@@ -571,11 +508,12 @@ int bioem::createConvolutedProjectionMap(int iMap,int iConv,mycomplex_t* lproj,b
 
 	/**** Multiplying FFTmap with corresponding kernel ****/
 
+	const mycomplex_t* refCTF = &param.refCTF[iConv * param.RefMapSize];
 	for(int i=0;i < param.param_device.NumberPixels * param.param_device.NumberFFTPixels1D;i++)
 	{
-		localmultFFT[i][0] = lproj[i][0] * param.refCTF[iConv].cpoints[i][0] + lproj[i][1] * param.refCTF[iConv].cpoints[i][1];
-		localmultFFT[i][1] = lproj[i][1] * param.refCTF[iConv].cpoints[i][0] - lproj[i][0] * param.refCTF[iConv].cpoints[i][1];
-		// cout << "GG " << i << " " << j << " " << param.refCTF[iConv].cpoints[i][0] << " " <<param.refCTF[iConv].cpoints[i][1] <<" " <<lproj[i][0] <<" " <<lproj[i][1] << "\n";
+		localmultFFT[i][0] = lproj[i][0] * refCTF[i][0] + lproj[i][1] * refCTF[i][1];
+		localmultFFT[i][1] = lproj[i][1] * refCTF[i][0] - lproj[i][0] * refCTF[i][1];
+		// cout << "GG " << i << " " << j << " " << refCTF[i][0] << " " << refCTF[i][1] <<" " <<lproj[i][0] <<" " <<lproj[i][1] << "\n";
 	}
 
 	//FFTW_C2R will destroy the input array, so we have to work on a copy here
diff --git a/bioem_algorithm.h b/bioem_algorithm.h
index e4f401b..5640cec 100644
--- a/bioem_algorithm.h
+++ b/bioem_algorithm.h
@@ -70,6 +70,66 @@ __device__ static inline myfloat_t calc_logpro(const bioem_param_device& param,
 	return(logpro);
 }
 
+__device__ static inline void calProb(int iRefMap,int iOrient, int iConv,myfloat_t sumC,myfloat_t sumsquareC, float value, int disx, int disy, bioem_Probability* pProb, const bioem_param_device& param, const bioem_RefMap& RefMap)
+{
+	/********************************************************/
+	/*********** Calculates the BioEM probability ***********/
+	/********************************************************/
+
+	const myfloat_t logpro = calc_logpro(param, sumC, sumsquareC, value, RefMap.sum_RefMap[iRefMap], RefMap.sumsquare_RefMap[iRefMap]);
+
+	//update_prob<-1>(logpro, iRefMap, iOrient, iConv, disx, disy, pProb);
+	//GCC is too stupid to inline properly, so the code is copied here
+    if(pProb[iRefMap].Constoadd < logpro)
+    {
+		pProb[iRefMap].Total = pProb[iRefMap].Total * exp(-logpro + pProb[iRefMap].Constoadd);
+		pProb[iRefMap].Constoadd = logpro;
+	}
+	pProb[iRefMap].Total += exp(logpro - pProb[iRefMap].Constoadd);
+
+	if(pProb[iRefMap].ConstAngle[iOrient] < logpro)
+	{
+		pProb[iRefMap].forAngles[iOrient] = pProb[iRefMap].forAngles[iOrient] * exp(-logpro + pProb[iRefMap].ConstAngle[iOrient]);
+		pProb[iRefMap].ConstAngle[iOrient] = logpro;
+	}
+	pProb[iRefMap].forAngles[iOrient] += exp(logpro - pProb[iRefMap].ConstAngle[iOrient]);
+
+	if(pProb[iRefMap].max_prob < logpro)
+	{
+		pProb[iRefMap].max_prob = logpro;
+		pProb[iRefMap].max_prob_cent_x = disx;
+		pProb[iRefMap].max_prob_cent_y = disy;
+		pProb[iRefMap].max_prob_orient = iOrient;
+		pProb[iRefMap].max_prob_conv = iConv;
+	}
+}
+
+__device__ static inline void doRefMapFFT(const int iRefMap, const int iOrient, const int iConv, const myfloat_t* lCC, const myfloat_t sumC, const myfloat_t sumsquareC, bioem_Probability* pProb, const bioem_param_device& param, const bioem_RefMap& RefMap)
+{
+	for (int cent_x = 0; cent_x <= param.maxDisplaceCenter; cent_x=cent_x+param.GridSpaceCenter)
+	{
+		for (int cent_y = 0; cent_y <= param.maxDisplaceCenter; cent_y=cent_y+param.GridSpaceCenter)
+		{
+			calProb(iRefMap, iOrient, iConv, sumC, sumsquareC, (myfloat_t) lCC[cent_x*param.NumberPixels+cent_y]/ (myfloat_t) (param.NumberPixels * param.NumberPixels), cent_x, cent_y, pProb, param, RefMap);
+		}
+		for (int cent_y = param.NumberPixels-param.maxDisplaceCenter; cent_y < param.NumberPixels; cent_y=cent_y+param.GridSpaceCenter)
+		{
+			calProb(iRefMap, iOrient, iConv, sumC, sumsquareC, (myfloat_t) lCC[cent_x*param.NumberPixels+cent_y]/ (myfloat_t) (param.NumberPixels*param.NumberPixels), cent_x, param.NumberPixels-cent_y, pProb, param, RefMap);
+		}
+	}
+	for (int cent_x = param.NumberPixels-param.maxDisplaceCenter; cent_x < param.NumberPixels; cent_x=cent_x+param.GridSpaceCenter)
+	{
+		for (int cent_y = 0; cent_y < param.maxDisplaceCenter; cent_y=cent_y+param.GridSpaceCenter)
+		{
+			calProb(iRefMap, iOrient, iConv, sumC, sumsquareC, (myfloat_t) lCC[cent_x*param.NumberPixels+cent_y]/ (myfloat_t) (param.NumberPixels*param.NumberPixels), param.NumberPixels-cent_x, cent_y, pProb, param, RefMap);
+		}
+		for (int cent_y = param.NumberPixels-param.maxDisplaceCenter; cent_y <= param.NumberPixels; cent_y=cent_y+param.GridSpaceCenter)
+		{
+			calProb(iRefMap, iOrient, iConv, sumC, sumsquareC, (myfloat_t) lCC[cent_x*param.NumberPixels+cent_y]/ (myfloat_t) (param.NumberPixels*param.NumberPixels), param.NumberPixels-cent_x, param.NumberPixels-cent_y, pProb, param, RefMap);
+		}
+	}
+}
+
 template <int GPUAlgo, class RefT>
 __device__ static inline void compareRefMap(const int iRefMap, const int iOrient, const int iConv, const bioem_map& Mapconv, bioem_Probability* pProb, const bioem_param_device& param, const RefT& RefMap,
 	const int cent_x, const int cent_y, const int myShift = 0, const int nShifts2 = 0, const int myRef = 0, const bool threadActive = true)
diff --git a/bioem_cuda.cu b/bioem_cuda.cu
index d125710..223b7c6 100644
--- a/bioem_cuda.cu
+++ b/bioem_cuda.cu
@@ -12,13 +12,13 @@ using namespace std;
 
 //__constant__ bioem_map gConvMap;
 
-static inline void checkCudaErrors(cudaError_t error)
-{
-	if (error != cudaSuccess)
-	{
-		printf("CUDA Error %d / %s\n", error, cudaGetErrorString(error));
-		exit(1);
-	}
+#define checkCudaErrors(error) \
+{ \
+	if ((error) != cudaSuccess) \
+	{ \
+		printf("CUDA Error %d / %s (%s: %d)\n", error, cudaGetErrorString(error), __FILE__, __LINE__); \
+		exit(1); \
+	} \
 }
 
 bioem_cuda::bioem_cuda()
@@ -52,6 +52,15 @@ __global__ void compareRefMapShifted_kernel(const int iOrient, const int iConv,
 	}
 }
 
+__global__ void cudaZeroMem(void* ptr, size_t size)
+{
+	int* myptr = (int*) ptr;
+	int mysize = size / sizeof(int);
+	int myid = myBlockDimX * myBlockIdxX + myThreadIdxX;
+	int mygrid = myBlockDimX * myGridDimX;
+	for (int i = myid;i < mysize;i += mygrid) myptr[i] = 0;
+}
+
 __global__ void compareRefMapLoopShifts_kernel(const int iOrient, const int iConv, const bioem_map* pMap, bioem_Probability* pProb, const bioem_param_device param, const bioem_RefMap* RefMap, const int blockoffset, const int nShifts, const int nShiftBits, const int maxRef)
 {
 	const size_t myid = (myBlockIdxX + blockoffset) * myBlockDimX + myThreadIdxX;
@@ -68,6 +77,26 @@ __global__ void compareRefMapLoopShifts_kernel(const int iOrient, const int iCon
 	compareRefMap<2>(iRefMap,iOrient,iConv,*pMap, pProb, param, *RefMap, cent_x, cent_y, myShift, nShifts * nShifts, myRef, threadActive);
 }
 
+__global__ void multComplexMap(const mycomplex_t* convmap, const mycomplex_t* refmap, mycuComplex_t* out, const int NumberPixelsTotal, const int MapSize, const int NumberMaps)
+{
+	if (myBlockIdxX >= NumberMaps) return;
+	const mycomplex_t* myin = &refmap[myBlockIdxX * MapSize];
+	mycuComplex_t* myout = &out[myBlockIdxX * MapSize];
+	for(int i = myThreadIdxX;i < NumberPixelsTotal;i += myBlockDimX)
+	{
+		myout[i].x = convmap[i][0] * myin[i][0] + convmap[i][1] * myin[i][1];
+		myout[i].y = convmap[i][1] * myin[i][0] - convmap[i][0] * myin[i][1];
+	}
+}
+
+__global__ void cuDoRefMapsFFT(const int iOrient, const int iConv, const myfloat_t* lCC, const myfloat_t sumC, const myfloat_t sumsquareC, bioem_Probability* pProb, const bioem_param_device param, const bioem_RefMap* RefMap, const int maxRef)
+{
+	const int iRefMap = myBlockIdxX * myBlockDimX + myThreadIdxX;
+	const myfloat_t* mylCC = &lCC[iRefMap * param.NumberPixels * param.NumberPixels];
+	if (iRefMap >= maxRef) return;
+	doRefMapFFT(iRefMap, iOrient, iConv, mylCC, sumC, sumsquareC, pProb, param, *RefMap);
+}
+
 template <class T> static inline T divup(T num, T divider) {return((num + divider - 1) / divider);}
 static inline bool IsPowerOf2(int x) {return ((x > 0) && ((x & (x - 1)) == 0));}
 #if defined(_WIN32)
@@ -93,53 +122,70 @@ int bioem_cuda::compareRefMaps(int iProjectionOut, int iConv, const bioem_map& c
 	{
 		checkCudaErrors(cudaEventSynchronize(cudaEvent[iConv & 1]));
 	}
-	checkCudaErrors(cudaMemcpyAsync(pConvMap_device[iConv & 1], &conv_map, sizeof(bioem_map), cudaMemcpyHostToDevice, cudaStream));
 
-	if (GPUAlgo == 2) //Loop over shifts
+	if (FFTAlgo)
 	{
-		const int nShifts = 2 * param.param_device.maxDisplaceCenter / param.param_device.GridSpaceCenter + 1;
-		if (!IsPowerOf2(nShifts))
+		checkCudaErrors(cudaMemcpyAsync(&pConvMapFFT[(iConv & 1) * param.RefMapSize], localmultFFT, param.RefMapSize * sizeof(mycomplex_t), cudaMemcpyHostToDevice, cudaStream));
+		multComplexMap<<<maxRef, CUDA_THREAD_COUNT, 0, cudaStream>>>(&pConvMapFFT[(iConv & 1) * param.RefMapSize], pRefMapsFFT, pFFTtmp2, param.param_device.NumberPixels * param.param_device.NumberFFTPixels1D, param.RefMapSize, maxRef);
+		cudaZeroMem<<<32, 256>>>(pFFTtmp, maxRef * sizeof(myfloat_t) * param.param_device.NumberPixels * param.param_device.NumberPixels);
+		if (mycufftExecC2R(plan, pFFTtmp2, pFFTtmp) != CUFFT_SUCCESS)
 		{
-			cout << "Invalid number of displacements, no power of two\n";
+			cout << "Error running CUFFT\n";
 			exit(1);
 		}
-		if (CUDA_THREAD_COUNT % (nShifts * nShifts))
+		cuDoRefMapsFFT<<<divup(maxRef, CUDA_THREAD_COUNT), CUDA_THREAD_COUNT, 0, cudaStream>>>(iProjectionOut, iConv, pFFTtmp, sumC, sumsquareC, pProb_device, param.param_device, pRefMap_device, maxRef);
+		checkCudaErrors(cudaGetLastError());
+	}
+	else
+	{
+		checkCudaErrors(cudaMemcpyAsync(pConvMap_device[iConv & 1], &conv_map, sizeof(bioem_map), cudaMemcpyHostToDevice, cudaStream));
+
+		if (GPUAlgo == 2) //Loop over shifts
 		{
-			cout << "CUDA Thread count (" << CUDA_THREAD_COUNT << ") is no multiple of number of shifts (" << (nShifts * nShifts) << ")\n";
-			exit(1);
+			const int nShifts = 2 * param.param_device.maxDisplaceCenter / param.param_device.GridSpaceCenter + 1;
+			if (!IsPowerOf2(nShifts))
+			{
+				cout << "Invalid number of displacements, no power of two\n";
+				exit(1);
+			}
+			if (CUDA_THREAD_COUNT % (nShifts * nShifts))
+			{
+				cout << "CUDA Thread count (" << CUDA_THREAD_COUNT << ") is no multiple of number of shifts (" << (nShifts * nShifts) << ")\n";
+				exit(1);
+			}
+			if (nShifts > CUDA_MAX_SHIFT_REDUCE)
+			{
+				cout << "Too many displacements for CUDA reduction\n";
+				exit(1);
+			}
+			const int nShiftBits = ilog2(nShifts);
+			size_t totalBlocks = divup((size_t) maxRef * (size_t) nShifts * (size_t) nShifts, (size_t) CUDA_THREAD_COUNT);
+			size_t nBlocks = CUDA_BLOCK_COUNT;
+			for (size_t i = 0;i < totalBlocks;i += nBlocks)
+			{
+				compareRefMapLoopShifts_kernel <<<min(nBlocks, totalBlocks - i), CUDA_THREAD_COUNT, (CUDA_THREAD_COUNT * 2 + CUDA_THREAD_COUNT / (nShifts * nShifts) * 4) * sizeof(myfloat_t), cudaStream>>> (iProjectionOut, iConv, pConvMap_device[iConv & 1], pProb_device, param.param_device, pRefMap_device, i, nShifts, nShiftBits, maxRef);
+			}
 		}
-		if (nShifts > CUDA_MAX_SHIFT_REDUCE)
+		else if (GPUAlgo == 1) //Split shifts in multiple kernels
 		{
-			cout << "Too many displacements for CUDA reduction\n";
-			exit(1);
+			for (int cent_x = -param.param_device.maxDisplaceCenter; cent_x <= param.param_device.maxDisplaceCenter; cent_x=cent_x+param.param_device.GridSpaceCenter)
+			{
+				for (int cent_y = -param.param_device.maxDisplaceCenter; cent_y <= param.param_device.maxDisplaceCenter; cent_y=cent_y+param.param_device.GridSpaceCenter)
+				{
+					compareRefMap_kernel <<<divup(maxRef, CUDA_THREAD_COUNT), CUDA_THREAD_COUNT, 0, cudaStream>>> (iProjectionOut, iConv, pConvMap_device[iConv & 1], pProb_device, param.param_device, pRefMap_device_Mod, cent_x, cent_y, maxRef);
+				}
+			}
 		}
-		const int nShiftBits = ilog2(nShifts);
-		size_t totalBlocks = divup((size_t) maxRef * (size_t) nShifts * (size_t) nShifts, (size_t) CUDA_THREAD_COUNT);
-		size_t nBlocks = CUDA_BLOCK_COUNT;
-		for (size_t i = 0;i < totalBlocks;i += nBlocks)
+		else if (GPUAlgo == 0) //All shifts in one kernel
 		{
-			compareRefMapLoopShifts_kernel <<<min(nBlocks, totalBlocks - i), CUDA_THREAD_COUNT, (CUDA_THREAD_COUNT * 2 + CUDA_THREAD_COUNT / (nShifts * nShifts) * 4) * sizeof(myfloat_t), cudaStream>>> (iProjectionOut, iConv, pConvMap_device[iConv & 1], pProb_device, param.param_device, pRefMap_device, i, nShifts, nShiftBits, maxRef);
+			compareRefMapShifted_kernel <<<divup(maxRef, CUDA_THREAD_COUNT), CUDA_THREAD_COUNT, 0, cudaStream>>> (iProjectionOut, iConv, pConvMap_device[iConv & 1], pProb_device, param.param_device, pRefMap_device_Mod, maxRef);
 		}
-	}
-	else if (GPUAlgo == 1) //Split shifts in multiple kernels
-	{
-		for (int cent_x = -param.param_device.maxDisplaceCenter; cent_x <= param.param_device.maxDisplaceCenter; cent_x=cent_x+param.param_device.GridSpaceCenter)
+		else
 		{
-			for (int cent_y = -param.param_device.maxDisplaceCenter; cent_y <= param.param_device.maxDisplaceCenter; cent_y=cent_y+param.param_device.GridSpaceCenter)
-			{
-				compareRefMap_kernel <<<divup(maxRef, CUDA_THREAD_COUNT), CUDA_THREAD_COUNT, 0, cudaStream>>> (iProjectionOut, iConv, pConvMap_device[iConv & 1], pProb_device, param.param_device, pRefMap_device_Mod, cent_x, cent_y, maxRef);
-			}
+			cout << "Invalid GPU Algorithm selected\n";
+			exit(1);
 		}
 	}
-	else if (GPUAlgo == 0) //All shifts in one kernel
-	{
-		compareRefMapShifted_kernel <<<divup(maxRef, CUDA_THREAD_COUNT), CUDA_THREAD_COUNT, 0, cudaStream>>> (iProjectionOut, iConv, pConvMap_device[iConv & 1], pProb_device, param.param_device, pRefMap_device_Mod, maxRef);
-	}
-	else
-	{
-		cout << "Invalid GPU Algorithm selected\n";
-		exit(1);
-	}
 	if (GPUWorkload < 100)
 	{
 		bioem::compareRefMaps(iProjectionOut, iConv, conv_map, localmultFFT, sumC, sumsquareC, maxRef);
@@ -159,6 +205,8 @@ int bioem_cuda::deviceInit()
 {
 	deviceExit();
 
+	if (FFTAlgo) GPUAlgo = 2;
+
 	checkCudaErrors(cudaStreamCreate(&cudaStream));
 	checkCudaErrors(cudaMalloc(&pRefMap_device, sizeof(bioem_RefMap)));
 	checkCudaErrors(cudaMalloc(&pProb_device, sizeof(bioem_Probability) * RefMap.ntotRefMap));
@@ -169,6 +217,32 @@ int bioem_cuda::deviceInit()
 	}
 	pRefMap_device_Mod = (bioem_RefMap_Mod*) pRefMap_device;
 
+	if (FFTAlgo)
+	{
+		checkCudaErrors(cudaMalloc(&pRefMapsFFT, RefMap.ntotRefMap * param.RefMapSize * sizeof(mycomplex_t)));
+		checkCudaErrors(cudaMalloc(&pFFTtmp2, RefMap.ntotRefMap * param.RefMapSize * sizeof(mycomplex_t)));
+		checkCudaErrors(cudaMalloc(&pFFTtmp, RefMap.ntotRefMap * param.param_device.NumberPixels * param.param_device.NumberPixels * sizeof(myfloat_t)));
+		checkCudaErrors(cudaMalloc(&pConvMapFFT, param.RefMapSize * sizeof(mycomplex_t) * 2));
+		cudaMemcpy(pRefMapsFFT, RefMap.RefMapsFFT, RefMap.ntotRefMap * param.RefMapSize * sizeof(mycomplex_t), cudaMemcpyHostToDevice);
+
+		int n[2] = {param.param_device.NumberPixels, param.param_device.NumberPixels};
+        if (cufftPlanMany(&plan, 2, n, NULL, 1, 0, NULL, 1, 0, CUFFT_C2R, RefMap.ntotRefMap) != CUFFT_SUCCESS)
+        {
+			cout << "Error planning CUFFT\n";
+			exit(1);
+		}
+		if (cufftSetCompatibilityMode(plan, CUFFT_COMPATIBILITY_NATIVE) != CUFFT_SUCCESS)
+		{
+			cout << "Error planning CUFFT compatibility\n";
+			exit(1);
+		}
+		if (cufftSetStream(plan, cudaStream) != CUFFT_SUCCESS)
+		{
+			cout << "Error setting CUFFT stream\n";
+			exit(1);
+		}
+	}
+
 	if (GPUAlgo == 0 || GPUAlgo == 1)
 	{
 		bioem_RefMap_Mod* RefMapGPU = new bioem_RefMap_Mod(RefMap);
@@ -179,6 +253,7 @@ int bioem_cuda::deviceInit()
 	{
 		cudaMemcpy(pRefMap_device, &RefMap, sizeof(bioem_RefMap), cudaMemcpyHostToDevice);
 	}
+
 	deviceInitialized = 1;
 	return(0);
 }
@@ -195,6 +270,14 @@ int bioem_cuda::deviceExit()
 		cudaEventDestroy(cudaEvent[i]);
 		cudaFree(pConvMap_device);
 	}
+	if (FFTAlgo)
+	{
+		cudaFree(pRefMapsFFT);
+		cudaFree(pConvMapFFT);
+		cudaFree(pFFTtmp);
+		cudaFree(pFFTtmp2);
+		cufftDestroy(plan);
+	}
 	cudaThreadExit();
 
 	deviceInitialized = 0;
diff --git a/include/bioem.h b/include/bioem.h
index 510b7a4..1892947 100644
--- a/include/bioem.h
+++ b/include/bioem.h
@@ -27,8 +27,7 @@ public:
 
 	int createProjection(int iMap, mycomplex_t* map);
 	int calcross_cor(bioem_map& localmap,myfloat_t& sum,myfloat_t& sumsquare);
-	int calProb(int iRefMap,int iOrient, int iConv,myfloat_t sumC,myfloat_t sumsquareC, float value, int disx, int disy);
-	int calculateCCFFT(int iMap, int iOrient, int iConv, myfloat_t sumC, myfloat_t sumsquareC, mycomplex_t* localConvFFT,mycomplex_t* localCCT,myfloat_t* lCC);
+	void calculateCCFFT(int iMap, int iOrient, int iConv, myfloat_t sumC, myfloat_t sumsquareC, mycomplex_t* localConvFFT,mycomplex_t* localCCT,myfloat_t* lCC);
 
 	bioem_Probability* pProb;
 
diff --git a/include/bioem_cuda_internal.h b/include/bioem_cuda_internal.h
index c553288..683d753 100644
--- a/include/bioem_cuda_internal.h
+++ b/include/bioem_cuda_internal.h
@@ -2,6 +2,7 @@
 #define BIOEM_CUDA_INTERNAL_H
 
 #include <cuda.h>
+#include <cufft.h>
 
 //Hack to make nvcc compiler accept fftw.h, float128 is not used anyway
 #define __float128 double
@@ -33,6 +34,12 @@ protected:
 	bioem_Probability* pProb_device;
 	bioem_map* pConvMap_device[2];
 
+	mycomplex_t* pRefMapsFFT;
+	mycomplex_t* pConvMapFFT;
+	mycuComplex_t* pFFTtmp2;
+	myfloat_t* pFFTtmp;
+	cufftHandle plan;
+
 	int GPUAlgo;		//GPU Algorithm to use, 0: parallelize over maps, 1: as 0 but work split in multiple kernels (better), 2: also parallelize over shifts (best)
 	int GPUAsync;		//Run GPU Asynchronously, do the convolutions on the host in parallel.
 	int GPUWorkload;	//Percentage of workload to perform on GPU. Default 100. Rest is done on processor in parallel.
diff --git a/include/defs.h b/include/defs.h
index 8df699c..ebf2e14 100644
--- a/include/defs.h
+++ b/include/defs.h
@@ -16,6 +16,8 @@ typedef float myfloat_t;
 #define myfftw_plan_dft_r2c_2d fftwf_plan_dft_r2c_2d
 #define myfftw_plan_dft_c2r_2d fftwf_plan_dft_c2r_2d
 #define myfftw_plan fftwf_plan
+#define mycufftExecC2R cufftExecC2R
+#define mycuComplex_t cuComplex
 #else
 typedef double myfloat_t;
 #define myfftw_malloc fftw_malloc
@@ -29,6 +31,8 @@ typedef double myfloat_t;
 #define myfftw_plan_dft_r2c_2d fftw_plan_dft_r2c_2d
 #define myfftw_plan_dft_c2r_2d fftw_plan_dft_c2r_2d
 #define myfftw_plan fftw_plan
+#define mycufftExecC2R cufftExecZ2D
+#define mycuComplex_t cuDoubleComplex
 #endif
 typedef myfloat_t mycomplex_t[2];
 
@@ -36,10 +40,8 @@ typedef myfloat_t mycomplex_t[2];
 #define BIOEM_MAP_SIZE_X 224
 #define BIOEM_MAP_SIZE_Y 224
 #define BIOEM_MODEL_SIZE 120000
-#define BIOEM_MAX_MAPS 12000
-#define MAX_REF_CTF 200
+#define BIOEM_MAX_MAPS 4000
 #define MAX_ORIENT 20000
-#define MAX_DISPLACE 224
 
 struct myfloat3_t
 {
@@ -53,6 +55,7 @@ struct myfloat3_t
 #define myBlockDimY blockDim.y
 #define myBlockIdxX blockIdx.x
 #define myBlockIdxY blockIdx.y
+#define myGridDimX gridDim.x
 #else
 #define __device__
 #define __host__
diff --git a/include/map.h b/include/map.h
index a6360e4..f801d16 100644
--- a/include/map.h
+++ b/include/map.h
@@ -13,27 +13,17 @@ public:
 	myfloat_t points[BIOEM_MAP_SIZE_X][BIOEM_MAP_SIZE_Y];
 };
 
-class bioem_map_forFFT
-{
-public:
-	mycomplex_t cpoints[BIOEM_MAP_SIZE_X*BIOEM_MAP_SIZE_Y];
-};
-
-class bioem_convolutedMap
-{
-public:
-	bioem_map conv[MAX_REF_CTF];
-	myfloat_t sum_convMap[MAX_REF_CTF];
-	myfloat_t sumsquare_convMap[MAX_REF_CTF];
-	myfloat_t ForLogProbfromConv[MAX_REF_CTF];
-};
-
 class bioem_RefMap
 {
 public:
+	bioem_RefMap()
+	{
+		RefMapsFFT = NULL;
+	}
 	int readRefMaps(bioem_param& param);
 	int PreCalculateMapsFFT(bioem_param& param);
-	bioem_map_forFFT* RefMapFFT;
+
+	mycomplex_t* RefMapsFFT;
 
 	const char* filemap;
 	int ntotRefMap;
@@ -83,14 +73,6 @@ public:
 	}
 };
 
-class bioem_crossCor
-{
-public:
-	int disx[MAX_DISPLACE];
-	int disy[MAX_DISPLACE];
-	myfloat_t value[MAX_DISPLACE];
-};
-
 class bioem_Probability
 {
 public:
diff --git a/include/param.h b/include/param.h
index e211b67..5b85e47 100644
--- a/include/param.h
+++ b/include/param.h
@@ -7,8 +7,6 @@
 #include <math.h>
 #include <fftw3.h>
 
-class bioem_map_forFFT;
-
 class bioem_param_device
 {
 public:
@@ -34,7 +32,9 @@ public:
 
 	bioem_param_device param_device;
 
-	bioem_map_forFFT* refCTF;
+	int RefMapSize;
+	mycomplex_t* refCTF;
+	myfloat3_t* CtfParam;
 
 // File names
 	const char* fileinput;
@@ -66,7 +66,6 @@ public:
 	// Others
 	//myfloat_t volu;//in device class
 	myfloat3_t angles[MAX_ORIENT];
-	myfloat3_t CtfParam[MAX_REF_CTF] ;
 	int nTotGridAngles;
 	int nTotCTFs;
 	//myfloat_t Ntotpi;//in device class
diff --git a/map.cpp b/map.cpp
index 1f0a585..7e10397 100644
--- a/map.cpp
+++ b/map.cpp
@@ -27,8 +27,8 @@ int bioem_RefMap::readRefMaps(bioem_param& param)
 		fread(&ntotRefMap, sizeof(ntotRefMap), 1, fp);
 		if (ntotRefMap > BIOEM_MAX_MAPS)
 		{
-			cout << "BIOEM_MAX_MAPS too small\n";
-			exit(1);
+			cout << "BIOEM_MAX_MAPS too small, some maps dropped\n";
+			ntotRefMap = BIOEM_MAX_MAPS;
 		}
 		fread(&Ref[0], sizeof(Ref[0]), ntotRefMap, fp);
 		fclose(fp);
@@ -135,12 +135,12 @@ int bioem_RefMap::PreCalculateMapsFFT(bioem_param& param)
 
 	myfloat_t* localMap;
 
-	localMap= (myfloat_t *) myfftw_malloc(sizeof(myfloat_t) *param.param_device.NumberPixels*param.param_device.NumberPixels);
+	localMap= (myfloat_t *) myfftw_malloc(sizeof(myfloat_t) * param.param_device.NumberPixels * param.param_device.NumberPixels);
 
-	RefMapFFT = new bioem_map_forFFT[ntotRefMap];
+	RefMapsFFT = new mycomplex_t[ntotRefMap * param.RefMapSize];
 
 	mycomplex_t* localout;
-	localout= (mycomplex_t *) myfftw_malloc(sizeof(mycomplex_t) *param.param_device.NumberPixels*param.param_device.NumberFFTPixels1D);
+	localout= (mycomplex_t *) myfftw_malloc(sizeof(mycomplex_t) * param.param_device.NumberPixels * param.param_device.NumberFFTPixels1D);
 
 	for (int iRefMap = 0; iRefMap < ntotRefMap ; iRefMap++)
 	{
@@ -157,16 +157,12 @@ int bioem_RefMap::PreCalculateMapsFFT(bioem_param& param)
 		myfftw_execute_dft_r2c(param.fft_plan_r2c_forward,localMap,localout);
 
 		// Normalizing and saving the Reference CTFs
-		for(int i=0; i < param.param_device.NumberPixels ; i++ )
+		mycomplex_t* RefMap = &RefMapsFFT[iRefMap * param.RefMapSize];
+		for(int i=0; i < param.param_device.NumberPixels * param.param_device.NumberFFTPixels1D ; i++ )
 		{
-			for(int j=0; j < param.param_device.NumberFFTPixels1D ; j++ )
-			{
-				RefMapFFT[iRefMap].cpoints[i*param.param_device.NumberFFTPixels1D+j][0]=localout[i*param.param_device.NumberFFTPixels1D+j][0];
-				RefMapFFT[iRefMap].cpoints[i*param.param_device.NumberFFTPixels1D+j][1]=localout[i*param.param_device.NumberFFTPixels1D+j][1];
-			}
+			RefMap[i][0]=localout[i][0];
+			RefMap[i][1]=localout[i][1];
 		}
-
-
 	}
 
 	myfftw_free(localMap);
diff --git a/param.cpp b/param.cpp
index cc4201b..ae189ab 100644
--- a/param.cpp
+++ b/param.cpp
@@ -30,6 +30,9 @@ bioem_param::bioem_param()
 	numberGridPointsDisplaceCenter = 0;
 
 	fft_plans_created = 0;
+
+	refCTF = NULL;
+	CtfParam = NULL;
 }
 
 int bioem_param::readParameters()
@@ -177,6 +180,7 @@ int bioem_param::readParameters()
 	}
 	input.close();
 	param_device.NumberFFTPixels1D = param_device.NumberPixels / 2 + 1;
+	RefMapSize = param_device.NumberPixels * param_device.NumberFFTPixels1D;
 	cout << " +++++++++++++++++++++++++++++++++++++++++ \n";
 
 	cout << "Preparing FFTs\n";
@@ -271,7 +275,12 @@ int bioem_param::CalculateRefCTF()
 	int n=0;
 
 	localCTF= (myfloat_t *) myfftw_malloc(sizeof(myfloat_t) * param_device.NumberPixels*param_device.NumberPixels);
-	refCTF = new bioem_map_forFFT[MAX_REF_CTF];
+
+	nTotCTFs = numberGridPointsCTF_amp * numberGridPointsCTF_phase * numberGridPointsEnvelop;
+	delete[] refCTF;
+	refCTF = new mycomplex_t[nTotCTFs * RefMapSize];
+	delete[] CtfParam;
+	CtfParam = new myfloat3_t[nTotCTFs];
 
 	for (int iamp = 0; iamp <  numberGridPointsCTF_amp ; iamp++) //Loop over amplitud
 	{
@@ -308,29 +317,27 @@ int bioem_param::CalculateRefCTF()
 				myfftw_execute_dft_r2c(fft_plan_r2c_forward,localCTF,localout);
 
 				// Normalizing and saving the Reference CTFs
-				for(int i=0; i < param_device.NumberPixels ; i++ )
+				mycomplex_t* curRef = &refCTF[n * RefMapSize];
+				for(int i=0; i < param_device.NumberPixels * param_device.NumberFFTPixels1D; i++ )
 				{
-					for(int j=0; j < param_device.NumberFFTPixels1D ; j++ )
-					{
-						refCTF[n].cpoints[i*param_device.NumberFFTPixels1D+j][0]=localout[i*param_device.NumberFFTPixels1D+j][0];
-						refCTF[n].cpoints[i*param_device.NumberFFTPixels1D+j][1]=localout[i*param_device.NumberFFTPixels1D+j][1];
-					}
+					curRef[i][0] = localout[i][0];
+					curRef[i][1] = localout[i][1];
 				}
 				CtfParam[n].pos[0]=amp;
 				CtfParam[n].pos[1]=phase;
 				CtfParam[n].pos[2]=env;
 				n++;
 				myfftw_free(localout);
-				if(n>MAX_REF_CTF)
-				{   cout << n << "PROBLEM WITH CTF KERNEL PARAMETERS AND MAX NUMBER ALLOWED\n";
-					exit(1);
-				}
 			}
 		}
 	}
 
 	myfftw_free(localCTF);
-	nTotCTFs=n;
+	if (nTotCTFs != n)
+	{
+		cout << "Internal error during CTF preparation\n";
+		exit(1);
+	}
 
 	return(0);
 }
@@ -347,4 +354,6 @@ bioem_param::~bioem_param()
 	numberGridPointsCTF_phase = 0;
 	param_device.maxDisplaceCenter = 0;
 	numberGridPointsDisplaceCenter = 0;
+	delete[] refCTF;
+	delete[] CtfParam;
 }
-- 
GitLab