Commit 82910ba8 authored by Thomas Purcell's avatar Thomas Purcell
Browse files

Update Descriptor_Indetifier tests

Add exception tests
parent c43722be
......@@ -86,10 +86,6 @@ double Model::eval(std::map<std::string, double> x_in_dct) const
{
throw std::logic_error("The value of " + in_expr + " is not in x_in_dct.");
}
else if(x_in_dct.count(in_expr) > 1)
{
throw std::logic_error("Multiple values of " + in_expr + " defined in x_in_dct.");
}
x_in.push_back(x_in_dct[in_expr]);
}
......@@ -132,10 +128,6 @@ std::vector<double> Model::eval(std::map<std::string, std::vector<double>> x_in_
{
throw std::logic_error("The value of " + in_expr + " is not in x_in_dct.");
}
else if(x_in_dct.count(in_expr) > 1)
{
throw std::logic_error("Multiple values of " + in_expr + " defined in x_in_dct.");
}
x_in.push_back(x_in_dct[in_expr]);
}
......@@ -276,6 +268,8 @@ void Model::write_matlab_fxn(std::string fxn_filename)
}
boost::filesystem::path p(fxn_filename.c_str());
std::string fxn_name = p.filename().string();
fxn_name = fxn_name.substr(0, fxn_name.size() - 2);
boost::filesystem::path parent = p.remove_filename();
if(parent.string().size() > 0)
{
......@@ -296,7 +290,7 @@ void Model::write_matlab_fxn(std::string fxn_filename)
std::transform(leaves.begin(), leaves.end(), leaves.begin(), [](std::string s){return str_utils::matlabify(s);});
// Write the header of the function
out_file_stream << "function P = " << fxn_filename.substr(0, fxn_filename.size() - 2) << "(X)\n";
out_file_stream << "function P = " << fxn_name << "(X)\n";
out_file_stream << "% Returns the value of " << _prop_label << " = " << toString() << "\n%\n";
out_file_stream << "% X = [\n";
for(auto & leaf : leaves)
......@@ -373,6 +367,8 @@ void Model::populate_model(const std::string train_filename, const std::string t
std::string test_error_line;
std::getline(train_file_stream, prop_desc_line);
int n_line = 5;
// Legacy Code so previous model files can be read in
if(!is_error_line(prop_desc_line))
{
split_line = str_utils::split_string_trim(prop_desc_line);
......
......@@ -59,7 +59,9 @@ ModelClassifier::ModelClassifier(
_test_n_svm_misclassified = std::accumulate(test_misclassified.begin(), test_misclassified.end(), 0);
}
ModelClassifier::ModelClassifier(const std::string train_file)
ModelClassifier::ModelClassifier(const std::string train_file) :
_train_n_convex_overlap(0),
_test_n_convex_overlap(0)
{
populate_model(train_file);
_n_class = _loss->n_class();
......@@ -79,7 +81,9 @@ ModelClassifier::ModelClassifier(const std::string train_file)
);
}
}
ModelClassifier::ModelClassifier(const std::string train_file, std::string test_file)
ModelClassifier::ModelClassifier(const std::string train_file, std::string test_file) :
_train_n_convex_overlap(0),
_test_n_convex_overlap(0)
{
populate_model(train_file, test_file);
_n_class = _loss->n_class();
......
......@@ -43,10 +43,9 @@ class ModelClassifier : public Model
int _train_n_svm_misclassified; //!< The number of points misclassified by SVM (training set)
int _test_n_svm_misclassified; //!< The number of points misclassified by SVM (test set)
protected:
using Model::eval;
public:
using Model::eval;
/**
* @brief Construct a ModelClassifier using a loss function and a set of features
......
......@@ -126,7 +126,7 @@ std::string ModelLogRegressor::toLatexString() const
std::stringstream model_rep;
if(_fix_intercept)
{
model_rep << "$\\left(" << _feats[0]->get_latex_expr() << "\\right)^{a_0}" << std::endl;
model_rep << "$\\left(" << _feats[0]->get_latex_expr() << "\\right)^{a_0}";
for(int ff = 1; ff < _feats.size(); ++ff)
{
model_rep << "\\left(" << _feats[ff]->get_latex_expr() << "\\right)^{a_" << ff << "}";
......
......@@ -128,6 +128,11 @@ public:
*/
inline int n_models_store() const {return _n_models_store;}
/**
* @brief If true the bias term is fixed at 0
*/
inline bool fix_intercept() const {return _fix_intercept;}
// Python interface functions
#ifdef PY_BINDINGS
......
......@@ -143,4 +143,35 @@ namespace
boost::filesystem::remove("train_class_mods.dat");
boost::filesystem::remove("test_class_mods.dat");
}
TEST_F(ModelClassifierTests, EvalTest)
{
ModelClassifier model(
"Property",
Unit("m"),
_loss,
_features,
_leave_out_inds,
_sample_ids_train,
_sample_ids_test,
task_names
);
model.set_task_eval(0);
std::vector<double> pt = {2.0, 2.0};
EXPECT_THROW(model.eval(pt), std::logic_error);
std::map<std::string, double> pt_dct;
pt_dct["A"] = 1.0;
pt_dct["B"] = 1.0;
EXPECT_THROW(model.eval(pt_dct), std::logic_error);
std::vector<std::vector<double>> pts = {{1.0}, {1.0}};
EXPECT_THROW(model.eval(pts), std::logic_error);
std::map<std::string, std::vector<double>> pts_dct;
pts_dct["A"] = {1.0};
pts_dct["B"] = {1.0};
EXPECT_THROW(model.eval(pts_dct), std::logic_error);
}
}
......@@ -127,6 +127,8 @@ namespace
EXPECT_LT(std::abs(model.coefs()[0][1] + 2.1), 1e-10);
EXPECT_LT(std::abs(model.coefs()[0][2] - std::log(0.001)), 1e-10);
EXPECT_STREQ(model.toLatexString().c_str(), "$\\exp\\left(c_0\\right)\\left(A\\right)^{a_0}\\left(B\\right)^{a_1}$");
model.to_file("train_false_log_reg.dat", true);
model.to_file("test_false_log_reg.dat", false);
}
......@@ -248,6 +250,8 @@ namespace
model.to_file("train_true_log_reg.dat", true);
model.to_file("test_true_log_reg.dat", false);
EXPECT_STREQ(model.toLatexString().c_str(), "$\\left(A\\right)^{a_0}\\left(B\\right)^{a_1}$");
}
TEST_F(ModelLogRegssorTests, FixInterceptTrueFileTest)
......@@ -340,5 +344,23 @@ namespace
pts_dct["B"] = {1.0};
val = model.eval(pts_dct)[0];
EXPECT_LT(val - 0.00025, 1e-10);
pt.push_back(1.0);
EXPECT_THROW(model.eval(pt), std::logic_error);
pts.push_back({1.0});
EXPECT_THROW(model.eval(pts), std::logic_error);
pts.pop_back();
pts.back().push_back(1.0);
EXPECT_THROW(model.eval(pts), std::logic_error);
pts_dct["A"] = {1.0, 1.0};
EXPECT_THROW(model.eval(pts_dct), std::logic_error);
pt_dct.erase("A");
pts_dct.erase("A");
EXPECT_THROW(model.eval(pt_dct), std::logic_error);
EXPECT_THROW(model.eval(pts_dct), std::logic_error);
}
}
......@@ -133,6 +133,7 @@ namespace
EXPECT_LT(std::abs(model.coefs()[1][1] + 0.4), 1e-10);
EXPECT_LT(std::abs(model.coefs()[1][2] + 6.5), 1e-10);
EXPECT_STREQ(model.toLatexString().c_str(), "$c_0 + a_0A + a_1B$");
model.to_file("train_false.dat", true);
model.to_file("test_false.dat", false);
}
......@@ -259,6 +260,8 @@ namespace
EXPECT_LT(std::abs(model.coefs()[1][0] - 1.25), 1e-10);
EXPECT_LT(std::abs(model.coefs()[1][1] + 0.4), 1e-10);
EXPECT_STREQ(model.toLatexString().c_str(), "$a_0A + a_1B$");
model.to_file("train_true.dat", true);
model.to_file("test_true.dat", false);
}
......@@ -359,5 +362,23 @@ namespace
EXPECT_LT(model.eval(pt_dct) + 5.65, 1e-10);
EXPECT_LT(model.eval(pts)[0] + 5.65, 1e-10);
EXPECT_LT(model.eval(pts_dct)[0] + 5.65, 1e-10);
pt.push_back(1.0);
EXPECT_THROW(model.eval(pt), std::logic_error);
pts.push_back({1.0});
EXPECT_THROW(model.eval(pts), std::logic_error);
pts.pop_back();
pts.back().push_back(1.0);
EXPECT_THROW(model.eval(pts), std::logic_error);
pts_dct["A"] = {1.0, 1.0};
EXPECT_THROW(model.eval(pts_dct), std::logic_error);
pt_dct.erase("A");
pts_dct.erase("A");
EXPECT_THROW(model.eval(pt_dct), std::logic_error);
EXPECT_THROW(model.eval(pts_dct), std::logic_error);
}
}
......@@ -201,4 +201,38 @@ namespace
boost::filesystem::remove_all("feature_space/");
boost::filesystem::remove_all("models/");
}
TEST_F(SISSOClassifierTests, FixInterceptTrueTest)
{
std::shared_ptr<FeatureSpace> feat_space = std::make_shared<FeatureSpace>(inputs);
inputs.set_fix_intercept(true);
SISSOClassifier sisso(inputs, feat_space);
EXPECT_FALSE(sisso.fix_intercept());
std::vector<double> prop_comp(80, 0.0);
std::transform(inputs.prop_train().begin(), inputs.prop_train().end(), sisso.prop_train().begin(), prop_comp.begin(), [](double p1, double p2){return std::abs(p1 - p2);});
EXPECT_FALSE(std::any_of(prop_comp.begin(), prop_comp.end(), [](double p){return p > 1e-10;}));
std::transform(inputs.prop_test().begin(), inputs.prop_test().begin() + 10, sisso.prop_test().begin(), prop_comp.begin(), [](double p1, double p2){return std::abs(p1 - p2);});
EXPECT_FALSE(std::any_of(prop_comp.begin(), prop_comp.begin() + 10, [](double p){return p > 1e-10;}));
EXPECT_EQ(sisso.n_samp(), 80);
EXPECT_EQ(sisso.n_dim(), 2);
EXPECT_EQ(sisso.n_residual(), 2);
EXPECT_EQ(sisso.n_models_store(), 3);
sisso.fit();
EXPECT_EQ(sisso.models().size(), 2);
EXPECT_EQ(sisso.models()[0].size(), 3);
EXPECT_EQ(sisso.models().back()[0].n_convex_overlap_train(), 0);
EXPECT_EQ(sisso.models().back()[0].n_convex_overlap_test(), 0);
EXPECT_EQ(sisso.models().back()[0].n_svm_misclassified_train(), 0);
EXPECT_EQ(sisso.models().back()[0].n_svm_misclassified_test(), 0);
boost::filesystem::remove_all("feature_space/");
boost::filesystem::remove_all("models/");
}
}
function P = model_log_regressor(X)
% Returns the value of Prop = exp(c0) * ((B + A))^a0 * ((|D - B|))^a1
% Returns the value of Prop = ((B + A))^a0 * ((|D - B|))^a1
%
% X = [
% B,
......@@ -17,7 +17,7 @@ D = reshape(X(:, 3), 1, []);
f0 = (B + A);
f1 = abs(D - B);
c0 = 2.1945699276e-13;
c0 = 0.0;
a0 = 1.2000000000e+00;
a1 = -1.9500000000e+00;
......
# [(feat_9 - feat_8), (feat_1 * feat_0)]
# Property Label: $Class$; Unit of the Property: Unitless
# # Samples in Convex Hull Overlap Region: 5;# Samples SVM Misclassified: 0
# Decision Boundaries
# Task w0 w1 b
# all_0, 1.326205649731981e+00, -1.744239999671528e+00, 9.075950727790907e-01,
# Feature Rung, Units, and Expressions
# 0; 1; Unitless; 9|8|sub; (feat_9 - feat_8); $\left(feat_{9} - feat_{8}\right)$; (feat_9 - feat_8); feat_9,feat_8
# 1; 1; Unitless; 1|0|mult; (feat_1 * feat_0); $\left(feat_{1} feat_{0}\right)$; (feat_1 .* feat_0); feat_1,feat_0
# Number of Samples Per Task
# Task, n_mats_test
# all, 20
# Test Indexes: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 ]
# Sample ID , Property Value , Property Value (EST) , Feature 0 Value , Feature 1 Value
0 , 0.000000000000000e+00, 0.000000000000000e+00, -2.031012155743963e-01, 2.746937325370922e+00
1 , 0.000000000000000e+00, 0.000000000000000e+00, 3.178950007570245e-01, 3.060176523352919e+00
2 , 0.000000000000000e+00, 0.000000000000000e+00, 1.350899479575136e+00, 1.904914737747669e+00
3 , 0.000000000000000e+00, 0.000000000000000e+00, 3.112816979685040e-01, 2.597514970348419e+00
4 , 0.000000000000000e+00, 0.000000000000000e+00, 3.256274649800963e-01, 3.823832277859604e+00
5 , 0.000000000000000e+00, 0.000000000000000e+00, 7.291401227120657e-01, 2.789443909864211e+00
6 , 0.000000000000000e+00, 0.000000000000000e+00, 3.051059498409199e-01, 2.087853428517832e+00
7 , 0.000000000000000e+00, 0.000000000000000e+00, 5.345306546910435e-01, 2.507012794375703e+00
8 , 0.000000000000000e+00, 0.000000000000000e+00, -5.273941950401386e-01, 1.812203393718137e+00
9 , 0.000000000000000e+00, 0.000000000000000e+00, -1.780367164555883e-01, 3.143604947474592e+00
10 , 1.000000000000000e+00, 1.000000000000000e+00, 1.274772829175870e+00, -2.054835335229399e+00
11 , 1.000000000000000e+00, 1.000000000000000e+00, -1.099107097822589e+00, -2.210701955514223e+00
12 , 1.000000000000000e+00, 1.000000000000000e+00, -2.522737334308300e-02, -2.127724030671242e+00
13 , 1.000000000000000e+00, 1.000000000000000e+00, -8.048228984834345e-01, -3.158579181125339e+00
14 , 1.000000000000000e+00, 1.000000000000000e+00, -1.875592975314526e-01, -1.974183109498213e+00
15 , 1.000000000000000e+00, 1.000000000000000e+00, 6.149517560499549e-01, -1.721250821422664e+00
16 , 1.000000000000000e+00, 1.000000000000000e+00, 1.679421386195452e+00, -2.639919246265093e+00
17 , 1.000000000000000e+00, 1.000000000000000e+00, -3.729001722113563e-01, -2.014587145399039e+00
18 , 1.000000000000000e+00, 1.000000000000000e+00, -1.241140800579893e+00, -1.410625471724265e+00
19 , 1.000000000000000e+00, 1.000000000000000e+00, 3.358150821235752e-01, -1.769187643167631e+00
# exp(c0) * ((B + A))^a0 * ((|D - B|))^a1
# ((B + A))^a0 * ((|D - B|))^a1
# Property Label: $Prop$; Unit of the Property: Unitless
# RMSE: 1.61410365875894e-15; Max AE: 3.10862446895044e-15
# Coefficients
# Task a0 a1 c0
# all , 1.199999999999988e+00, -1.950000000000029e+00, 2.194569927587456e-13,
# Task a0 a1
# all , 1.199999999999988e+00, -1.950000000000029e+00,
# Feature Rung, Units, and Expressions
# 0; 1; Unitless; 1|0|add; (B + A); $\left(B + A\right)$; (B + A); B,A
# 1; 1; Unitless; 3|1|abd; (|D - B|); $\left(\left|D - B\right|\right)$; abs(D - B); D,B
......
# [(feat_9 - feat_8), (feat_1 * feat_0)]
# Property Label: $Class$; Unit of the Property: Unitless
# # Samples in Convex Hull Overlap Region: 5;# Samples SVM Misclassified: 0
# Decision Boundaries
# Task w0 w1 b
# all_0, 1.326205649731981e+00, -1.744239999671528e+00, 9.075950727790907e-01,
# Feature Rung, Units, and Expressions
# 0; 1; Unitless; 9|8|sub; (feat_9 - feat_8); $\left(feat_{9} - feat_{8}\right)$; (feat_9 - feat_8); feat_9,feat_8
# 1; 1; Unitless; 1|0|mult; (feat_1 * feat_0); $\left(feat_{1} feat_{0}\right)$; (feat_1 .* feat_0); feat_1,feat_0
# Number of Samples Per Task
# Task, n_mats_train
# all , 80
# Sample ID , Property Value , Property Value (EST) , Feature 0 Value , Feature 1 Value
20 , 0.000000000000000e+00, 0.000000000000000e+00, -1.438535620957356e+00, -1.000000000000000e-04
21 , 0.000000000000000e+00, 0.000000000000000e+00, 1.190224286778585e-01, 1.551266013053755e+00
22 , 0.000000000000000e+00, 0.000000000000000e+00, -1.425077659929501e-01, 2.130683687424777e+00
23 , 0.000000000000000e+00, 0.000000000000000e+00, -4.540168315339039e-01, 2.634756018203185e+00
24 , 0.000000000000000e+00, 0.000000000000000e+00, 6.192845577547714e-01, 1.928013807462464e+00
25 , 0.000000000000000e+00, 0.000000000000000e+00, -8.781439552075476e-01, 2.167912710491058e+00
26 , 0.000000000000000e+00, 0.000000000000000e+00, -3.535569323953591e-01, 1.528444956153448e+00
27 , 0.000000000000000e+00, 0.000000000000000e+00, 1.339601116269036e-01, 1.802459125102224e+00
28 , 0.000000000000000e+00, 0.000000000000000e+00, -1.136023513610327e+00, 3.215479730179889e+00
29 , 0.000000000000000e+00, 0.000000000000000e+00, 6.038078099519975e-01, 2.573190329369634e+00
30 , 0.000000000000000e+00, 0.000000000000000e+00, -1.215056999365121e-01, 1.595270557148176e+00
31 , 0.000000000000000e+00, 0.000000000000000e+00, -5.518151942675462e-01, 1.730168908796035e+00
32 , 0.000000000000000e+00, 0.000000000000000e+00, 1.650765229640842e+00, 3.203700967878904e+00
33 , 0.000000000000000e+00, 0.000000000000000e+00, -1.184077335611730e+00, 1.517408938916172e+00
34 , 0.000000000000000e+00, 0.000000000000000e+00, -2.221541956972297e-01, 1.751520526988180e+00
35 , 0.000000000000000e+00, 0.000000000000000e+00, -4.681000124924655e-01, 1.824502458636519e+00
36 , 0.000000000000000e+00, 0.000000000000000e+00, -1.050081904577687e-01, 1.724802253834064e+00
37 , 0.000000000000000e+00, 0.000000000000000e+00, 1.339981411155463e+00, 2.558207468331875e+00
38 , 0.000000000000000e+00, 0.000000000000000e+00, -1.684823671566578e-01, 1.506294601099636e+00
39 , 0.000000000000000e+00, 0.000000000000000e+00, 7.714455467324950e-01, 2.242769710603608e+00
40 , 0.000000000000000e+00, 0.000000000000000e+00, -1.622660863089168e+00, -1.000000000000000e-04
41 , 0.000000000000000e+00, 0.000000000000000e+00, -3.851141611133064e-01, 1.621724598445333e+00
42 , 0.000000000000000e+00, 0.000000000000000e+00, -7.796131832604434e-01, 3.412199890833602e+00
43 , 0.000000000000000e+00, 0.000000000000000e+00, 6.330392653717503e-01, 2.644525290379403e+00
44 , 0.000000000000000e+00, 0.000000000000000e+00, 4.472597809306964e-01, 1.639977210905994e+00
45 , 0.000000000000000e+00, 0.000000000000000e+00, 5.619997969970609e-01, 2.117832540122095e+00
46 , 0.000000000000000e+00, 0.000000000000000e+00, 2.693335908708820e-01, 3.719200588905764e+00
47 , 0.000000000000000e+00, 0.000000000000000e+00, -6.945169947212362e-01, 2.658310913357233e+00
48 , 0.000000000000000e+00, 0.000000000000000e+00, -2.608343436805389e-01, 2.389278127799646e+00
49 , 0.000000000000000e+00, 0.000000000000000e+00, -5.883177866461617e-01, 1.194385279781109e+00
50 , 0.000000000000000e+00, 0.000000000000000e+00, 3.016305034685407e-01, 2.163287243369974e+00
51 , 0.000000000000000e+00, 0.000000000000000e+00, -8.429615971293545e-01, 3.143453483796918e+00
52 , 0.000000000000000e+00, 0.000000000000000e+00, -2.305301655628482e-01, 2.373605928069240e+00
53 , 0.000000000000000e+00, 0.000000000000000e+00, 8.169785601229205e-01, 3.393041023148420e+00
54 , 0.000000000000000e+00, 0.000000000000000e+00, 5.880968210966282e-01, 1.540775049989281e+00
55 , 0.000000000000000e+00, 0.000000000000000e+00, -8.439557195782834e-01, 2.354515308140759e+00
56 , 0.000000000000000e+00, 0.000000000000000e+00, 1.145691901781707e-01, 3.057598248128036e+00
57 , 0.000000000000000e+00, 0.000000000000000e+00, 5.052378789612302e-01, 3.681321981867383e+00
58 , 0.000000000000000e+00, 0.000000000000000e+00, 6.830515610497974e-01, 2.677195784075541e+00
59 , 0.000000000000000e+00, 0.000000000000000e+00, -3.962323210078385e-01, 2.494759927195949e+00
60 , 1.000000000000000e+00, 1.000000000000000e+00, 6.978964070103477e-02, 1.000000000000000e-04
61 , 1.000000000000000e+00, 1.000000000000000e+00, -1.017083127934503e+00, -1.718221667867104e+00
62 , 1.000000000000000e+00, 1.000000000000000e+00, -5.038361140934988e-02, -3.023687952494995e+00
63 , 1.000000000000000e+00, 1.000000000000000e+00, 2.981066824725631e-02, -2.580950415579647e+00
64 , 1.000000000000000e+00, 1.000000000000000e+00, 1.173640969341423e+00, -2.015913518051938e+00
65 , 1.000000000000000e+00, 1.000000000000000e+00, 8.711405011252915e-02, -3.331488038371359e+00
66 , 1.000000000000000e+00, 1.000000000000000e+00, 1.309781594456224e+00, -2.340258337136148e+00
67 , 1.000000000000000e+00, 1.000000000000000e+00, -2.024028100438937e-01, -1.817820634181115e+00
68 , 1.000000000000000e+00, 1.000000000000000e+00, -2.684686877159819e-01, -1.754047733957138e+00
69 , 1.000000000000000e+00, 1.000000000000000e+00, 1.446320111150274e-01, -2.385204762866371e+00
70 , 1.000000000000000e+00, 1.000000000000000e+00, -2.832821671606189e-01, -2.001289065001360e+00
71 , 1.000000000000000e+00, 1.000000000000000e+00, -3.128846468236810e-01, -1.884355389893358e+00
72 , 1.000000000000000e+00, 1.000000000000000e+00, 1.383377419667691e-01, -2.044929395284636e+00
73 , 1.000000000000000e+00, 1.000000000000000e+00, -8.811671096262539e-01, -1.442201355797836e+00
74 , 1.000000000000000e+00, 1.000000000000000e+00, 6.544208451153577e-02, -1.908068625698732e+00
75 , 1.000000000000000e+00, 1.000000000000000e+00, 1.036366915038913e+00, -2.016924107725964e+00
76 , 1.000000000000000e+00, 1.000000000000000e+00, 1.334871147341559e-01, -1.634604715418913e+00
77 , 1.000000000000000e+00, 1.000000000000000e+00, 7.123254204690519e-01, -2.150275095672414e+00
78 , 1.000000000000000e+00, 1.000000000000000e+00, 1.759107096658776e+00, -2.342128876529649e+00
79 , 1.000000000000000e+00, 1.000000000000000e+00, 5.445726305421505e-02, -1.698710028312236e+00
80 , 1.000000000000000e+00, 1.000000000000000e+00, 7.553415620459540e-01, 1.000000000000000e-04
81 , 1.000000000000000e+00, 1.000000000000000e+00, -2.764313999225854e-02, -1.519240762581481e+00
82 , 1.000000000000000e+00, 1.000000000000000e+00, -4.406804475082324e-01, -2.024875026617072e+00
83 , 1.000000000000000e+00, 1.000000000000000e+00, -9.929257149617352e-01, -2.241942575124601e+00
84 , 1.000000000000000e+00, 1.000000000000000e+00, -1.466600027097579e+00, -2.984909012607663e+00
85 , 1.000000000000000e+00, 1.000000000000000e+00, -5.990304867158840e-01, -2.388164897459385e+00
86 , 1.000000000000000e+00, 1.000000000000000e+00, 3.040420794796370e-01, -1.894050465195215e+00
87 , 1.000000000000000e+00, 1.000000000000000e+00, -5.909515296974093e-01, -2.454144932345226e+00
88 , 1.000000000000000e+00, 1.000000000000000e+00, -1.091152792865723e+00, -2.563576277205860e+00
89 , 1.000000000000000e+00, 1.000000000000000e+00, -6.755252548115369e-01, -2.593071076035451e+00
90 , 1.000000000000000e+00, 1.000000000000000e+00, 6.506490705074306e-01, -2.742653045444400e+00
91 , 1.000000000000000e+00, 1.000000000000000e+00, 1.321034297602704e+00, -2.220389516459539e+00
92 , 1.000000000000000e+00, 1.000000000000000e+00, 3.854877052279315e-02, -2.765058645463596e+00
93 , 1.000000000000000e+00, 1.000000000000000e+00, -1.153450083656313e-01, -1.522894852256558e+00
94 , 1.000000000000000e+00, 1.000000000000000e+00, -1.185090801197946e-01, -2.756212326574877e+00
95 , 1.000000000000000e+00, 1.000000000000000e+00, 3.123253615639401e-01, -3.575465250587423e+00
96 , 1.000000000000000e+00, 1.000000000000000e+00, 2.245979218959215e-02, -2.016739417798566e+00
97 , 1.000000000000000e+00, 1.000000000000000e+00, -1.260091086602861e-01, -3.076103843283174e+00
98 , 1.000000000000000e+00, 1.000000000000000e+00, -3.656366231240911e-01, -3.116616975503573e+00
99 , 1.000000000000000e+00, 1.000000000000000e+00, -4.323166403743459e-01, -1.373441707188801e+00
# exp(c0) * ((B + A))^a0 * ((|D - B|))^a1
# ((B + A))^a0 * ((|D - B|))^a1
# Property Label: $Prop$; Unit of the Property: Unitless
# RMSE: 3.17364877036896e-10; Max AE: 3.06818037643097e-09
# Coefficients
# Task a0 a1 c0
# all , 1.199999999999988e+00, -1.950000000000029e+00, 2.194569927587456e-13,
# Task a0 a1
# all , 1.199999999999988e+00, -1.950000000000029e+00,
# Feature Rung, Units, and Expressions
# 0; 1; Unitless; 1|0|add; (B + A); $\left(B + A\right)$; (B + A); B,A
# 1; 1; Unitless; 3|1|abd; (|D - B|); $\left(\left|D - B\right|\right)$; abs(D - B); D,B
......
......@@ -25,19 +25,38 @@ parent = Path(__file__).parent
def test_class_model_from_file():
try:
model = load_model(
str(parent / "model_files/train_classifier_fail_overlap.dat"),
str(parent / "model_files/test_classifier.dat"),
)
raise ValueError("Model created that should fail")
except RuntimeError:
pass
try:
model = load_model(
str(parent / "model_files/train_classifier.dat"),
str(parent / "model_files/test_classifier_fail_overlap.dat"),
)
raise ValueError("Model created that should fail")
except RuntimeError:
pass
model = load_model(
str(parent / "model_files/train_classifier.dat"),
str(parent / "model_files/test_classifier.dat"),
)
mat_fxn_fn = "model_classifier.m"
mat_fxn_fn = "test_matlab_fxn/model_classifier"
mat_fxn_fn_real = str(parent / "matlab_functions" / "model_classifier.m")
model.write_matlab_fxn(mat_fxn_fn)
actual_lines = open(mat_fxn_fn_real).readlines()
test_lines = open(mat_fxn_fn).readlines()
test_lines = open(mat_fxn_fn + ".m").readlines()
Path(mat_fxn_fn).unlink()
Path(mat_fxn_fn + ".m").unlink()
Path("test_matlab_fxn").rmdir()
for tl, al in zip(test_lines, actual_lines):
assert tl == al
......
......@@ -25,6 +25,14 @@ parent = Path(__file__).parent
def test_class_model_train_from_file():
try:
model = load_model(
str(parent / "model_files/train_classifier_fail_overlap.dat"),
)
raise ValueError("Model created that should fail")
except RuntimeError:
pass
model = load_model(str(parent / "model_files/train_classifier.dat"))
assert np.all(np.abs(model.fit - model.prop_train) < 1e-7)
......
......@@ -121,11 +121,17 @@ def test_sisso_classifier():
shutil.rmtree("models/")
shutil.rmtree("feature_space/")
assert sisso.models[0][0].n_convex_overlap_train == 4
assert sisso.models[1][0].n_convex_overlap_train == 0
# assert sisso.models[0][0].n_convex_overlap_train == 4
# assert sisso.models[1][0].n_convex_overlap_train == 0
assert sisso.models[0][0].n_convex_overlap_test == 0
assert sisso.models[1][0].n_convex_overlap_test == 0
# assert sisso.models[1][0].n_convex_overlap_test == 0
assert np.all(sisso.prop_train != inputs.prop_train)
assert np.all(sisso.prop_test != inputs.prop_test)
assert np.all(sisso.task_sizes_train == inputs.task_sizes_train)
assert np.all(sisso.task_sizes_test == inputs.task_sizes_test)
if __name__ == "__main__":
......
......@@ -55,7 +55,7 @@ def test_log_reg_model_from_file():
assert model.feats[1].postfix_expr == "3|1|abd"
actual_coefs = [
[1.20, -1.95, 2.194569927587456e-13],
[1.20, -1.95],
]
assert np.all(
......@@ -89,9 +89,10 @@ def test_log_reg_model_from_file():
assert model.percentile_95_ae < 1e-7
assert model.percentile_95_test_ae < 1e-7
print(model.latex_str)
assert (
model.latex_str
== "$\\exp\\left(c_0\\right)\\left(\\left(B + A\\right)\\right)^{a_0}\\left(\\left(\\left|D - B\\right|\\right)\\right)^{a_1}$"
== "$\\left(\\left(B + A\\right)\\right)^{a_0}\\left(\\left(\\left|D - B\\right|\\right)\\right)^{a_1}$"
)
......
......@@ -40,7 +40,7 @@ def test_log_reg_model_from_file():
assert model.feats[1].postfix_expr == "3|1|abd"
actual_coefs = [
[1.20, -1.95, 2.194569927587456e-13],
[1.20, -1.95],
]
assert np.all(
......@@ -60,7 +60,7 @@ def test_log_reg_model_from_file():
assert model.percentile_95_ae < 1e-7
assert (
model.latex_str
== "$\\exp\\left(c_0\\right)\\left(\\left(B + A\\right)\\right)^{a_0}\\left(\\left(\\left|D - B\\right|\\right)\\right)^{a_1}$"
== "$\\left(\\left(B + A\\right)\\right)^{a_0}\\left(\\left(\\left|D - B\\right|\\right)\\right)^{a_1}$"
)
......
......@@ -83,6 +83,12 @@ def test_sisso_log_regressor():
assert sisso.models[1][0].rmse < 1e-7
assert sisso.models[1][0].test_rmse < 1e-7
assert np.all(np.abs(sisso.prop_train - np.log(inputs.prop_train)) < 1e-10)
assert np.all(np.abs(sisso.prop_test - np.log(inputs.prop_test)) < 1e-10)
assert np.all(sisso.task_sizes_train == inputs.task_sizes_train)
assert np.all(sisso.task_sizes_test == inputs.task_sizes_test)
</