From 4ae790499eff4ceea93c65694c2012e2bde1deb0 Mon Sep 17 00:00:00 2001
From: Thomas <purcell@fhi-berlin.mpg.de>
Date: Fri, 27 Aug 2021 13:46:07 +0200
Subject: [PATCH] Add input marker to set data file relative to sisso.json file
 name

used for make test
---
 src/inputs/InputParser.cpp                  | 18 ++++++++++++++++--
 src/inputs/InputParser.hpp                  |  1 +
 tests/exec_test/classification/sisso.json   |  1 +
 tests/exec_test/default/sisso.json          |  1 +
 tests/exec_test/gen_proj/sisso.json         |  1 +
 tests/exec_test/log_reg/sisso.json          |  1 +
 tests/exec_test/max_corr/sisso.json         |  1 +
 tests/exec_test/no_test_data/check_model.py |  9 +++++++++
 tests/exec_test/no_test_data/sisso.json     | 14 ++++++++++++++
 tests/exec_test/param/sisso.json            |  1 +
 tests/exec_test/reparam/sisso.json          |  1 +
 11 files changed, 47 insertions(+), 2 deletions(-)
 create mode 100644 tests/exec_test/no_test_data/check_model.py
 create mode 100644 tests/exec_test/no_test_data/sisso.json

diff --git a/src/inputs/InputParser.cpp b/src/inputs/InputParser.cpp
index 0363fc2e..c791f1e4 100644
--- a/src/inputs/InputParser.cpp
+++ b/src/inputs/InputParser.cpp
@@ -45,6 +45,7 @@ InputParser::InputParser() :
     _n_models_store(1),
     _max_param_depth(-1),
     _nlopt_seed(42),
+    _data_file_relative_to_json(false),
     _fix_intercept(false),
     _global_param_opt(false),
     _reparam_residual(false)
@@ -79,10 +80,21 @@ InputParser::InputParser(pt::ptree ip, std::string fn, std::shared_ptr<MPI_Inter
     _n_models_store(ip.get<int>("n_models_store", _n_residual)),
     _max_param_depth(ip.get<int>("max_feat_param_depth", _max_rung)),
     _nlopt_seed(ip.get<int>("nlopt_seed", 42)),
+    _data_file_relative_to_json(ip.get<bool>("data_file_relatice_to_json", false)),
     _fix_intercept(ip.get<bool>("fix_intercept", false)),
     _global_param_opt(ip.get<bool>("global_param_opt", false)),
     _reparam_residual(ip.get<bool>("reparam_residual", false))
 {
+    if(_data_file_relative_to_json)
+    {
+        if((_data_file[0] == '/') || (_data_file[0] == '\\'))
+        {
+            throw std::logic_error("The data file is an absolute path, but _data_file_relative_to_json is true.");
+        }
+        std::vector<std::string> filepath = str_utils::split_string_trim(fn, "/");
+        _data_file = str_utils::join("/", filepath.data(), filepath.size() - 1) + "/" + _data_file;
+    }
+
     // Check if param ops are passed without being build with parameterized features
     #ifndef PARAMETERIZE
     if(_allowed_param_ops.size() > 0)
@@ -806,7 +818,8 @@ void InputParser::set_phi_0(std::vector<FeatureNode> phi_0)
 void strip_comments(std::string& filename)
 {
     //Open input and output file
-    std::string newfn = "stripped_" + filename;
+    std::vector<std::string> filepath = str_utils::split_string_trim(filename, "/");
+    std::string newfn = str_utils::join("/", filepath.data(), filepath.size() - 1) + "/stripped_" + filepath.back();
     std::fstream inputfile;
     inputfile.open(filename);
     std::ofstream inputcopy;
@@ -846,7 +859,8 @@ pt::ptree get_prop_tree(std::string fn, std::shared_ptr<MPI_Interface> mpi_comm)
     }
     else
     {
-        fn = "stripped_" + fn;
+        std::vector<std::string> filepath = str_utils::split_string_trim(fn, "/");
+        fn = str_utils::join("/", filepath.data(), filepath.size() - 1) + "/stripped_" + filepath.back();
     }
 
     mpi_comm->barrier();
diff --git a/src/inputs/InputParser.hpp b/src/inputs/InputParser.hpp
index 3ee2c351..7ddd2322 100644
--- a/src/inputs/InputParser.hpp
+++ b/src/inputs/InputParser.hpp
@@ -104,6 +104,7 @@ private:
     int _nlopt_seed; //!< The seed used for the nlOpt library
 
     bool _fix_intercept; //!< If true the bias term is fixed at 0
+    bool _data_file_relative_to_json; //!< If true then the data filepath is relative to the filename path
     bool _global_param_opt; //!< True if global optimization is requested for non-linear optimization of parameters (Can break reproducibility)
     bool _reparam_residual; //!< If True then reparameterize features using the residuals of each model
 
diff --git a/tests/exec_test/classification/sisso.json b/tests/exec_test/classification/sisso.json
index 8a758b7f..196884b5 100644
--- a/tests/exec_test/classification/sisso.json
+++ b/tests/exec_test/classification/sisso.json
@@ -4,6 +4,7 @@
     "max_rung": 1,
     "n_residual": 1,
     "data_file": "data.csv",
+    "data_file_relatice_to_json": true,
     "property_key": "prop",
     "leave_out_frac": 0.2,
     "n_models_store": 1,
diff --git a/tests/exec_test/default/sisso.json b/tests/exec_test/default/sisso.json
index 48cf9c98..384e552e 100644
--- a/tests/exec_test/default/sisso.json
+++ b/tests/exec_test/default/sisso.json
@@ -4,6 +4,7 @@
     "max_rung": 2,
     "n_residual": 1,
     "data_file": "../data.csv",
+    "data_file_relatice_to_json": true,
     "property_key": "Prop",
     "task_key": "Task",
     "leave_out_frac": 0.05,
diff --git a/tests/exec_test/gen_proj/sisso.json b/tests/exec_test/gen_proj/sisso.json
index 69732ffd..f6159318 100644
--- a/tests/exec_test/gen_proj/sisso.json
+++ b/tests/exec_test/gen_proj/sisso.json
@@ -4,6 +4,7 @@
     "max_rung": 2,
     "n_residual": 1,
     "data_file": "../data.csv",
+    "data_file_relatice_to_json": true,
     "property_key": "Prop",
     "task_key": "Task",
     "leave_out_frac": 0.05,
diff --git a/tests/exec_test/log_reg/sisso.json b/tests/exec_test/log_reg/sisso.json
index dc94add2..d8265a70 100644
--- a/tests/exec_test/log_reg/sisso.json
+++ b/tests/exec_test/log_reg/sisso.json
@@ -4,6 +4,7 @@
     "max_rung": 1,
     "n_residual": 1,
     "data_file": "data.csv",
+    "data_file_relatice_to_json": true,
     "property_key": "Prop",
     "leave_out_frac": 0.05,
     "n_models_store": 1,
diff --git a/tests/exec_test/max_corr/sisso.json b/tests/exec_test/max_corr/sisso.json
index 246dd287..444e27b9 100644
--- a/tests/exec_test/max_corr/sisso.json
+++ b/tests/exec_test/max_corr/sisso.json
@@ -4,6 +4,7 @@
     "max_rung": 2,
     "n_residual": 1,
     "data_file": "../data.csv",
+    "data_file_relatice_to_json": true,
     "property_key": "Prop",
     "task_key": "Task",
     "leave_out_frac": 0.05,
diff --git a/tests/exec_test/no_test_data/check_model.py b/tests/exec_test/no_test_data/check_model.py
new file mode 100644
index 00000000..50da73db
--- /dev/null
+++ b/tests/exec_test/no_test_data/check_model.py
@@ -0,0 +1,9 @@
+from sissopp import ModelRegressor
+from pathlib import Path
+
+import numpy as np
+
+model = ModelRegressor(
+    str("models/train_dim_2_model_0.dat")
+)
+assert model.rmse < 1e-4
diff --git a/tests/exec_test/no_test_data/sisso.json b/tests/exec_test/no_test_data/sisso.json
new file mode 100644
index 00000000..7dcb3488
--- /dev/null
+++ b/tests/exec_test/no_test_data/sisso.json
@@ -0,0 +1,14 @@
+{
+    "desc_dim": 2,
+    "n_sis_select": 1,
+    "max_rung": 2,
+    "n_residual": 1,
+    "data_file": "../data.csv",
+    "data_file_relatice_to_json": true,
+    "property_key": "Prop",
+    "task_key": "Task",
+    "leave_out_frac": 0.00,
+    "n_models_store": 1,
+    "opset": ["add", "sub", "abs_diff", "mult", "div", "inv", "abs", "exp", "log", "sin", "cos", "sq", "cb", "six_pow", "sqrt", "cbrt", "neg_exp"],
+    "fix_intercept": false
+}
diff --git a/tests/exec_test/param/sisso.json b/tests/exec_test/param/sisso.json
index fe8b5275..4d656340 100644
--- a/tests/exec_test/param/sisso.json
+++ b/tests/exec_test/param/sisso.json
@@ -4,6 +4,7 @@
     "max_rung": 2,
     "n_residual": 1,
     "data_file": "data.csv",
+    "data_file_relatice_to_json": true,
     "property_key": "Prop",
     "task_key": "Task",
     "leave_out_frac": 0.05,
diff --git a/tests/exec_test/reparam/sisso.json b/tests/exec_test/reparam/sisso.json
index 73ba070e..4a19c4e7 100644
--- a/tests/exec_test/reparam/sisso.json
+++ b/tests/exec_test/reparam/sisso.json
@@ -4,6 +4,7 @@
     "max_rung": 1,
     "n_residual": 1,
     "data_file": "data.csv",
+    "data_file_relatice_to_json": true,
     "property_key": "Prop",
     "leave_out_frac": 0.05,
     "n_rung_generate": 1,
-- 
GitLab