Skip to content
Snippets Groups Projects
Commit 1d1d6e85 authored by Chichi Lalescu's avatar Chichi Lalescu
Browse files

use hdf5 for parameter file

parent 741ad419
No related branches found
No related tags found
No related merge requests found
......@@ -21,6 +21,7 @@
import os
import h5py
class base(object):
def __init__(
......@@ -52,25 +53,25 @@ class base(object):
key = self.parameters.keys()
key.sort()
src_txt = ('int read_parameters()\n{\n'
+ 'int err_while_reading = 0, errr;\n'
+ 'if (myrank == {0})'.format(self.iorank)
+ '\n{\n'
+ 'FILE *par_file;\n'
+ 'char fname[{0}];\n'.format(self.string_length)
+ 'sprintf(fname, "%s_pars.txt", simname);\n'
+ 'par_file = fopen(fname, "r");\n')
+ 'sprintf(fname, "%s.h5", simname);\n'
+ 'H5::H5File par_file(fname, H5F_ACC_RDONLY);\n'
+ 'H5::DataSet dset;\n'
+ 'H5::StrType strdtype(0, H5T_VARIABLE);\n'
+ 'H5::DataSpace strdspace(H5S_SCALAR);\n'
+ 'std::string tempstr;')
#src_txt += 'std::cerr << fname << std::endl;\n'
for i in range(len(key)):
src_txt += 'dset = par_file.openDataSet("parameters/{0}");\n'.format(key[i])
if type(self.parameters[key[i]]) == int:
src_txt += ('if (fscanf(par_file, "' + key[i] + ' = %d\\n", &' + key[i] + ') != 1)\n'
+ ' err_while_reading++;\n')
src_txt += 'dset.read(&{0}, H5::PredType::NATIVE_INT);\n'.format(key[i])
elif type(self.parameters[key[i]]) == str:
src_txt += ('if (fscanf(par_file, "' + key[i] + ' = %s\\n", ' + key[i] + ') != 1)\n'
+ ' err_while_reading++;\n')
src_txt += ('dset.read(tempstr, strdtype, strdspace);\n' +
'sprintf({0}, "%s", tempstr.c_str());\n').format(key[i])
else:
src_txt += ('if (fscanf(par_file, "' + key[i] + ' = %le\\n", &' + key[i] + ') != 1)\n'
+ ' err_while_reading++;\n')
#src_txt += 'DEBUG_MSG("read ' + key[i] + ', err_while_reading is %d\\n", err_while_reading);\n'
src_txt += 'dset.read(&{0}, H5::PredType::NATIVE_DOUBLE);\n'.format(key[i])
src_txt += '}\n' # finishing if myrank == 0
# now broadcasting values to all ranks
for i in range(len(key)):
......@@ -80,13 +81,7 @@ class base(object):
src_txt += 'MPI_Bcast((void*)(' + key[i] + '), {0}, MPI_CHAR, {1}, MPI_COMM_WORLD);\n'.format(self.string_length, self.iorank)
else:
src_txt += 'MPI_Bcast((void*)(&' + key[i] + '), 1, MPI_DOUBLE, {0}, MPI_COMM_WORLD);\n'.format(self.iorank)
src_txt += ('MPI_Allreduce((void*)(&err_while_reading), (void*)(&errr), 1, MPI_INTEGER, MPI_SUM, MPI_COMM_WORLD);\n'
+ 'if (errr > 0)\n{\n'
+ 'fprintf(stderr, "Error reading parameters.\\nAttempting to exit.\\n");\n'
+ 'MPI_Finalize();\n'
+ 'exit(0);\n'
+ '}\n' # finishing errr check
+ 'return 0;\n}\n') # finishing read_parameters
src_txt += 'return 0;\n}\n' # finishing read_parameters
return src_txt
def cprint_pars(self):
key = self.parameters.keys()
......@@ -101,17 +96,11 @@ class base(object):
src_txt += 'DEBUG_MSG("'+ key[i] + ' = %le\\n", ' + key[i] + ');\n'
return src_txt
def write_par(self):
filename = self.simname + '_pars.txt'
if not os.path.isdir(self.work_dir):
os.makedirs(self.work_dir)
ofile = open(os.path.join(self.work_dir, filename), 'w')
key = self.parameters.keys()
key.sort()
for i in range(len(key)):
if type(self.parameters[key[i]]) == float:
ofile.write(('{0} = {1:e}\n').format(key[i], self.parameters[key[i]]))
else:
ofile.write('{0} = {1}\n'.format(key[i], self.parameters[key[i]]))
ofile = h5py.File(os.path.join(self.work_dir, self.simname + '.h5'), 'w')
for k in self.parameters.keys():
ofile['parameters/' + k] = self.parameters[k]
ofile.close()
return None
def read_parameters(self):
......
......@@ -41,6 +41,7 @@ class code(base):
#include "base.hpp"
#include "fluid_solver.hpp"
#include <iostream>
#include <H5Cpp.h>
#include <fftw3-mpi.h>
//endcpp
"""
......
......@@ -7,14 +7,14 @@ import os
hostname = os.getenv('HOSTNAME')
extra_compile_args = ['-mtune=native', '-ffast-math', '-std=c++11']
extra_libraries = []
extra_libraries = ['hdf5_cpp', 'hdf5']
if hostname == 'chichi-G':
include_dirs = ['/usr/local/include',
'/usr/include/mpich']
library_dirs = ['/usr/local/lib'
'/usr/lib/mpich']
extra_libraries = ['mpich']
extra_libraries += ['mpich']
if hostname in ['frontend01', 'frontend02']:
include_dirs = ['/usr/nld/mvapich2-1.9a2-gcc/include',
......@@ -26,7 +26,7 @@ if hostname in ['frontend01', 'frontend02']:
'/usr/nld/gcc-4.7.2/lib64',
'/usr/nld/fftw-3.3.3-mvapich2-1.9a2-gcc/lib',
'/usr/nld/fftw-3.3.3-float-mvapich2-1.9a2-gcc/lib']
extra_libraries = ['mpich']
extra_libraries += ['mpich']
if hostname == 'tolima':
local_install_dir = '/scratch.local/chichi/installs'
......
......@@ -89,7 +89,7 @@ libbfps = Extension(
setup(
name = 'bfps',
packages = ['bfps'],
install_requires = ['numpy>=1.8', 'matplotlib>=1.3'],
install_requires = ['numpy>=1.8', 'matplotlib>=1.3', 'h5py>=2.2.1'],
ext_modules = [libbfps],
package_data = {'bfps': header_list + ['../machine_settings.py',
'install_info.pickle']},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment