From 494d2e61f89d63e393741677b8ad79a6d979984a Mon Sep 17 00:00:00 2001
From: Cristian C Lalescu <Cristian.Lalescu@ds.mpg.de>
Date: Mon, 9 May 2016 16:27:51 +0200
Subject: [PATCH] add io of vector<int> and vector<double>

---
 bfps/_base.py         | 35 ++++++++++++++++++++++--
 bfps/cpp/base.hpp     |  1 +
 bfps/cpp/io_tools.cpp | 62 +++++++++++++++++++++++++++++++++++++++++++
 bfps/cpp/io_tools.hpp | 41 ++++++++++++++++++++++++++++
 setup.py              |  1 +
 tests/test_io.py      |  2 ++
 6 files changed, 140 insertions(+), 2 deletions(-)
 create mode 100644 bfps/cpp/io_tools.cpp
 create mode 100644 bfps/cpp/io_tools.hpp

diff --git a/bfps/_base.py b/bfps/_base.py
index 395c423e..e0c686f0 100644
--- a/bfps/_base.py
+++ b/bfps/_base.py
@@ -60,6 +60,13 @@ class _base(object):
                 src_txt += 'int ' + key[i] + ';\n'
             elif type(parameters[key[i]]) == str:
                 src_txt += 'char ' + key[i] + '[{0}];\n'.format(self.string_length)
+            elif type(parameters[key[i]]) == np.ndarray:
+                src_txt += 'std::vector<'
+                if parameters[key[i]].dtype == np.float64:
+                    src_txt += 'double'
+                elif parameters[key[i]].dtype == np.int:
+                    src_txt += 'int'
+                src_txt += '> ' + key[i] + ';\n'
             else:
                 src_txt += 'double ' + key[i] + ';\n'
         return src_txt
@@ -80,7 +87,8 @@ class _base(object):
                    'sprintf(fname, "%s.h5", simname);\n' +
                    'parameter_file = H5Fopen(fname, H5F_ACC_RDONLY, H5P_DEFAULT);\n')
         for i in range(len(key)):
-            src_txt += 'dset = H5Dopen(parameter_file, "/{0}/{1}", H5P_DEFAULT);\n'.format(file_group, key[i])
+            src_txt += 'dset = H5Dopen(parameter_file, "/{0}/{1}", H5P_DEFAULT);\n'.format(
+                    file_group, key[i])
             if type(parameters[key[i]]) == int:
                 src_txt += 'H5Dread(dset, H5T_NATIVE_INT, H5S_ALL, H5S_ALL, H5P_DEFAULT, &{0});\n'.format(key[i])
             elif type(parameters[key[i]]) == str:
@@ -93,6 +101,13 @@ class _base(object):
                             'free(string_data);\n' +
                             'H5Sclose(space);\n' +
                             'H5Tclose(memtype);\n')
+            elif type(parameters[key[i]]) == np.ndarray:
+                if parameters[key[i]].dtype in [np.int, np.int64, np.int32]:
+                    template_par = 'int'
+                elif parameters[key[i]].dtype == np.float64:
+                    template_par = 'double'
+                src_txt += '{0} = read_vector<{1}>(parameter_file, "/{2}/{0}");\n'.format(
+                        key[i], template_par, file_group)
             else:
                 src_txt += 'H5Dread(dset, H5T_NATIVE_DOUBLE, H5S_ALL, H5S_ALL, H5P_DEFAULT, &{0});\n'.format(key[i])
             src_txt += 'H5Dclose(dset);\n'
@@ -107,6 +122,19 @@ class _base(object):
                 src_txt += 'DEBUG_MSG("'+ key[i] + ' = %d\\n", ' + key[i] + ');\n'
             elif type(self.parameters[key[i]]) == str:
                 src_txt += 'DEBUG_MSG("'+ key[i] + ' = %s\\n", ' + key[i] + ');\n'
+            elif type(self.parameters[key[i]]) == np.ndarray:
+                src_txt += ('for (int array_counter=0; array_counter<' +
+                            key[i] +
+                            '.size(); array_counter++)\n' +
+                            '{\n' +
+                            'DEBUG_MSG("' + key[i] + '[%d] = %')
+                if self.parameters[key[i]].dtype == np.int:
+                    src_txt += 'd'
+                elif self.parameters[key[i]].dtype == np.float64:
+                    src_txt += 'g'
+                src_txt += ('\\n", array_counter, ' +
+                            key[i] +
+                            '[array_counter]);\n}\n')
             else:
                 src_txt += 'DEBUG_MSG("'+ key[i] + ' = %g\\n", ' + key[i] + ');\n'
         return src_txt
@@ -152,7 +180,10 @@ class _base(object):
         with h5py.File(os.path.join(self.work_dir, self.simname + '.h5'), 'r') as data_file:
             for k in data_file['parameters'].keys():
                 if k in self.parameters.keys():
