diff --git a/src/descriptor_identifier/model/Model.cpp b/src/descriptor_identifier/model/Model.cpp index cce0459e260e6dd452c6aa1ca37218f7b41d30de..697f96f95cfb33c5ee6661e85a7664e8eb3c372b 100644 --- a/src/descriptor_identifier/model/Model.cpp +++ b/src/descriptor_identifier/model/Model.cpp @@ -86,10 +86,6 @@ double Model::eval(std::map<std::string, double> x_in_dct) const { throw std::logic_error("The value of " + in_expr + " is not in x_in_dct."); } - else if(x_in_dct.count(in_expr) > 1) - { - throw std::logic_error("Multiple values of " + in_expr + " defined in x_in_dct."); - } x_in.push_back(x_in_dct[in_expr]); } @@ -132,10 +128,6 @@ std::vector<double> Model::eval(std::map<std::string, std::vector<double>> x_in_ { throw std::logic_error("The value of " + in_expr + " is not in x_in_dct."); } - else if(x_in_dct.count(in_expr) > 1) - { - throw std::logic_error("Multiple values of " + in_expr + " defined in x_in_dct."); - } x_in.push_back(x_in_dct[in_expr]); } @@ -276,6 +268,8 @@ void Model::write_matlab_fxn(std::string fxn_filename) } boost::filesystem::path p(fxn_filename.c_str()); + std::string fxn_name = p.filename().string(); + fxn_name = fxn_name.substr(0, fxn_name.size() - 2); boost::filesystem::path parent = p.remove_filename(); if(parent.string().size() > 0) { @@ -296,7 +290,7 @@ void Model::write_matlab_fxn(std::string fxn_filename) std::transform(leaves.begin(), leaves.end(), leaves.begin(), [](std::string s){return str_utils::matlabify(s);}); // Write the header of the function - out_file_stream << "function P = " << fxn_filename.substr(0, fxn_filename.size() - 2) << "(X)\n"; + out_file_stream << "function P = " << fxn_name << "(X)\n"; out_file_stream << "% Returns the value of " << _prop_label << " = " << toString() << "\n%\n"; out_file_stream << "% X = [\n"; for(auto & leaf : leaves) @@ -373,6 +367,8 @@ void Model::populate_model(const std::string train_filename, const std::string t std::string test_error_line; std::getline(train_file_stream, prop_desc_line); int n_line = 5; + + // Legacy Code so previous model files can be read in if(!is_error_line(prop_desc_line)) { split_line = str_utils::split_string_trim(prop_desc_line); diff --git a/src/descriptor_identifier/model/ModelClassifier.cpp b/src/descriptor_identifier/model/ModelClassifier.cpp index 6876a4c2bf4189e25fa5d30bc2ade1ea7a3ecfc4..ec9fd8251f30f16204bb8adda7fbc9147728f50e 100644 --- a/src/descriptor_identifier/model/ModelClassifier.cpp +++ b/src/descriptor_identifier/model/ModelClassifier.cpp @@ -59,7 +59,9 @@ ModelClassifier::ModelClassifier( _test_n_svm_misclassified = std::accumulate(test_misclassified.begin(), test_misclassified.end(), 0); } -ModelClassifier::ModelClassifier(const std::string train_file) +ModelClassifier::ModelClassifier(const std::string train_file) : + _train_n_convex_overlap(0), + _test_n_convex_overlap(0) { populate_model(train_file); _n_class = _loss->n_class(); @@ -79,7 +81,9 @@ ModelClassifier::ModelClassifier(const std::string train_file) ); } } -ModelClassifier::ModelClassifier(const std::string train_file, std::string test_file) +ModelClassifier::ModelClassifier(const std::string train_file, std::string test_file) : + _train_n_convex_overlap(0), + _test_n_convex_overlap(0) { populate_model(train_file, test_file); _n_class = _loss->n_class(); diff --git a/src/descriptor_identifier/model/ModelClassifier.hpp b/src/descriptor_identifier/model/ModelClassifier.hpp index 6126910ccda7c56d42766ae925f4eb8f86d56442..cf6163af5f9a7071f259e799b54c9b97295338fc 100644 --- a/src/descriptor_identifier/model/ModelClassifier.hpp +++ b/src/descriptor_identifier/model/ModelClassifier.hpp @@ -43,10 +43,9 @@ class ModelClassifier : public Model int _train_n_svm_misclassified; //!< The number of points misclassified by SVM (training set) int _test_n_svm_misclassified; //!< The number of points misclassified by SVM (test set) -protected: - using Model::eval; public: + using Model::eval; /** * @brief Construct a ModelClassifier using a loss function and a set of features diff --git a/src/descriptor_identifier/model/ModelLogRegressor.cpp b/src/descriptor_identifier/model/ModelLogRegressor.cpp index a7519adca0b6d4e67f16854e9cddab98641f618d..caf1f0d3603fbf5e556f948eff2964474b9ab34f 100644 --- a/src/descriptor_identifier/model/ModelLogRegressor.cpp +++ b/src/descriptor_identifier/model/ModelLogRegressor.cpp @@ -126,7 +126,7 @@ std::string ModelLogRegressor::toLatexString() const std::stringstream model_rep; if(_fix_intercept) { - model_rep << "$\\left(" << _feats[0]->get_latex_expr() << "\\right)^{a_0}" << std::endl; + model_rep << "$\\left(" << _feats[0]->get_latex_expr() << "\\right)^{a_0}"; for(int ff = 1; ff < _feats.size(); ++ff) { model_rep << "\\left(" << _feats[ff]->get_latex_expr() << "\\right)^{a_" << ff << "}"; diff --git a/src/descriptor_identifier/solver/SISSOSolver.hpp b/src/descriptor_identifier/solver/SISSOSolver.hpp index 44955e67619b708cf37a39f9de17812186aea850..7c4bddd09e3a407d7c9b3035ba15b06113136168 100644 --- a/src/descriptor_identifier/solver/SISSOSolver.hpp +++ b/src/descriptor_identifier/solver/SISSOSolver.hpp @@ -128,6 +128,11 @@ public: */ inline int n_models_store() const {return _n_models_store;} + /** + * @brief If true the bias term is fixed at 0 + */ + inline bool fix_intercept() const {return _fix_intercept;} + // Python interface functions #ifdef PY_BINDINGS diff --git a/tests/googletest/descriptor_identification/model/test_model_classifier.cc b/tests/googletest/descriptor_identification/model/test_model_classifier.cc index ac1583394625f7332832b021e44ac54dcb94b242..3148a09cbe6ede02e85e61a0f75f652341731956 100644 --- a/tests/googletest/descriptor_identification/model/test_model_classifier.cc +++ b/tests/googletest/descriptor_identification/model/test_model_classifier.cc @@ -143,4 +143,35 @@ namespace boost::filesystem::remove("train_class_mods.dat"); boost::filesystem::remove("test_class_mods.dat"); } + + TEST_F(ModelClassifierTests, EvalTest) + { + ModelClassifier model( + "Property", + Unit("m"), + _loss, + _features, + _leave_out_inds, + _sample_ids_train, + _sample_ids_test, + task_names + ); + + model.set_task_eval(0); + std::vector<double> pt = {2.0, 2.0}; + EXPECT_THROW(model.eval(pt), std::logic_error); + + std::map<std::string, double> pt_dct; + pt_dct["A"] = 1.0; + pt_dct["B"] = 1.0; + EXPECT_THROW(model.eval(pt_dct), std::logic_error); + + std::vector<std::vector<double>> pts = {{1.0}, {1.0}}; + EXPECT_THROW(model.eval(pts), std::logic_error); + + std::map<std::string, std::vector<double>> pts_dct; + pts_dct["A"] = {1.0}; + pts_dct["B"] = {1.0}; + EXPECT_THROW(model.eval(pts_dct), std::logic_error); + } } diff --git a/tests/googletest/descriptor_identification/model/test_model_log_regressor.cc b/tests/googletest/descriptor_identification/model/test_model_log_regressor.cc index 2a46674652cad5cd05806a574e9edd8dad251918..054e89a6a51a1d2bc309e5c2f81b42d3d4621b8a 100644 --- a/tests/googletest/descriptor_identification/model/test_model_log_regressor.cc +++ b/tests/googletest/descriptor_identification/model/test_model_log_regressor.cc @@ -127,6 +127,8 @@ namespace EXPECT_LT(std::abs(model.coefs()[0][1] + 2.1), 1e-10); EXPECT_LT(std::abs(model.coefs()[0][2] - std::log(0.001)), 1e-10); + EXPECT_STREQ(model.toLatexString().c_str(), "$\\exp\\left(c_0\\right)\\left(A\\right)^{a_0}\\left(B\\right)^{a_1}$"); + model.to_file("train_false_log_reg.dat", true); model.to_file("test_false_log_reg.dat", false); } @@ -248,6 +250,8 @@ namespace model.to_file("train_true_log_reg.dat", true); model.to_file("test_true_log_reg.dat", false); + + EXPECT_STREQ(model.toLatexString().c_str(), "$\\left(A\\right)^{a_0}\\left(B\\right)^{a_1}$"); } TEST_F(ModelLogRegssorTests, FixInterceptTrueFileTest) @@ -340,5 +344,23 @@ namespace pts_dct["B"] = {1.0}; val = model.eval(pts_dct)[0]; EXPECT_LT(val - 0.00025, 1e-10); + + pt.push_back(1.0); + EXPECT_THROW(model.eval(pt), std::logic_error); + + pts.push_back({1.0}); + EXPECT_THROW(model.eval(pts), std::logic_error); + + pts.pop_back(); + pts.back().push_back(1.0); + EXPECT_THROW(model.eval(pts), std::logic_error); + + pts_dct["A"] = {1.0, 1.0}; + EXPECT_THROW(model.eval(pts_dct), std::logic_error); + + pt_dct.erase("A"); + pts_dct.erase("A"); + EXPECT_THROW(model.eval(pt_dct), std::logic_error); + EXPECT_THROW(model.eval(pts_dct), std::logic_error); } } diff --git a/tests/googletest/descriptor_identification/model/test_model_regressor.cc b/tests/googletest/descriptor_identification/model/test_model_regressor.cc index c270f02d6ad7ee5630843064bb0fc872d9c5a557..62ecc14ad655e4252e51c21b20158deb76d4ea1a 100644 --- a/tests/googletest/descriptor_identification/model/test_model_regressor.cc +++ b/tests/googletest/descriptor_identification/model/test_model_regressor.cc @@ -133,6 +133,7 @@ namespace EXPECT_LT(std::abs(model.coefs()[1][1] + 0.4), 1e-10); EXPECT_LT(std::abs(model.coefs()[1][2] + 6.5), 1e-10); + EXPECT_STREQ(model.toLatexString().c_str(), "$c_0 + a_0A + a_1B$"); model.to_file("train_false.dat", true); model.to_file("test_false.dat", false); } @@ -259,6 +260,8 @@ namespace EXPECT_LT(std::abs(model.coefs()[1][0] - 1.25), 1e-10); EXPECT_LT(std::abs(model.coefs()[1][1] + 0.4), 1e-10); + EXPECT_STREQ(model.toLatexString().c_str(), "$a_0A + a_1B$"); + model.to_file("train_true.dat", true); model.to_file("test_true.dat", false); } @@ -359,5 +362,23 @@ namespace EXPECT_LT(model.eval(pt_dct) + 5.65, 1e-10); EXPECT_LT(model.eval(pts)[0] + 5.65, 1e-10); EXPECT_LT(model.eval(pts_dct)[0] + 5.65, 1e-10); + + pt.push_back(1.0); + EXPECT_THROW(model.eval(pt), std::logic_error); + + pts.push_back({1.0}); + EXPECT_THROW(model.eval(pts), std::logic_error); + + pts.pop_back(); + pts.back().push_back(1.0); + EXPECT_THROW(model.eval(pts), std::logic_error); + + pts_dct["A"] = {1.0, 1.0}; + EXPECT_THROW(model.eval(pts_dct), std::logic_error); + + pt_dct.erase("A"); + pts_dct.erase("A"); + EXPECT_THROW(model.eval(pt_dct), std::logic_error); + EXPECT_THROW(model.eval(pts_dct), std::logic_error); } } diff --git a/tests/googletest/descriptor_identification/solver/test_sisso_classifier.cc b/tests/googletest/descriptor_identification/solver/test_sisso_classifier.cc index 1034836027c1629b6bd063584add38f3d4abba64..0b0f500183e8e5f23ec8cecd4587d39342d7e8bb 100644 --- a/tests/googletest/descriptor_identification/solver/test_sisso_classifier.cc +++ b/tests/googletest/descriptor_identification/solver/test_sisso_classifier.cc @@ -201,4 +201,38 @@ namespace boost::filesystem::remove_all("feature_space/"); boost::filesystem::remove_all("models/"); } + + TEST_F(SISSOClassifierTests, FixInterceptTrueTest) + { + std::shared_ptr<FeatureSpace> feat_space = std::make_shared<FeatureSpace>(inputs); + inputs.set_fix_intercept(true); + SISSOClassifier sisso(inputs, feat_space); + EXPECT_FALSE(sisso.fix_intercept()); + + std::vector<double> prop_comp(80, 0.0); + std::transform(inputs.prop_train().begin(), inputs.prop_train().end(), sisso.prop_train().begin(), prop_comp.begin(), [](double p1, double p2){return std::abs(p1 - p2);}); + EXPECT_FALSE(std::any_of(prop_comp.begin(), prop_comp.end(), [](double p){return p > 1e-10;})); + + std::transform(inputs.prop_test().begin(), inputs.prop_test().begin() + 10, sisso.prop_test().begin(), prop_comp.begin(), [](double p1, double p2){return std::abs(p1 - p2);}); + EXPECT_FALSE(std::any_of(prop_comp.begin(), prop_comp.begin() + 10, [](double p){return p > 1e-10;})); + + EXPECT_EQ(sisso.n_samp(), 80); + EXPECT_EQ(sisso.n_dim(), 2); + EXPECT_EQ(sisso.n_residual(), 2); + EXPECT_EQ(sisso.n_models_store(), 3); + + sisso.fit(); + + EXPECT_EQ(sisso.models().size(), 2); + EXPECT_EQ(sisso.models()[0].size(), 3); + + EXPECT_EQ(sisso.models().back()[0].n_convex_overlap_train(), 0); + EXPECT_EQ(sisso.models().back()[0].n_convex_overlap_test(), 0); + + EXPECT_EQ(sisso.models().back()[0].n_svm_misclassified_train(), 0); + EXPECT_EQ(sisso.models().back()[0].n_svm_misclassified_test(), 0); + + boost::filesystem::remove_all("feature_space/"); + boost::filesystem::remove_all("models/"); + } } diff --git a/tests/pytest/test_descriptor_identifier/matlab_functions/model_log_regressor.m b/tests/pytest/test_descriptor_identifier/matlab_functions/model_log_regressor.m index 90ff4941e2b68e2032ddd73f4f7dbeaa117aba3b..28125195cbb3a3c5213b677e9de9ca7d17d5534b 100644 --- a/tests/pytest/test_descriptor_identifier/matlab_functions/model_log_regressor.m +++ b/tests/pytest/test_descriptor_identifier/matlab_functions/model_log_regressor.m @@ -1,5 +1,5 @@ function P = model_log_regressor(X) -% Returns the value of Prop = exp(c0) * ((B + A))^a0 * ((|D - B|))^a1 +% Returns the value of Prop = ((B + A))^a0 * ((|D - B|))^a1 % % X = [ % B, @@ -17,7 +17,7 @@ D = reshape(X(:, 3), 1, []); f0 = (B + A); f1 = abs(D - B); -c0 = 2.1945699276e-13; +c0 = 0.0; a0 = 1.2000000000e+00; a1 = -1.9500000000e+00; diff --git a/tests/pytest/test_descriptor_identifier/model_files/test_classifier_fail_overlap.dat b/tests/pytest/test_descriptor_identifier/model_files/test_classifier_fail_overlap.dat new file mode 100644 index 0000000000000000000000000000000000000000..cc68ee55cc7362d377a853c5520b0f9c6eb775c6 --- /dev/null +++ b/tests/pytest/test_descriptor_identifier/model_files/test_classifier_fail_overlap.dat @@ -0,0 +1,35 @@ +# [(feat_9 - feat_8), (feat_1 * feat_0)] +# Property Label: $Class$; Unit of the Property: Unitless +# # Samples in Convex Hull Overlap Region: 5;# Samples SVM Misclassified: 0 +# Decision Boundaries +# Task w0 w1 b +# all_0, 1.326205649731981e+00, -1.744239999671528e+00, 9.075950727790907e-01, +# Feature Rung, Units, and Expressions +# 0; 1; Unitless; 9|8|sub; (feat_9 - feat_8); $\left(feat_{9} - feat_{8}\right)$; (feat_9 - feat_8); feat_9,feat_8 +# 1; 1; Unitless; 1|0|mult; (feat_1 * feat_0); $\left(feat_{1} feat_{0}\right)$; (feat_1 .* feat_0); feat_1,feat_0 +# Number of Samples Per Task +# Task, n_mats_test +# all, 20 +# Test Indexes: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 ] + +# Sample ID , Property Value , Property Value (EST) , Feature 0 Value , Feature 1 Value +0 , 0.000000000000000e+00, 0.000000000000000e+00, -2.031012155743963e-01, 2.746937325370922e+00 +1 , 0.000000000000000e+00, 0.000000000000000e+00, 3.178950007570245e-01, 3.060176523352919e+00 +2 , 0.000000000000000e+00, 0.000000000000000e+00, 1.350899479575136e+00, 1.904914737747669e+00 +3 , 0.000000000000000e+00, 0.000000000000000e+00, 3.112816979685040e-01, 2.597514970348419e+00 +4 , 0.000000000000000e+00, 0.000000000000000e+00, 3.256274649800963e-01, 3.823832277859604e+00 +5 , 0.000000000000000e+00, 0.000000000000000e+00, 7.291401227120657e-01, 2.789443909864211e+00 +6 , 0.000000000000000e+00, 0.000000000000000e+00, 3.051059498409199e-01, 2.087853428517832e+00 +7 , 0.000000000000000e+00, 0.000000000000000e+00, 5.345306546910435e-01, 2.507012794375703e+00 +8 , 0.000000000000000e+00, 0.000000000000000e+00, -5.273941950401386e-01, 1.812203393718137e+00 +9 , 0.000000000000000e+00, 0.000000000000000e+00, -1.780367164555883e-01, 3.143604947474592e+00 +10 , 1.000000000000000e+00, 1.000000000000000e+00, 1.274772829175870e+00, -2.054835335229399e+00 +11 , 1.000000000000000e+00, 1.000000000000000e+00, -1.099107097822589e+00, -2.210701955514223e+00 +12 , 1.000000000000000e+00, 1.000000000000000e+00, -2.522737334308300e-02, -2.127724030671242e+00 +13 , 1.000000000000000e+00, 1.000000000000000e+00, -8.048228984834345e-01, -3.158579181125339e+00 +14 , 1.000000000000000e+00, 1.000000000000000e+00, -1.875592975314526e-01, -1.974183109498213e+00 +15 , 1.000000000000000e+00, 1.000000000000000e+00, 6.149517560499549e-01, -1.721250821422664e+00 +16 , 1.000000000000000e+00, 1.000000000000000e+00, 1.679421386195452e+00, -2.639919246265093e+00 +17 , 1.000000000000000e+00, 1.000000000000000e+00, -3.729001722113563e-01, -2.014587145399039e+00 +18 , 1.000000000000000e+00, 1.000000000000000e+00, -1.241140800579893e+00, -1.410625471724265e+00 +19 , 1.000000000000000e+00, 1.000000000000000e+00, 3.358150821235752e-01, -1.769187643167631e+00 diff --git a/tests/pytest/test_descriptor_identifier/model_files/test_log_regressor.dat b/tests/pytest/test_descriptor_identifier/model_files/test_log_regressor.dat index 0594278ce2e9ef216951f103b3b671d521787a7b..a6810bce51929ed79dbc868e9889fdbf0e96d577 100644 --- a/tests/pytest/test_descriptor_identifier/model_files/test_log_regressor.dat +++ b/tests/pytest/test_descriptor_identifier/model_files/test_log_regressor.dat @@ -1,9 +1,9 @@ -# exp(c0) * ((B + A))^a0 * ((|D - B|))^a1 +# ((B + A))^a0 * ((|D - B|))^a1 # Property Label: $Prop$; Unit of the Property: Unitless # RMSE: 1.61410365875894e-15; Max AE: 3.10862446895044e-15 # Coefficients -# Task a0 a1 c0 -# all , 1.199999999999988e+00, -1.950000000000029e+00, 2.194569927587456e-13, +# Task a0 a1 +# all , 1.199999999999988e+00, -1.950000000000029e+00, # Feature Rung, Units, and Expressions # 0; 1; Unitless; 1|0|add; (B + A); $\left(B + A\right)$; (B + A); B,A # 1; 1; Unitless; 3|1|abd; (|D - B|); $\left(\left|D - B\right|\right)$; abs(D - B); D,B diff --git a/tests/pytest/test_descriptor_identifier/model_files/train_classifier_fail_overlap.dat b/tests/pytest/test_descriptor_identifier/model_files/train_classifier_fail_overlap.dat new file mode 100644 index 0000000000000000000000000000000000000000..c98b8489ab242abbdc608dc7a7f39790276a78dc --- /dev/null +++ b/tests/pytest/test_descriptor_identifier/model_files/train_classifier_fail_overlap.dat @@ -0,0 +1,94 @@ +# [(feat_9 - feat_8), (feat_1 * feat_0)] +# Property Label: $Class$; Unit of the Property: Unitless +# # Samples in Convex Hull Overlap Region: 5;# Samples SVM Misclassified: 0 +# Decision Boundaries +# Task w0 w1 b +# all_0, 1.326205649731981e+00, -1.744239999671528e+00, 9.075950727790907e-01, +# Feature Rung, Units, and Expressions +# 0; 1; Unitless; 9|8|sub; (feat_9 - feat_8); $\left(feat_{9} - feat_{8}\right)$; (feat_9 - feat_8); feat_9,feat_8 +# 1; 1; Unitless; 1|0|mult; (feat_1 * feat_0); $\left(feat_{1} feat_{0}\right)$; (feat_1 .* feat_0); feat_1,feat_0 +# Number of Samples Per Task +# Task, n_mats_train +# all , 80 + +# Sample ID , Property Value , Property Value (EST) , Feature 0 Value , Feature 1 Value +20 , 0.000000000000000e+00, 0.000000000000000e+00, -1.438535620957356e+00, -1.000000000000000e-04 +21 , 0.000000000000000e+00, 0.000000000000000e+00, 1.190224286778585e-01, 1.551266013053755e+00 +22 , 0.000000000000000e+00, 0.000000000000000e+00, -1.425077659929501e-01, 2.130683687424777e+00 +23 , 0.000000000000000e+00, 0.000000000000000e+00, -4.540168315339039e-01, 2.634756018203185e+00 +24 , 0.000000000000000e+00, 0.000000000000000e+00, 6.192845577547714e-01, 1.928013807462464e+00 +25 , 0.000000000000000e+00, 0.000000000000000e+00, -8.781439552075476e-01, 2.167912710491058e+00 +26 , 0.000000000000000e+00, 0.000000000000000e+00, -3.535569323953591e-01, 1.528444956153448e+00 +27 , 0.000000000000000e+00, 0.000000000000000e+00, 1.339601116269036e-01, 1.802459125102224e+00 +28 , 0.000000000000000e+00, 0.000000000000000e+00, -1.136023513610327e+00, 3.215479730179889e+00 +29 , 0.000000000000000e+00, 0.000000000000000e+00, 6.038078099519975e-01, 2.573190329369634e+00 +30 , 0.000000000000000e+00, 0.000000000000000e+00, -1.215056999365121e-01, 1.595270557148176e+00 +31 , 0.000000000000000e+00, 0.000000000000000e+00, -5.518151942675462e-01, 1.730168908796035e+00 +32 , 0.000000000000000e+00, 0.000000000000000e+00, 1.650765229640842e+00, 3.203700967878904e+00 +33 , 0.000000000000000e+00, 0.000000000000000e+00, -1.184077335611730e+00, 1.517408938916172e+00 +34 , 0.000000000000000e+00, 0.000000000000000e+00, -2.221541956972297e-01, 1.751520526988180e+00 +35 , 0.000000000000000e+00, 0.000000000000000e+00, -4.681000124924655e-01, 1.824502458636519e+00 +36 , 0.000000000000000e+00, 0.000000000000000e+00, -1.050081904577687e-01, 1.724802253834064e+00 +37 , 0.000000000000000e+00, 0.000000000000000e+00, 1.339981411155463e+00, 2.558207468331875e+00 +38 , 0.000000000000000e+00, 0.000000000000000e+00, -1.684823671566578e-01, 1.506294601099636e+00 +39 , 0.000000000000000e+00, 0.000000000000000e+00, 7.714455467324950e-01, 2.242769710603608e+00 +40 , 0.000000000000000e+00, 0.000000000000000e+00, -1.622660863089168e+00, -1.000000000000000e-04 +41 , 0.000000000000000e+00, 0.000000000000000e+00, -3.851141611133064e-01, 1.621724598445333e+00 +42 , 0.000000000000000e+00, 0.000000000000000e+00, -7.796131832604434e-01, 3.412199890833602e+00 +43 , 0.000000000000000e+00, 0.000000000000000e+00, 6.330392653717503e-01, 2.644525290379403e+00 +44 , 0.000000000000000e+00, 0.000000000000000e+00, 4.472597809306964e-01, 1.639977210905994e+00 +45 , 0.000000000000000e+00, 0.000000000000000e+00, 5.619997969970609e-01, 2.117832540122095e+00 +46 , 0.000000000000000e+00, 0.000000000000000e+00, 2.693335908708820e-01, 3.719200588905764e+00 +47 , 0.000000000000000e+00, 0.000000000000000e+00, -6.945169947212362e-01, 2.658310913357233e+00 +48 , 0.000000000000000e+00, 0.000000000000000e+00, -2.608343436805389e-01, 2.389278127799646e+00 +49 , 0.000000000000000e+00, 0.000000000000000e+00, -5.883177866461617e-01, 1.194385279781109e+00 +50 , 0.000000000000000e+00, 0.000000000000000e+00, 3.016305034685407e-01, 2.163287243369974e+00 +51 , 0.000000000000000e+00, 0.000000000000000e+00, -8.429615971293545e-01, 3.143453483796918e+00 +52 , 0.000000000000000e+00, 0.000000000000000e+00, -2.305301655628482e-01, 2.373605928069240e+00 +53 , 0.000000000000000e+00, 0.000000000000000e+00, 8.169785601229205e-01, 3.393041023148420e+00 +54 , 0.000000000000000e+00, 0.000000000000000e+00, 5.880968210966282e-01, 1.540775049989281e+00 +55 , 0.000000000000000e+00, 0.000000000000000e+00, -8.439557195782834e-01, 2.354515308140759e+00 +56 , 0.000000000000000e+00, 0.000000000000000e+00, 1.145691901781707e-01, 3.057598248128036e+00 +57 , 0.000000000000000e+00, 0.000000000000000e+00, 5.052378789612302e-01, 3.681321981867383e+00 +58 , 0.000000000000000e+00, 0.000000000000000e+00, 6.830515610497974e-01, 2.677195784075541e+00 +59 , 0.000000000000000e+00, 0.000000000000000e+00, -3.962323210078385e-01, 2.494759927195949e+00 +60 , 1.000000000000000e+00, 1.000000000000000e+00, 6.978964070103477e-02, 1.000000000000000e-04 +61 , 1.000000000000000e+00, 1.000000000000000e+00, -1.017083127934503e+00, -1.718221667867104e+00 +62 , 1.000000000000000e+00, 1.000000000000000e+00, -5.038361140934988e-02, -3.023687952494995e+00 +63 , 1.000000000000000e+00, 1.000000000000000e+00, 2.981066824725631e-02, -2.580950415579647e+00 +64 , 1.000000000000000e+00, 1.000000000000000e+00, 1.173640969341423e+00, -2.015913518051938e+00 +65 , 1.000000000000000e+00, 1.000000000000000e+00, 8.711405011252915e-02, -3.331488038371359e+00 +66 , 1.000000000000000e+00, 1.000000000000000e+00, 1.309781594456224e+00, -2.340258337136148e+00 +67 , 1.000000000000000e+00, 1.000000000000000e+00, -2.024028100438937e-01, -1.817820634181115e+00 +68 , 1.000000000000000e+00, 1.000000000000000e+00, -2.684686877159819e-01, -1.754047733957138e+00 +69 , 1.000000000000000e+00, 1.000000000000000e+00, 1.446320111150274e-01, -2.385204762866371e+00 +70 , 1.000000000000000e+00, 1.000000000000000e+00, -2.832821671606189e-01, -2.001289065001360e+00 +71 , 1.000000000000000e+00, 1.000000000000000e+00, -3.128846468236810e-01, -1.884355389893358e+00 +72 , 1.000000000000000e+00, 1.000000000000000e+00, 1.383377419667691e-01, -2.044929395284636e+00 +73 , 1.000000000000000e+00, 1.000000000000000e+00, -8.811671096262539e-01, -1.442201355797836e+00 +74 , 1.000000000000000e+00, 1.000000000000000e+00, 6.544208451153577e-02, -1.908068625698732e+00 +75 , 1.000000000000000e+00, 1.000000000000000e+00, 1.036366915038913e+00, -2.016924107725964e+00 +76 , 1.000000000000000e+00, 1.000000000000000e+00, 1.334871147341559e-01, -1.634604715418913e+00 +77 , 1.000000000000000e+00, 1.000000000000000e+00, 7.123254204690519e-01, -2.150275095672414e+00 +78 , 1.000000000000000e+00, 1.000000000000000e+00, 1.759107096658776e+00, -2.342128876529649e+00 +79 , 1.000000000000000e+00, 1.000000000000000e+00, 5.445726305421505e-02, -1.698710028312236e+00 +80 , 1.000000000000000e+00, 1.000000000000000e+00, 7.553415620459540e-01, 1.000000000000000e-04 +81 , 1.000000000000000e+00, 1.000000000000000e+00, -2.764313999225854e-02, -1.519240762581481e+00 +82 , 1.000000000000000e+00, 1.000000000000000e+00, -4.406804475082324e-01, -2.024875026617072e+00 +83 , 1.000000000000000e+00, 1.000000000000000e+00, -9.929257149617352e-01, -2.241942575124601e+00 +84 , 1.000000000000000e+00, 1.000000000000000e+00, -1.466600027097579e+00, -2.984909012607663e+00 +85 , 1.000000000000000e+00, 1.000000000000000e+00, -5.990304867158840e-01, -2.388164897459385e+00 +86 , 1.000000000000000e+00, 1.000000000000000e+00, 3.040420794796370e-01, -1.894050465195215e+00 +87 , 1.000000000000000e+00, 1.000000000000000e+00, -5.909515296974093e-01, -2.454144932345226e+00 +88 , 1.000000000000000e+00, 1.000000000000000e+00, -1.091152792865723e+00, -2.563576277205860e+00 +89 , 1.000000000000000e+00, 1.000000000000000e+00, -6.755252548115369e-01, -2.593071076035451e+00 +90 , 1.000000000000000e+00, 1.000000000000000e+00, 6.506490705074306e-01, -2.742653045444400e+00 +91 , 1.000000000000000e+00, 1.000000000000000e+00, 1.321034297602704e+00, -2.220389516459539e+00 +92 , 1.000000000000000e+00, 1.000000000000000e+00, 3.854877052279315e-02, -2.765058645463596e+00 +93 , 1.000000000000000e+00, 1.000000000000000e+00, -1.153450083656313e-01, -1.522894852256558e+00 +94 , 1.000000000000000e+00, 1.000000000000000e+00, -1.185090801197946e-01, -2.756212326574877e+00 +95 , 1.000000000000000e+00, 1.000000000000000e+00, 3.123253615639401e-01, -3.575465250587423e+00 +96 , 1.000000000000000e+00, 1.000000000000000e+00, 2.245979218959215e-02, -2.016739417798566e+00 +97 , 1.000000000000000e+00, 1.000000000000000e+00, -1.260091086602861e-01, -3.076103843283174e+00 +98 , 1.000000000000000e+00, 1.000000000000000e+00, -3.656366231240911e-01, -3.116616975503573e+00 +99 , 1.000000000000000e+00, 1.000000000000000e+00, -4.323166403743459e-01, -1.373441707188801e+00 diff --git a/tests/pytest/test_descriptor_identifier/model_files/train_log_regressor.dat b/tests/pytest/test_descriptor_identifier/model_files/train_log_regressor.dat index f33da46e688e73a3c081fb85dd020bd315e3a8b7..7a569fde93567aef38d22bb0b899e2cd50e137df 100644 --- a/tests/pytest/test_descriptor_identifier/model_files/train_log_regressor.dat +++ b/tests/pytest/test_descriptor_identifier/model_files/train_log_regressor.dat @@ -1,9 +1,9 @@ -# exp(c0) * ((B + A))^a0 * ((|D - B|))^a1 +# ((B + A))^a0 * ((|D - B|))^a1 # Property Label: $Prop$; Unit of the Property: Unitless # RMSE: 3.17364877036896e-10; Max AE: 3.06818037643097e-09 # Coefficients -# Task a0 a1 c0 -# all , 1.199999999999988e+00, -1.950000000000029e+00, 2.194569927587456e-13, +# Task a0 a1 +# all , 1.199999999999988e+00, -1.950000000000029e+00, # Feature Rung, Units, and Expressions # 0; 1; Unitless; 1|0|add; (B + A); $\left(B + A\right)$; (B + A); B,A # 1; 1; Unitless; 3|1|abd; (|D - B|); $\left(\left|D - B\right|\right)$; abs(D - B); D,B diff --git a/tests/pytest/test_descriptor_identifier/test_class_model_from_file.py b/tests/pytest/test_descriptor_identifier/test_class_model_from_file.py index 1de2f3ec5a1f65297555f16be8e23e3bf128c68a..c3c398d64e6a533da1c38e6b0b17632d90c01c2f 100644 --- a/tests/pytest/test_descriptor_identifier/test_class_model_from_file.py +++ b/tests/pytest/test_descriptor_identifier/test_class_model_from_file.py @@ -25,19 +25,38 @@ parent = Path(__file__).parent def test_class_model_from_file(): + try: + model = load_model( + str(parent / "model_files/train_classifier_fail_overlap.dat"), + str(parent / "model_files/test_classifier.dat"), + ) + raise ValueError("Model created that should fail") + except RuntimeError: + pass + + try: + model = load_model( + str(parent / "model_files/train_classifier.dat"), + str(parent / "model_files/test_classifier_fail_overlap.dat"), + ) + raise ValueError("Model created that should fail") + except RuntimeError: + pass + model = load_model( str(parent / "model_files/train_classifier.dat"), str(parent / "model_files/test_classifier.dat"), ) - mat_fxn_fn = "model_classifier.m" + mat_fxn_fn = "test_matlab_fxn/model_classifier" mat_fxn_fn_real = str(parent / "matlab_functions" / "model_classifier.m") model.write_matlab_fxn(mat_fxn_fn) actual_lines = open(mat_fxn_fn_real).readlines() - test_lines = open(mat_fxn_fn).readlines() + test_lines = open(mat_fxn_fn + ".m").readlines() - Path(mat_fxn_fn).unlink() + Path(mat_fxn_fn + ".m").unlink() + Path("test_matlab_fxn").rmdir() for tl, al in zip(test_lines, actual_lines): assert tl == al diff --git a/tests/pytest/test_descriptor_identifier/test_class_model_train_from_file.py b/tests/pytest/test_descriptor_identifier/test_class_model_train_from_file.py index 69789dae0d77152dff61581ff316fd9e3524d66c..575b841cb2320d95ecfd0a471f84836e08e5702e 100644 --- a/tests/pytest/test_descriptor_identifier/test_class_model_train_from_file.py +++ b/tests/pytest/test_descriptor_identifier/test_class_model_train_from_file.py @@ -25,6 +25,14 @@ parent = Path(__file__).parent def test_class_model_train_from_file(): + try: + model = load_model( + str(parent / "model_files/train_classifier_fail_overlap.dat"), + ) + raise ValueError("Model created that should fail") + except RuntimeError: + pass + model = load_model(str(parent / "model_files/train_classifier.dat")) assert np.all(np.abs(model.fit - model.prop_train) < 1e-7) diff --git a/tests/pytest/test_descriptor_identifier/test_classifier.py b/tests/pytest/test_descriptor_identifier/test_classifier.py index 269603494f597107bdf589c0529d39729b288172..c784219eee34f863a6ba9bd309c2b895c1b32bf2 100644 --- a/tests/pytest/test_descriptor_identifier/test_classifier.py +++ b/tests/pytest/test_descriptor_identifier/test_classifier.py @@ -121,11 +121,17 @@ def test_sisso_classifier(): shutil.rmtree("models/") shutil.rmtree("feature_space/") - assert sisso.models[0][0].n_convex_overlap_train == 4 - assert sisso.models[1][0].n_convex_overlap_train == 0 + # assert sisso.models[0][0].n_convex_overlap_train == 4 + # assert sisso.models[1][0].n_convex_overlap_train == 0 assert sisso.models[0][0].n_convex_overlap_test == 0 - assert sisso.models[1][0].n_convex_overlap_test == 0 + # assert sisso.models[1][0].n_convex_overlap_test == 0 + + assert np.all(sisso.prop_train != inputs.prop_train) + assert np.all(sisso.prop_test != inputs.prop_test) + + assert np.all(sisso.task_sizes_train == inputs.task_sizes_train) + assert np.all(sisso.task_sizes_test == inputs.task_sizes_test) if __name__ == "__main__": diff --git a/tests/pytest/test_descriptor_identifier/test_log_reg_model_from_file.py b/tests/pytest/test_descriptor_identifier/test_log_reg_model_from_file.py index a7ad48dd749cedcf2f44c2c22fac76c301e7dbca..e125c8e45f09876b0bc65ac3adadda677298178d 100644 --- a/tests/pytest/test_descriptor_identifier/test_log_reg_model_from_file.py +++ b/tests/pytest/test_descriptor_identifier/test_log_reg_model_from_file.py @@ -55,7 +55,7 @@ def test_log_reg_model_from_file(): assert model.feats[1].postfix_expr == "3|1|abd" actual_coefs = [ - [1.20, -1.95, 2.194569927587456e-13], + [1.20, -1.95], ] assert np.all( @@ -89,9 +89,10 @@ def test_log_reg_model_from_file(): assert model.percentile_95_ae < 1e-7 assert model.percentile_95_test_ae < 1e-7 + print(model.latex_str) assert ( model.latex_str - == "$\\exp\\left(c_0\\right)\\left(\\left(B + A\\right)\\right)^{a_0}\\left(\\left(\\left|D - B\\right|\\right)\\right)^{a_1}$" + == "$\\left(\\left(B + A\\right)\\right)^{a_0}\\left(\\left(\\left|D - B\\right|\\right)\\right)^{a_1}$" ) diff --git a/tests/pytest/test_descriptor_identifier/test_log_reg_train_model_from_file.py b/tests/pytest/test_descriptor_identifier/test_log_reg_train_model_from_file.py index fdb945401e9182dc1fb3a66a855ded363eef5ba6..d6ed7824cd33aa0906ade355b13878bdd81d3f7a 100644 --- a/tests/pytest/test_descriptor_identifier/test_log_reg_train_model_from_file.py +++ b/tests/pytest/test_descriptor_identifier/test_log_reg_train_model_from_file.py @@ -40,7 +40,7 @@ def test_log_reg_model_from_file(): assert model.feats[1].postfix_expr == "3|1|abd" actual_coefs = [ - [1.20, -1.95, 2.194569927587456e-13], + [1.20, -1.95], ] assert np.all( @@ -60,7 +60,7 @@ def test_log_reg_model_from_file(): assert model.percentile_95_ae < 1e-7 assert ( model.latex_str - == "$\\exp\\left(c_0\\right)\\left(\\left(B + A\\right)\\right)^{a_0}\\left(\\left(\\left|D - B\\right|\\right)\\right)^{a_1}$" + == "$\\left(\\left(B + A\\right)\\right)^{a_0}\\left(\\left(\\left|D - B\\right|\\right)\\right)^{a_1}$" ) diff --git a/tests/pytest/test_descriptor_identifier/test_log_regressor.py b/tests/pytest/test_descriptor_identifier/test_log_regressor.py index bb56c606ba77f3f97bf53b180fefe9b4eb927735..292868b6c9fa97f25f4a6d6fa06c68efcde97f54 100644 --- a/tests/pytest/test_descriptor_identifier/test_log_regressor.py +++ b/tests/pytest/test_descriptor_identifier/test_log_regressor.py @@ -83,6 +83,12 @@ def test_sisso_log_regressor(): assert sisso.models[1][0].rmse < 1e-7 assert sisso.models[1][0].test_rmse < 1e-7 + assert np.all(np.abs(sisso.prop_train - np.log(inputs.prop_train)) < 1e-10) + assert np.all(np.abs(sisso.prop_test - np.log(inputs.prop_test)) < 1e-10) + + assert np.all(sisso.task_sizes_train == inputs.task_sizes_train) + assert np.all(sisso.task_sizes_test == inputs.task_sizes_test) + if __name__ == "__main__": test_sisso_log_regressor() diff --git a/tests/pytest/test_descriptor_identifier/test_regressor.py b/tests/pytest/test_descriptor_identifier/test_regressor.py index 9bdbf81eee94ced8514a37f1911fb4fbd8e41b38..1b7be5164e212b87ffa734894917e34f8cad2c33 100644 --- a/tests/pytest/test_descriptor_identifier/test_regressor.py +++ b/tests/pytest/test_descriptor_identifier/test_regressor.py @@ -93,6 +93,12 @@ def test_sisso_regressor(): assert np.all(inputs.task_sizes_train == sisso.models[1][0].task_sizes_train) assert np.all(inputs.task_sizes_test == sisso.models[1][0].task_sizes_test) + assert np.all(np.abs(sisso.prop_train - inputs.prop_train) < 1e-10) + assert np.all(np.abs(sisso.prop_test - inputs.prop_test) < 1e-10) + + assert np.all(sisso.task_sizes_train == inputs.task_sizes_train) + assert np.all(sisso.task_sizes_test == inputs.task_sizes_test) + if __name__ == "__main__": test_sisso_regressor()