Skip to content
Snippets Groups Projects
base.py 6.10 KiB
#######################################################################
#                                                                     #
#  Copyright 2015 Max Planck Institute                                #
#                 for Dynamics and Self-Organization                  #
#                                                                     #
#  This file is part of bfps.                                         #
#                                                                     #
#  bfps is free software: you can redistribute it and/or modify       #
#  it under the terms of the GNU General Public License as published  #
#  by the Free Software Foundation, either version 3 of the License,  #
#  or (at your option) any later version.                             #
#                                                                     #
#  bfps is distributed in the hope that it will be useful,            #
#  but WITHOUT ANY WARRANTY; without even the implied warranty of     #
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the      #
#  GNU General Public License for more details.                       #
#                                                                     #
#  You should have received a copy of the GNU General Public License  #
#  along with bfps.  If not, see <http://www.gnu.org/licenses/>       #
#                                                                     #
# Contact: Cristian.Lalescu@ds.mpg.de                                 #
#                                                                     #
#######################################################################



import os
import numpy as np
import h5py
import bfps

class base(object):
    """
        This class contains simulation parameters, and handles parameter related
        functionalities of both python objects and C++ codes.
    """
    def __init__(
            self,
            work_dir = './',
            simname = 'test'):
        self.iorank = 0
        ### simulation parameters
        self.parameters = {'nx' : 32,
                           'ny' : 32,
                           'nz' : 32}
        self.string_length = 512
        self.work_dir = work_dir
        self.simname = simname
        return None
    def cdef_pars(self):
        key = self.parameters.keys()
        key.sort()
        src_txt = ''
        for i in range(len(key)):
            if type(self.parameters[key[i]]) == int:
                src_txt += 'int ' + key[i] + ';\n'
            elif type(self.parameters[key[i]]) == str:
                src_txt += 'char ' + key[i] + '[{0}];\n'.format(self.string_length)
            else:
                src_txt += 'double ' + key[i] + ';\n'
        return src_txt
    def cread_pars(self):
        key = self.parameters.keys()
        key.sort()
        src_txt = ('int read_parameters(hid_t data_file_id)\n{\n'
                 + 'hid_t dset, memtype, space;\n'
                 + 'hsize_t dims[1];\n'
                 + 'char *string_data;\n'
                 + 'std::string tempstr;\n')
        for i in range(len(key)):
            src_txt += 'dset = H5Dopen(data_file_id, "parameters/{0}", H5P_DEFAULT);\n'.format(key[i])
            if type(self.parameters[key[i]]) == int:
                src_txt += 'H5Dread(dset, H5T_NATIVE_INT, H5S_ALL, H5S_ALL, H5P_DEFAULT, &{0});\n'.format(key[i])
            elif type(self.parameters[key[i]]) == str:
                src_txt += ('space = H5Dget_space(dset);\n' +
                            'memtype = H5Tcopy(H5T_C_S1);\n' +
                            'H5Tset_size(memtype, H5T_VARIABLE);\n' +
                            'H5Sget_simple_extent_dims(space, dims, NULL);\n' +
                            'string_data = (char*)malloc(dims[0]*sizeof(char));\n' +
                            'H5Dread(dset, memtype, H5S_ALL, H5S_ALL, H5P_DEFAULT, &string_data);\n' +
                            'sprintf({0}, "%s", string_data);\n'.format(key[i]) +
                            'free(string_data);\n' +
                            'H5Sclose(space);\n' +
                            'H5Tclose(memtype);\n')
            else:
                src_txt += 'H5Dread(dset, H5T_NATIVE_DOUBLE, H5S_ALL, H5S_ALL, H5P_DEFAULT, &{0});\n'.format(key[i])
            src_txt += 'H5Dclose(dset);\n'
        src_txt += 'return 0;\n}\n' # finishing read_parameters
        return src_txt
    def cprint_pars(self):
        key = self.parameters.keys()
        key.sort()
        src_txt = ''
        for i in range(len(key)):
            if type(self.parameters[key[i]]) == int:
                src_txt += 'DEBUG_MSG("'+ key[i] + ' = %d\\n", ' + key[i] + ');\n'
            elif type(self.parameters[key[i]]) == str:
                src_txt += 'DEBUG_MSG("'+ key[i] + ' = %s\\n", ' + key[i] + ');\n'
            else:
                src_txt += 'DEBUG_MSG("'+ key[i] + ' = %g\\n", ' + key[i] + ');\n'
        return src_txt
    def write_par(self, iter0 = 0):
        if not os.path.isdir(self.work_dir):
            os.makedirs(self.work_dir)
        ofile = h5py.File(os.path.join(self.work_dir, self.simname + '.h5'), 'w-')
        for k in self.parameters.keys():
            ofile['parameters/' + k] = self.parameters[k]
        ofile['iteration'] = int(iter0)
        for k in bfps.install_info.keys():
            ofile['install_info/' + k] = str(bfps.install_info[k])
        ofile.close()
        return None
    def read_parameters(self):
        with h5py.File(os.path.join(self.work_dir, self.simname + '.h5'), 'r') as data_file:
            for k in data_file['parameters'].keys():
                if k in self.parameters.keys():
                    self.parameters[k] = type(self.parameters[k])(data_file['parameters/' + k].value)
        return None
    def pars_from_namespace(self, opt):
        new_pars = vars(opt)
        self.simname = opt.simname
        self.work_dir = opt.work_dir
        for k in self.parameters.keys():
            self.parameters[k] = new_pars[k]
        return None
    def get_coord(self, direction):
        assert(direction == 'x' or direction == 'y' or direction == 'z')
        return np.arange(.0, self.parameters['n' + direction])*2*np.pi / self.parameters['n' + direction]