-                    self.parameters[k] = type(self.parameters[k])(data_file['parameters/' + k].value)
+                    if type(self.parameters[k]) in [int, str, float]:
+                        self.parameters[k] = type(self.parameters[k])(data_file['parameters/' + k].value)
+                    else:
+                        self.parameters[k] = data_file['parameters/' + k].value
         return None
     def pars_from_namespace(
             self,
diff --git a/bfps/cpp/base.hpp b/bfps/cpp/base.hpp
index f3df9c71..ee2d74d5 100644
--- a/bfps/cpp/base.hpp
+++ b/bfps/cpp/base.hpp
@@ -28,6 +28,7 @@
 #include <stdarg.h>
 #include <iostream>
 #include <typeinfo>
+#include "io_tools.hpp"
 
 #ifndef BASE
 
diff --git a/bfps/cpp/io_tools.cpp b/bfps/cpp/io_tools.cpp
new file mode 100644
index 00000000..224803dc
--- /dev/null
+++ b/bfps/cpp/io_tools.cpp
@@ -0,0 +1,62 @@
+/**********************************************************************
+*                                                                     *
+*  Copyright 2015 Max Planck Institute                                *
+*                 for Dynamics and Self-Organization                  *
+*                                                                     *
+*  This file is part of bfps.                                         *
+*                                                                     *
+*  bfps is free software: you can redistribute it and/or modify       *
+*  it under the terms of the GNU General Public License as published  *
+*  by the Free Software Foundation, either version 3 of the License,  *
+*  or (at your option) any later version.                             *
+*                                                                     *
+*  bfps is distributed in the hope that it will be useful,            *
+*  but WITHOUT ANY WARRANTY; without even the implied warranty of     *
+*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the      *
+*  GNU General Public License for more details.                       *
+*                                                                     *
+*  You should have received a copy of the GNU General Public License  *
+*  along with bfps.  If not, see <http://www.gnu.org/licenses/>       *
+*                                                                     *
+* Contact: Cristian.Lalescu@ds.mpg.de                                 *
+*                                                                     *
+**********************************************************************/
+
+
+
+#include <typeinfo>
+#include <cassert>
+#include "io_tools.hpp"
+
+
+template <typename number>
+std::vector<number> read_vector(
+        hid_t group,
+        std::string dset_name)
+{
+    std::vector<number> result;
+    hsize_t vector_length;
+    // first, read size of array
+    hid_t dset, dspace;
+    hid_t mem_dtype;
+    if (typeid(number) == typeid(int))
+        mem_dtype = H5Tcopy(H5T_NATIVE_INT);
+    else if (typeid(number) == typeid(double))
+        mem_dtype = H5Tcopy(H5T_NATIVE_DOUBLE);
+    dset = H5Dopen(group, dset_name.c_str(), H5P_DEFAULT);
+    dspace = H5Dget_space(dset);
+    assert(H5Sget_simple_extent_ndims(dspace) == 1);
+    H5Sget_simple_extent_dims(dspace, &vector_length, NULL);
+    result.resize(vector_length);
+    H5Dread(dset, mem_dtype, H5S_ALL, H5S_ALL, H5P_DEFAULT, &result.front());
+    H5Sclose(dspace);
+    H5Dclose(dset);
+    H5Tclose(mem_dtype);
+    return result;
+}
+
+template std::vector<int> read_vector(
+        hid_t, std::string);
+template std::vector<double> read_vector(
+        hid_t, std::string);
+
diff --git a/bfps/cpp/io_tools.hpp b/bfps/cpp/io_tools.hpp
new file mode 100644
index 00000000..69c0e8bb
--- /dev/null
+++ b/bfps/cpp/io_tools.hpp
@@ -0,0 +1,41 @@
+/**********************************************************************
+*                                                                     *
+*  Copyright 2015 Max Planck Institute                                *
+*                 for Dynamics and Self-Organization                  *
+*                                                                     *
+*  This file is part of bfps.                                         *
+*                                                                     *
+*  bfps is free software: you can redistribute it and/or modify       *
+*  it under the terms of the GNU General Public License as published  *
+*  by the Free Software Foundation, either version 3 of the License,  *
+*  or (at your option) any later version.                             *
+*                                                                     *
+*  bfps is distributed in the hope that it will be useful,            *
+*  but WITHOUT ANY WARRANTY; without even the implied warranty of     *
+*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the      *
+*  GNU General Public License for more details.                       *
+*                                                                     *
+*  You should have received a copy of the GNU General Public License  *
+*  along with bfps.  If not, see <http://www.gnu.org/licenses/>       *
+*                                                                     *
+* Contact: Cristian.Lalescu@ds.mpg.de                                 *
+*                                                                     *
+**********************************************************************/
+
+
+
+#include <hdf5.h>
+#include <vector>
+#include <string>
+
+#ifndef IO_TOOLS
+
+#define IO_TOOLS
+
+template <typename number>
+std::vector<number> read_vector(
+        hid_t group,
+        std::string dset_name);
+
+#endif//IO_TOOLS
+
diff --git a/setup.py b/setup.py
index b3245758..6c4c277e 100644
--- a/setup.py
+++ b/setup.py
@@ -99,6 +99,7 @@ src_file_list = ['field',
                  'interpolator_base',
                  'fluid_solver',
                  'fluid_solver_base',
+                 'io_tools',
                  'fftw_tools',
                  'spline_n1',
                  'spline_n2',
diff --git a/tests/test_io.py b/tests/test_io.py
index 327889ec..ce825c80 100644
--- a/tests/test_io.py
+++ b/tests/test_io.py
@@ -41,6 +41,8 @@ class test_io(_code):
         self.parameters['other_string_parameter'] = 'another test string'
         self.parameters['niter_todo'] = 0
         self.parameters['real_number'] = 1.21
+        self.parameters['real_array'] = np.array([1.3, 1.5, 0.4])
+        self.parameters['int_array'] = np.array([1, 3, 5, 4])
         self.main_start += self.cprint_pars()
         return None
 
-- 
GitLab