Skip to content
Snippets Groups Projects
Commit 064037e4 authored by Thomas Purcell's avatar Thomas Purcell
Browse files

Update input test to not be hard coded relative path

Fix for INSTALL_PREFIX
parent 137a4e6a
No related branches found
No related tags found
No related merge requests found
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "mpi_interface/serialize_tuple.h" #include "mpi_interface/serialize_tuple.h"
#include "utils/project.hpp" #include "utils/project.hpp"
#include "mpi_interface/MPI_Interface.hpp"
#ifdef PY_BINDINGS #ifdef PY_BINDINGS
namespace np = boost::python::numpy; namespace np = boost::python::numpy;
......
...@@ -38,9 +38,11 @@ ...@@ -38,9 +38,11 @@
#include "feature_creation/node/FeatureNode.hpp" #include "feature_creation/node/FeatureNode.hpp"
#include "mpi_interface/MPI_Interface.hpp" #include "mpi_interface/MPI_Interface.hpp"
#if PARAMETERIZE
#ifdef PARAMETERIZE
#include "nl_opt/NLOptWrapper.hpp" #include "nl_opt/NLOptWrapper.hpp"
#endif #endif
#ifdef PY_BINDINGS #ifdef PY_BINDINGS
#include "python/py_binding_cpp_def/conversion_utils.hpp" #include "python/py_binding_cpp_def/conversion_utils.hpp"
......
...@@ -88,3 +88,13 @@ std::string str_utils::matlabify(const std::string str) ...@@ -88,3 +88,13 @@ std::string str_utils::matlabify(const std::string str)
} }
return to_ret; return to_ret;
} }
std::string str_utils::join(std::string join_section, std::string* start, int sz)
{
std::string to_ret = start[0];
for(int ii = 1; ii < sz; ++ii)
{
to_ret += join_section + start[ii];
}
return to_ret;
}
...@@ -68,6 +68,28 @@ namespace str_utils ...@@ -68,6 +68,28 @@ namespace str_utils
* @return The string with the operators replaced * @return The string with the operators replaced
*/ */
std::string op2str(const std::string str, const std::string op, const std::string op_str); std::string op2str(const std::string str, const std::string op, const std::string op_str);
/**
* @brief Join a vector of strings together
*
* @param join_section string to put in between each element
* @param start pointer to the first element of the vector to join
* @param sz Number of elements of the vector to join
* @return The joined string
*/
std::string join(std::string join_section, std::string* start, int sz);
/**
* @brief Join a vector of strings together
*
* @param join_section string to put in between each element
* @param list The vector of strings to join together
* @return The joined string
*/
inline std::string join(std::string join_section, std::vector<std::string> list)
{
return join(join_section, list.data(), list.size());
}
} }
......
...@@ -21,11 +21,11 @@ namespace ...@@ -21,11 +21,11 @@ namespace
protected: protected:
void SetUp() override void SetUp() override
{ {
std::vector<std::string> filepath = str_utils::split_string_trim(__FILE__, "/");
node_value_arrs::finialize_values_arr(); node_value_arrs::finialize_values_arr();
_sample_ids_train = {"a", "b", "c"}; _sample_ids_train = {"a", "b", "c"};
_sample_ids_test = {"d"}; _sample_ids_test = {"d"};
_task_names = {"task_1", "task_2"}; _task_names = {"task_1", "task_2"};
_allowed_param_ops = {"log"};
_allowed_ops = {"sq", "cb"}; _allowed_ops = {"sq", "cb"};
_prop_train = {1.0, 4.0, 9.0}; _prop_train = {1.0, 4.0, 9.0};
_prop_test = {16.0}; _prop_test = {16.0};
...@@ -34,8 +34,10 @@ namespace ...@@ -34,8 +34,10 @@ namespace
_task_sizes_test = {1, 0}; _task_sizes_test = {1, 0};
_phi_0 = {FeatureNode(0, "feat_1", {1.0, 2.0, 3.0}, {4.0}, Unit("m"))}; _phi_0 = {FeatureNode(0, "feat_1", {1.0, 2.0, 3.0}, {4.0}, Unit("m"))};
_prop_unit = Unit("m"); _prop_unit = Unit("m");
_filename = "googletest/inputs/sisso.json"; _data_file = (
_data_file = "googletest/inputs/data.csv"; str_utils::join("/", filepath.data(), filepath.size() - 1) +
"/data.csv"
);
_prop_key = "property"; _prop_key = "property";
_prop_label = "property"; _prop_label = "property";
_task_key = "task"; _task_key = "task";
...@@ -50,18 +52,35 @@ namespace ...@@ -50,18 +52,35 @@ namespace
_n_sis_select = 1; _n_sis_select = 1;
_n_residual = 1; _n_residual = 1;
_n_models_store = 1; _n_models_store = 1;
_fix_intercept = false;
#ifdef PARAMETERIZE
_filename = (
str_utils::join("/", filepath.data(), filepath.size() - 1) +
"/sisso_param.json"
);
_allowed_param_ops = {"log"};
_max_param_depth = 1; _max_param_depth = 1;
_nlopt_seed = 10; _nlopt_seed = 10;
_fix_intercept = false;
_global_param_opt = true; _global_param_opt = true;
_reparam_residual = true; _reparam_residual = true;
#else
_filename = (
str_utils::join("/", filepath.data(), filepath.size() - 1) +
"/sisso.json"
);
#endif
} }
std::vector<std::string> _sample_ids_train; //!< Vector storing all sample ids for the training samples std::vector<std::string> _sample_ids_train; //!< Vector storing all sample ids for the training samples
std::vector<std::string> _sample_ids_test; //!< Vector storing all sample ids for the test samples std::vector<std::string> _sample_ids_test; //!< Vector storing all sample ids for the test samples
std::vector<std::string> _task_names; //!< Vector storing the ID of the task names std::vector<std::string> _task_names; //!< Vector storing the ID of the task names
std::vector<std::string> _allowed_param_ops; //!< Vector containing all allowed operators strings for operators with free parameters
std::vector<std::string> _allowed_ops; //!< Vector containing all allowed operators strings std::vector<std::string> _allowed_ops; //!< Vector containing all allowed operators strings
std::vector<double> _prop_train; //!< The value of the property to evaluate the loss function against for the training set std::vector<double> _prop_train; //!< The value of the property to evaluate the loss function against for the training set
std::vector<double> _prop_test; //!< The value of the property to evaluate the loss function against for the test set std::vector<double> _prop_test; //!< The value of the property to evaluate the loss function against for the test set
...@@ -97,12 +116,18 @@ namespace ...@@ -97,12 +116,18 @@ namespace
int _n_samp_test; //!< Number of samples in the test set int _n_samp_test; //!< Number of samples in the test set
int _n_residual; //!< Number of residuals to pass to the next sis model int _n_residual; //!< Number of residuals to pass to the next sis model
int _n_models_store; //!< The number of models to output to files int _n_models_store; //!< The number of models to output to files
bool _fix_intercept; //!< If true the bias term is fixed at 0
#ifdef PARAMETERIZE
std::vector<std::string> _allowed_param_ops; //!< Vector containing all allowed operators strings for operators with free parameters
int _max_param_depth; //!< The maximum depth in the binary expression tree to set non-linear optimization int _max_param_depth; //!< The maximum depth in the binary expression tree to set non-linear optimization
int _nlopt_seed; //!< The seed used for the nlOpt library int _nlopt_seed; //!< The seed used for the nlOpt library
bool _fix_intercept; //!< If true the bias term is fixed at 0
bool _global_param_opt; //!< True if global optimization is requested for non-linear optimization of parameters (Can break reproducibility) bool _global_param_opt; //!< True if global optimization is requested for non-linear optimization of parameters (Can break reproducibility)
bool _reparam_residual; //!< If True then reparameterize features using the residuals of each model bool _reparam_residual; //!< If True then reparameterize features using the residuals of each model
#endif
}; };
...@@ -124,9 +149,6 @@ namespace ...@@ -124,9 +149,6 @@ namespace
inputs.set_task_names(_task_names); inputs.set_task_names(_task_names);
EXPECT_EQ(inputs.task_names()[0], _task_names[0]); EXPECT_EQ(inputs.task_names()[0], _task_names[0]);
inputs.set_allowed_param_ops(_allowed_param_ops);
EXPECT_EQ(inputs.allowed_param_ops()[0], _allowed_param_ops[0]);
inputs.set_allowed_ops(_allowed_ops); inputs.set_allowed_ops(_allowed_ops);
EXPECT_EQ(inputs.allowed_ops()[0], _allowed_ops[0]); EXPECT_EQ(inputs.allowed_ops()[0], _allowed_ops[0]);
...@@ -198,31 +220,36 @@ namespace ...@@ -198,31 +220,36 @@ namespace
inputs.set_n_models_store(_n_models_store); inputs.set_n_models_store(_n_models_store);
EXPECT_EQ(inputs.n_models_store(), _n_models_store); EXPECT_EQ(inputs.n_models_store(), _n_models_store);
inputs.set_fix_intercept(_fix_intercept);
EXPECT_EQ(inputs.fix_intercept(), _fix_intercept);
#ifdef PARAMETERIZE
inputs.set_allowed_param_ops(_allowed_param_ops);
EXPECT_EQ(inputs.allowed_param_ops()[0], _allowed_param_ops[0]);
inputs.set_max_param_depth(_max_param_depth); inputs.set_max_param_depth(_max_param_depth);
EXPECT_EQ(inputs.max_param_depth(), _max_param_depth); EXPECT_EQ(inputs.max_param_depth(), _max_param_depth);
inputs.set_nlopt_seed(_nlopt_seed); inputs.set_nlopt_seed(_nlopt_seed);
EXPECT_EQ(inputs.nlopt_seed(), _nlopt_seed); EXPECT_EQ(inputs.nlopt_seed(), _nlopt_seed);
inputs.set_fix_intercept(_fix_intercept);
EXPECT_EQ(inputs.fix_intercept(), _fix_intercept);
inputs.set_global_param_opt(_global_param_opt); inputs.set_global_param_opt(_global_param_opt);
EXPECT_EQ(inputs.global_param_opt(), _global_param_opt); EXPECT_EQ(inputs.global_param_opt(), _global_param_opt);
inputs.set_reparam_residual(_reparam_residual); inputs.set_reparam_residual(_reparam_residual);
EXPECT_EQ(inputs.reparam_residual(), _reparam_residual); EXPECT_EQ(inputs.reparam_residual(), _reparam_residual);
#endif
} }
TEST_F(InputParserTests, FileConsructor) TEST_F(InputParserTests, FileConsructor)
{ {
boost::property_tree::ptree propTree; boost::property_tree::ptree propTree;
boost::property_tree::json_parser::read_json(_filename, propTree); boost::property_tree::json_parser::read_json(_filename, propTree);
propTree.put("data_file", _data_file);
InputParser inputs(propTree, _filename, mpi_setup::comm); InputParser inputs(propTree, _filename, mpi_setup::comm);
EXPECT_EQ(inputs.sample_ids_train()[0], _sample_ids_train[0]); EXPECT_EQ(inputs.sample_ids_train()[0], _sample_ids_train[0]);
EXPECT_EQ(inputs.sample_ids_test()[0], _sample_ids_test[0]); EXPECT_EQ(inputs.sample_ids_test()[0], _sample_ids_test[0]);
EXPECT_EQ(inputs.task_names()[0], _task_names[0]); EXPECT_EQ(inputs.task_names()[0], _task_names[0]);
EXPECT_EQ(inputs.allowed_param_ops()[0], _allowed_param_ops[0]);
EXPECT_EQ(inputs.allowed_ops()[0], _allowed_ops[0]); EXPECT_EQ(inputs.allowed_ops()[0], _allowed_ops[0]);
EXPECT_EQ(inputs.prop_train()[0], _prop_train[0]); EXPECT_EQ(inputs.prop_train()[0], _prop_train[0]);
EXPECT_EQ(inputs.prop_test()[0], _prop_test[0]); EXPECT_EQ(inputs.prop_test()[0], _prop_test[0]);
...@@ -251,10 +278,14 @@ namespace ...@@ -251,10 +278,14 @@ namespace
EXPECT_EQ(inputs.n_sis_select(), _n_sis_select); EXPECT_EQ(inputs.n_sis_select(), _n_sis_select);
EXPECT_EQ(inputs.n_residual(), _n_residual); EXPECT_EQ(inputs.n_residual(), _n_residual);
EXPECT_EQ(inputs.n_models_store(), _n_models_store); EXPECT_EQ(inputs.n_models_store(), _n_models_store);
EXPECT_EQ(inputs.fix_intercept(), _fix_intercept);
#ifdef PARAMETERIZE
EXPECT_EQ(inputs.allowed_param_ops()[0], _allowed_param_ops[0]);
EXPECT_EQ(inputs.max_param_depth(), _max_param_depth); EXPECT_EQ(inputs.max_param_depth(), _max_param_depth);
EXPECT_EQ(inputs.nlopt_seed(), _nlopt_seed); EXPECT_EQ(inputs.nlopt_seed(), _nlopt_seed);
EXPECT_EQ(inputs.fix_intercept(), _fix_intercept);
EXPECT_EQ(inputs.global_param_opt(), _global_param_opt); EXPECT_EQ(inputs.global_param_opt(), _global_param_opt);
EXPECT_EQ(inputs.reparam_residual(), _reparam_residual); EXPECT_EQ(inputs.reparam_residual(), _reparam_residual);
#endif
} }
} }
...@@ -10,12 +10,7 @@ ...@@ -10,12 +10,7 @@
"task_key": "task", "task_key": "task",
"leave_out_inds": [3], "leave_out_inds": [3],
"opset": ["sq", "cb"], "opset": ["sq", "cb"],
"param_opset": ["log"],
"fix_intercept": false, "fix_intercept": false,
"min_abs_feat_val": 1e-5, "min_abs_feat_val": 1e-5,
"max_abs_feat_val": 1e8, "max_abs_feat_val": 1e8
"max_param_depth": 1,
"nlopt_seed": 10,
"global_param_opt": true,
"reparam_residual": true
} }
{
"desc_dim": 2,
"n_sis_select": 1,
"max_rung": 1,
"n_residual": 1,
"n_models_store": 1,
"n_rung_store": 1,
"data_file": "googletest/inputs/data.csv",
"property_key": "property",
"task_key": "task",
"leave_out_inds": [3],
"opset": ["sq", "cb"],
"param_opset": ["log"],
"fix_intercept": false,
"min_abs_feat_val": 1e-5,
"max_abs_feat_val": 1e8,
"max_param_depth": 1,
"nlopt_seed": 10,
"global_param_opt": true,
"reparam_residual": true
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment