diff --git a/src/field_descriptor.cpp b/src/field_descriptor.cpp index d6cd7c515ef02772cb35c4d520d364f0c21db244..20a399a0b2c24bd969a0a1fe86a151ccd0a33f0d 100644 --- a/src/field_descriptor.cpp +++ b/src/field_descriptor.cpp @@ -25,7 +25,8 @@ #include "base.hpp" #include "field_descriptor.hpp" -field_descriptor::field_descriptor( +template <class rnumber> +field_descriptor<rnumber>::field_descriptor( int ndims, int *n, MPI_Datatype element_type, @@ -46,7 +47,7 @@ field_descriptor::field_descriptor( ptrdiff_t local_n0, local_0_start; for (int i = 0; i < this->ndims; i++) nfftw[i] = n[i]; - this->local_size = fftwf_mpi_local_size_many( + this->local_size = fftw_mpi_local_size_many( this->ndims, nfftw, 1, @@ -183,7 +184,8 @@ field_descriptor::field_descriptor( delete[] local_rank; } -field_descriptor::~field_descriptor() +template <class rnumber> +field_descriptor<rnumber>::~field_descriptor() { DEBUG_MSG_WAIT( MPI_COMM_WORLD, @@ -211,7 +213,8 @@ field_descriptor::~field_descriptor() delete[] this->rank; } -int field_descriptor::read( +template<> +int field_descriptor<float>::read( const char *fname, void *buffer) { @@ -250,7 +253,8 @@ int field_descriptor::read( return EXIT_SUCCESS; } -int field_descriptor::write( +template<> +int field_descriptor<float>::write( const char *fname, void *buffer) { @@ -290,7 +294,8 @@ int field_descriptor::write( return EXIT_SUCCESS; } -int field_descriptor::transpose( +template<> +int field_descriptor<float>::transpose( float *input, float *output) { @@ -327,7 +332,8 @@ int field_descriptor::transpose( return EXIT_SUCCESS; } -int field_descriptor::transpose( +template<> +int field_descriptor<float>::transpose( fftwf_complex *input, fftwf_complex *output) { @@ -381,7 +387,8 @@ int field_descriptor::transpose( return EXIT_SUCCESS; } -int field_descriptor::interleave( +template<> +int field_descriptor<float>::interleave( float *a, int dim) { @@ -411,7 +418,8 @@ int field_descriptor::interleave( return EXIT_SUCCESS; } -int field_descriptor::interleave( +template<> +int field_descriptor<float>::interleave( fftwf_complex *a, int dim) { @@ -438,7 +446,8 @@ int field_descriptor::interleave( return EXIT_SUCCESS; } -int field_descriptor::switch_endianness( +template<> +int field_descriptor<float>::switch_endianness( float *a) { for (int i = 0; i < this->local_size; i++) @@ -449,7 +458,8 @@ int field_descriptor::switch_endianness( return EXIT_SUCCESS; } -int field_descriptor::switch_endianness( +template<> +int field_descriptor<float>::switch_endianness( fftwf_complex *b) { float *a = (float*)b; @@ -463,11 +473,12 @@ int field_descriptor::switch_endianness( return EXIT_SUCCESS; } -field_descriptor* field_descriptor::get_transpose() +template<> +field_descriptor<float>* field_descriptor<float>::get_transpose() { int n[this->ndims]; for (int i=0; i<this->ndims; i++) n[i] = this->sizes[this->ndims - i - 1]; - return new field_descriptor(this->ndims, n, this->mpi_dtype, this->comm); + return new field_descriptor<float>(this->ndims, n, this->mpi_dtype, this->comm); } diff --git a/src/field_descriptor.hpp b/src/field_descriptor.hpp index 793d4e9a0e5930be7eef3a16c7198fb72ed621cf..4197cb01b0c258435b012cc27d5f019a041cad28 100644 --- a/src/field_descriptor.hpp +++ b/src/field_descriptor.hpp @@ -31,6 +31,8 @@ extern int myrank, nprocs; template <class rnumber> class field_descriptor { + private: + typedef rnumber cnumber[2]; public: /* data */ @@ -76,20 +78,20 @@ class field_descriptor rnumber *input, rnumber *output); int transpose( - rnumber *input[2], - rnumber *output[2] = NULL); + cnumber *input, + cnumber *output = NULL); int interleave( rnumber *input, int dim); int interleave( - rnumber *input[2], + cnumber *input, int dim); int switch_endianness( rnumber *a); int switch_endianness( - rnumber *a[2]); + cnumber *a); }; @@ -98,13 +100,13 @@ class field_descriptor * the arrays are assumed to use fftw layout. * */ int fftwf_copy_complex_array( - field_descriptor *fi, + field_descriptor<float> *fi, fftwf_complex *ai, - field_descriptor *fo, + field_descriptor<float> *fo, fftwf_complex *ao); int fftwf_clip_zero_padding( - field_descriptor *f, + field_descriptor<float> *f, float *a); inline float btle(const float be)