Skip to content
Snippets Groups Projects
Commit 1b92ef2d authored by Cristian Lalescu's avatar Cristian Lalescu
Browse files

add vector_field class

parent 6eecc226
No related branches found
No related tags found
No related merge requests found
......@@ -61,6 +61,7 @@ base_files := \
fftw_tools \
Morton_shuffler \
p3DFFT_to_iR \
vector_field \
fluid_solver
#headers := $(patsubst %, ./src/%.hpp, ${base_files})
......
......@@ -2,15 +2,22 @@ fluid_solver<float> *fs;
fs = new fluid_solver<float>(32, 32, 32);
DEBUG_MSG("fluid_solver object created\n");
fs->fc->read(
vector_field<float> cv(fs->cd, fs->cvorticity);
vector_field<float> rv(fs->cd, fs->rvorticity);
fs->cd->read(
"Kdata0",
(void*)fs->cvorticity);
fftwf_execute(*(fftwf_plan*)fs->c2r_vorticity);
//rv*(1. / (fs->rd->sizes[0]*fs->rd->sizes[1]*fs->rd->sizes[2]));
fftwf_execute(*(fftwf_plan*)fs->r2c_vorticity);
fs->fc->write(
cv = cv*(1. / (fs->rd->sizes[0]*fs->rd->sizes[1]*fs->rd->sizes[2]));
fs->cd->write(
"Kdata1",
(void*)fs->cvorticity);
DEBUG_MSG("full size is %ld\n", fs->rd->full_size);
delete fs;
DEBUG_MSG("fluid_solver object deleted\n");
......@@ -28,7 +28,7 @@
/*****************************************************************************/
/* macro for specializations to numeric types compatible with FFTW */
#define FLUID_SOLVER_DEFINITIONS(FFTW, R, C) \
#define FLUID_SOLVER_DEFINITIONS(FFTW, R, C, MPI_RNUM, MPI_CNUM) \
\
template<> \
fluid_solver<R>::fluid_solver( \
......@@ -36,20 +36,29 @@ fluid_solver<R>::fluid_solver( \
int ny, \
int nz) \
{ \
get_descriptors_3D<R>(nz, ny, nx, &this->fr, &this->fc);\
this->cvorticity = FFTW(alloc_complex)(this->fc->local_size*3);\
this->cvelocity = FFTW(alloc_complex)(this->fc->local_size*3);\
this->rvorticity = FFTW(alloc_real)(this->fc->local_size*6);\
this->rvelocity = FFTW(alloc_real)(this->fc->local_size*6);\
int ntmp[4]; \
ntmp[0] = nz; \
ntmp[1] = ny; \
ntmp[2] = nx; \
ntmp[3] = 3; \
this->rd = new field_descriptor<R>( \
4, ntmp, MPI_RNUM, MPI_COMM_WORLD);\
ntmp[2] = nx/2 + 1; \
this->cd = new field_descriptor<R>( \
4, ntmp, MPI_CNUM, MPI_COMM_WORLD);\
this->cvorticity = FFTW(alloc_complex)(this->cd->local_size);\
this->cvelocity = FFTW(alloc_complex)(this->cd->local_size);\
this->rvorticity = FFTW(alloc_real)(this->cd->local_size*2);\
this->rvelocity = FFTW(alloc_real)(this->cd->local_size*2);\
\
this->c2r_vorticity = new FFTW(plan);\
this->r2c_vorticity = new FFTW(plan);\
this->c2r_velocity = new FFTW(plan);\
this->r2c_velocity = new FFTW(plan);\
\
ptrdiff_t sizes[] = {this->fr->sizes[0], \
this->fr->sizes[1], \
this->fr->sizes[2]};\
ptrdiff_t sizes[] = {nz, \
ny, \
nx};\
\
*(FFTW(plan)*)this->c2r_vorticity = FFTW(mpi_plan_many_dft_c2r)( \
3, sizes, 3, \
......@@ -101,6 +110,9 @@ fluid_solver<R>::~fluid_solver() \
FFTW(free)(this->rvorticity);\
FFTW(free)(this->cvelocity);\
FFTW(free)(this->rvelocity);\
\
delete this->cd; \
delete this->rd; \
} \
\
template<> \
......@@ -112,10 +124,16 @@ void fluid_solver<R>::step() \
/*****************************************************************************/
/* now actually use the macro defined above */
FLUID_SOLVER_DEFINITIONS(FFTW_MANGLE_FLOAT, float, fftwf_complex)
FLUID_SOLVER_DEFINITIONS(FFTW_MANGLE_DOUBLE, double, fftw_complex)
//FLUID_SOLVER_DEFINITIONS(FFTW_MANGLE_LONG_DOUBLE, long double, fftwl_complex)
//FLUID_SOLVER_DEFINITIONS(FFTW_MANGLE_QUAD, __float128, fftwq_complex)
FLUID_SOLVER_DEFINITIONS(FFTW_MANGLE_FLOAT,
float,
fftwf_complex,
MPI_REAL4,
MPI_COMPLEX8)
FLUID_SOLVER_DEFINITIONS(FFTW_MANGLE_DOUBLE,
double,
fftw_complex,
MPI_REAL8,
MPI_COMPLEX16)
/*****************************************************************************/
......
......@@ -22,6 +22,7 @@
#include <stdlib.h>
#include <iostream>
#include "field_descriptor.hpp"
#include "vector_field.hpp"
#ifndef FLUID_SOLVER
......@@ -41,7 +42,7 @@ class fluid_solver
private:
typedef rnumber cnumber[2];
public:
field_descriptor<rnumber> *fc, *fr;
field_descriptor<rnumber> *cd, *rd;
/* fields */
rnumber *rvorticity;
......
#include "vector_field.hpp"
/* destructor doesn't actually do anything */
template <class rnumber>
vector_field<rnumber>::~vector_field()
{}
template <class rnumber>
vector_field<rnumber>::vector_field(
field_descriptor<rnumber> *d,
rnumber *data)
{
this->is_real = true;
this->cdata = (rnumber (*)[2])(data);
this->rdata = data;
this->descriptor = d;
}
template <class rnumber>
vector_field<rnumber>::vector_field(
field_descriptor<rnumber> *d,
rnumber (*data)[2])
{
this->is_real = false;
this->rdata = (rnumber*)(&data[0][0]);
this->cdata = data;
this->descriptor = d;
}
template <class rnumber>
vector_field<rnumber>& vector_field<rnumber>::operator*(rnumber factor)
{
ptrdiff_t i;
for (i = 0;
i < this->descriptor->local_size * 2;
i++)
*(this->rdata + i) *= factor;
return *this;
}
template class vector_field<float>;
#include "field_descriptor.hpp"
template <class rnumber>
class vector_field
{
private:
field_descriptor<rnumber> *descriptor;
rnumber *rdata;
rnumber (*cdata)[2];
bool is_real;
public:
vector_field(field_descriptor<rnumber> *d, rnumber *data);
vector_field(field_descriptor<rnumber> *d, rnumber (*data)[2]);
~vector_field();
/* various operators */
vector_field &operator*(rnumber factor);
};
test.py 100644 → 100755
#! /usr/bin/env python2
import numpy as np
import subprocess
import pyfftw
import matplotlib.pyplot as plt
def run_test(
test_name = 'test_FFT',
......@@ -46,13 +49,29 @@ def generate_data_3D(
a[ii] = 0
return a
Kdata0 = generate_data_3D(32, p = 2).astype(np.complex64)
n = 32
Kdata00 = generate_data_3D(n, p = 2).astype(np.complex64)
Kdata01 = generate_data_3D(n, p = 2).astype(np.complex64)
Kdata02 = generate_data_3D(n, p = 2).astype(np.complex64)
Kdata0 = np.zeros(
Kdata00.shape + (3,),
Kdata00.dtype)
Kdata0[..., 0] = Kdata00
Kdata0[..., 1] = Kdata01
Kdata0[..., 2] = Kdata02
Kdata0.tofile("Kdata0")
run_test('test_FFT')
Kdata1 = np.fromfile('Kdata1', dtype = np.complex64).reshape(Kdata0.shape)
print np.max(np.abs(Kdata0 - Kdata1))
print np.max(np.abs(Kdata0))
fig = plt.figure(figsize=(12, 6))
a = fig.add_subplot(121)
a.imshow(abs(Kdata0[4, :, :, 2]), interpolation = 'none')
a = fig.add_subplot(122)
a.imshow(abs(Kdata1[4, :, :, 2]), interpolation = 'none')
fig.savefig('tmp.pdf', format = 'pdf')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment