From 02fbe149ca6b1932f3343403560d4700d25e3a32 Mon Sep 17 00:00:00 2001 From: Thomas <purcell@fhi-berlin.mpg.de> Date: Mon, 30 Aug 2021 09:12:09 +0200 Subject: [PATCH] Further modifications and bug fixes from coverage tests --- .gitlab-ci.yml | 2 +- src/feature_creation/node/FeatureNode.hpp | 4 +- src/feature_creation/node/ModelNode.cpp | 2 +- src/feature_creation/node/ModelNode.hpp | 13 +-- src/feature_creation/node/Node.hpp | 9 +- .../node/operator_nodes/OperatorNode.hpp | 3 +- .../abs/absolute_value.hpp | 3 +- .../abs_diff/absolute_difference.hpp | 3 +- .../allowed_operator_nodes/add/add.hpp | 3 +- .../allowed_operator_nodes/cb/cube.hpp | 3 +- .../allowed_operator_nodes/cbrt/cube_root.hpp | 3 +- .../allowed_operator_nodes/cos/cos.hpp | 3 +- .../allowed_operator_nodes/div/divide.hpp | 3 +- .../exp/exponential.hpp | 3 +- .../allowed_operator_nodes/inv/inverse.hpp | 3 +- .../allowed_operator_nodes/log/log.hpp | 3 +- .../allowed_operator_nodes/mult/multiply.hpp | 3 +- .../neg_exp/negative_exponential.hpp | 3 +- .../allowed_operator_nodes/sin/sin.hpp | 3 +- .../six_pow/sixth_power.hpp | 3 +- .../allowed_operator_nodes/sq/square.hpp | 3 +- .../sqrt/square_root.hpp | 3 +- .../allowed_operator_nodes/sub/subtract.hpp | 3 +- .../bindings_docstring_keyed.cpp | 35 ++---- .../bindings_docstring_keyed.hpp | 8 +- .../feature_creation/ModelNode.cpp | 8 +- tests/exec_test/default/sisso.json | 1 - .../feature_creation/units/test_untis.cc | 20 ++++ .../test_abs_diff_node.py | 8 +- .../test_feat_generation/test_abs_node.py | 2 +- .../test_feat_generation/test_add_node.py | 7 +- .../test_feat_generation/test_model_node.py | 59 ++++++++++ .../test_feature_space/data.csv | 101 ++++++++++++++++++ .../test_feature_space/selected_features.txt | 5 + ...st_gen_feature_space_selected_from_file.py | 36 +++++++ .../test_parameterize/test_param_abs.py | 9 +- 36 files changed, 317 insertions(+), 66 deletions(-) create mode 100644 tests/pytest/test_feature_creation/test_feature_space/data.csv create mode 100644 tests/pytest/test_feature_creation/test_feature_space/selected_features.txt create mode 100644 tests/pytest/test_feature_creation/test_feature_space/test_gen_feature_space_selected_from_file.py diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 8fb53dee..744859c8 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -406,7 +406,7 @@ build-gnu-gcov: - mkdir build_gcov/ - cd build_gcov/ - cmake -DCMAKE_CXX_COMPILER=g++ -DCMAKE_C_COMPILE=gcc -DCMAKE_BUILD_TYPE="Coverage" -DBUILD_TESTS=ON -DBUILD_PARAMS=ON -DBUILD_PYTHON=ON -DCMAKE_INSTALL_PREFIX=../gnu_gcov/ ../ - - make + - make install - make coverage_xml - cd ../ coverage: /^\s*lines:\s*\d+.\d+\%/ diff --git a/src/feature_creation/node/FeatureNode.hpp b/src/feature_creation/node/FeatureNode.hpp index 88e1cb6a..36e03de3 100644 --- a/src/feature_creation/node/FeatureNode.hpp +++ b/src/feature_creation/node/FeatureNode.hpp @@ -64,6 +64,8 @@ protected: std::string _expr; //!< Expression of the feature public: + using Node::n_leaves; + using Node::rung; /** * @brief Base Constructor * @details This is only used for serialization @@ -286,7 +288,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(int cur_rung = 0) const {return cur_rung;} + inline int rung(int cur_rung) const {return cur_rung;} /** * @brief Update the primary feature decomposition of a feature diff --git a/src/feature_creation/node/ModelNode.cpp b/src/feature_creation/node/ModelNode.cpp index d0c3ad66..48eec97f 100644 --- a/src/feature_creation/node/ModelNode.cpp +++ b/src/feature_creation/node/ModelNode.cpp @@ -26,7 +26,7 @@ ModelNode::ModelNode() ModelNode::ModelNode( const unsigned long int feat_ind, - const unsigned long int rung, + const int rung, const std::string expr, const std::string latex_expr, const std::string postfix_expr, diff --git a/src/feature_creation/node/ModelNode.hpp b/src/feature_creation/node/ModelNode.hpp index ec307c8d..831cbc81 100644 --- a/src/feature_creation/node/ModelNode.hpp +++ b/src/feature_creation/node/ModelNode.hpp @@ -77,6 +77,7 @@ protected: int _n_leaves; //!< The number of primary features (non-unique) this feature contains (The number of leaves of the Binary Expression Tree) public: using Node::n_leaves; + using Node::rung; /** * @brief Base Constructor * @details This is only used for serialization @@ -99,7 +100,7 @@ public: */ ModelNode( const unsigned long int feat_ind, - const unsigned long int rung, + const int rung, const std::string expr, const std::string latex_expr, const std::string postfix_expr, @@ -335,7 +336,7 @@ public: /** * @brief return the rung of the feature (Height of the binary expression tree - 1) */ - inline int rung(int cur_rung = 0) const {return _rung;} + inline int rung(int cur_rung) const {return _rung;} /** * @brief Update the primary feature decomposition of a feature @@ -401,8 +402,8 @@ public: * @param unit (Unit) Unit of the feature */ ModelNode( - const unsigned long int feat_ind, - const unsigned long int rung, + const int feat_ind, + const int rung, const std::string expr, const std::string latex_expr, const std::string postfix_expr, @@ -429,8 +430,8 @@ public: * @param unit (Unit) Unit of the feature */ ModelNode( - const unsigned long int feat_ind, - const unsigned long int rung, + const int feat_ind, + const int rung, const std::string expr, const std::string latex_expr, const std::string postfix_expr, diff --git a/src/feature_creation/node/Node.hpp b/src/feature_creation/node/Node.hpp index 8b65134b..06f385ce 100644 --- a/src/feature_creation/node/Node.hpp +++ b/src/feature_creation/node/Node.hpp @@ -329,13 +329,20 @@ public: */ virtual NODE_TYPE type() const = 0; + /** + * @brief Return the rung of the feature (Height of the binary expression tree - 1) + * + * @param cur_rung (int) A recursive helper counter for the rung + */ + virtual int rung(const int cur_rung) const = 0; + // DocString: node_rung /** * @brief Return the rung of the feature (Height of the binary expression tree - 1) * * @param cur_rung (int) A recursive helper counter for the rung */ - virtual int rung(const int cur_rung = 0) const = 0; + inline int rung() const {return rung(0);} /** * @brief Get the primary feature decomposition of a feature diff --git a/src/feature_creation/node/operator_nodes/OperatorNode.hpp b/src/feature_creation/node/operator_nodes/OperatorNode.hpp index 9cbb8243..fd742686 100644 --- a/src/feature_creation/node/operator_nodes/OperatorNode.hpp +++ b/src/feature_creation/node/operator_nodes/OperatorNode.hpp @@ -72,6 +72,7 @@ protected: public: using Node::n_leaves; + using Node::rung; /** * @brief Base Constructor * @details This is only used for serialization @@ -331,7 +332,7 @@ public: * * @param cur_rung (int) A recursive helper counter for the rung */ - virtual int rung(int cur_rung = 0) const = 0; + virtual int rung(int cur_rung) const = 0; /** * @brief Returns the type of node this is diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/abs/absolute_value.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/abs/absolute_value.hpp index e61c1886..80dc6ccf 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/abs/absolute_value.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/abs/absolute_value.hpp @@ -47,6 +47,7 @@ class AbsNode: public OperatorNode<1> ar & boost::serialization::base_object<OperatorNode>(*this); } public: + using Node::rung; /** * @brief Base Constructor @@ -132,7 +133,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return _feats[0]->rung(cur_rung + 1);} + inline int rung(const int cur_rung) const {return _feats[0]->rung(cur_rung + 1);} /** * @brief Returns the type of node this is diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/abs_diff/absolute_difference.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/abs_diff/absolute_difference.hpp index b75a2c43..5e880c1a 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/abs_diff/absolute_difference.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/abs_diff/absolute_difference.hpp @@ -48,6 +48,7 @@ class AbsDiffNode: public OperatorNode<2> } public: + using Node::rung; /** * @brief Base Constructor @@ -138,7 +139,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return std::max(_feats[0]->rung(cur_rung + 1), _feats[1]->rung(cur_rung + 1));} + inline int rung(const int cur_rung) const {return std::max(_feats[0]->rung(cur_rung + 1), _feats[1]->rung(cur_rung + 1));} /** * @brief Returns the type of node this is diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add/add.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add/add.hpp index 6d200d00..bffbece2 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add/add.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add/add.hpp @@ -47,6 +47,7 @@ class AddNode: public OperatorNode<2> ar & boost::serialization::base_object<OperatorNode>(*this); } public: + using Node::rung; /** * @brief Base Constructor @@ -122,7 +123,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return std::max(_feats[0]->rung(cur_rung + 1), _feats[1]->rung(cur_rung + 1));} + inline int rung(const int cur_rung) const {return std::max(_feats[0]->rung(cur_rung + 1), _feats[1]->rung(cur_rung + 1));} /** * @brief Get the valid LaTeX expression that represents the feature diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cb/cube.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cb/cube.hpp index 72d1a34e..d74dd28a 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cb/cube.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cb/cube.hpp @@ -47,6 +47,7 @@ class CbNode: public OperatorNode<1> ar & boost::serialization::base_object<OperatorNode>(*this); } public: + using Node::rung; /** * @brief Base Constructor @@ -132,7 +133,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return _feats[0]->rung(cur_rung + 1);} + inline int rung(const int cur_rung) const {return _feats[0]->rung(cur_rung + 1);} /** * @brief Returns the type of node this is diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cbrt/cube_root.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cbrt/cube_root.hpp index 86f6ef30..c12dd299 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cbrt/cube_root.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cbrt/cube_root.hpp @@ -47,6 +47,7 @@ class CbrtNode: public OperatorNode<1> ar & boost::serialization::base_object<OperatorNode>(*this); } public: + using Node::rung; /** * @brief Base Constructor @@ -132,7 +133,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return _feats[0]->rung(cur_rung + 1);} + inline int rung(const int cur_rung) const {return _feats[0]->rung(cur_rung + 1);} /** * @brief Returns the type of node this is diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos/cos.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos/cos.hpp index 84566154..0c57eb6f 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos/cos.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos/cos.hpp @@ -47,6 +47,7 @@ class CosNode: public OperatorNode<1> ar & boost::serialization::base_object<OperatorNode>(*this); } public: + using Node::rung; /** * @brief Base Constructor @@ -132,7 +133,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return _feats[0]->rung(cur_rung + 1);} + inline int rung(const int cur_rung) const {return _feats[0]->rung(cur_rung + 1);} /** * @brief Returns the type of node this is diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/div/divide.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/div/divide.hpp index e8684dee..01f3c460 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/div/divide.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/div/divide.hpp @@ -47,6 +47,7 @@ class DivNode: public OperatorNode<2> ar & boost::serialization::base_object<OperatorNode>(*this); } public: + using Node::rung; /** * @brief Base Constructor @@ -137,7 +138,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return std::max(_feats[0]->rung(cur_rung + 1), _feats[1]->rung(cur_rung + 1));} + inline int rung(const int cur_rung) const {return std::max(_feats[0]->rung(cur_rung + 1), _feats[1]->rung(cur_rung + 1));} /** * @brief Returns the type of node this is diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exp/exponential.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exp/exponential.hpp index e3bee539..e0387165 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exp/exponential.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exp/exponential.hpp @@ -47,6 +47,7 @@ class ExpNode: public OperatorNode<1> ar & boost::serialization::base_object<OperatorNode>(*this); } public: + using Node::rung; /** * @brief Base Constructor @@ -132,7 +133,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return _feats[0]->rung(cur_rung + 1);} + inline int rung(const int cur_rung) const {return _feats[0]->rung(cur_rung + 1);} /** * @brief Returns the type of node this is diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inv/inverse.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inv/inverse.hpp index 5b04ee52..973fa42a 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inv/inverse.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inv/inverse.hpp @@ -43,6 +43,7 @@ class InvNode: public OperatorNode<1> } public: + using Node::rung; /** * @brief Base Constructor @@ -129,7 +130,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return _feats[0]->rung(cur_rung + 1);} + inline int rung(const int cur_rung) const {return _feats[0]->rung(cur_rung + 1);} /** * @brief Returns the type of node this is diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log/log.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log/log.hpp index 233e3dbe..246d754f 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log/log.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log/log.hpp @@ -47,6 +47,7 @@ class LogNode: public OperatorNode<1> ar & boost::serialization::base_object<OperatorNode>(*this); } public: + using Node::rung; /** * @brief Base Constructor @@ -138,7 +139,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return _feats[0]->rung(cur_rung + 1);} + inline int rung(const int cur_rung) const {return _feats[0]->rung(cur_rung + 1);} /** * @brief Returns the type of node this is diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/mult/multiply.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/mult/multiply.hpp index 284aacea..a43fdb8f 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/mult/multiply.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/mult/multiply.hpp @@ -48,6 +48,7 @@ class MultNode: public OperatorNode<2> } public: + using Node::rung; /** * @brief Base Constructor @@ -137,7 +138,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return std::max(_feats[0]->rung(cur_rung + 1), _feats[1]->rung(cur_rung + 1));} + inline int rung(const int cur_rung) const {return std::max(_feats[0]->rung(cur_rung + 1), _feats[1]->rung(cur_rung + 1));} /** * @brief Returns the type of node this is diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/neg_exp/negative_exponential.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/neg_exp/negative_exponential.hpp index 420c8fe3..97d4ab99 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/neg_exp/negative_exponential.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/neg_exp/negative_exponential.hpp @@ -48,6 +48,7 @@ class NegExpNode: public OperatorNode<1> } public: + using Node::rung; /** * @brief Base Constructor @@ -133,7 +134,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return _feats[0]->rung(cur_rung + 1);} + inline int rung(const int cur_rung) const {return _feats[0]->rung(cur_rung + 1);} /** * @brief Returns the type of node this is diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin/sin.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin/sin.hpp index 5dd50188..a254182b 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin/sin.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin/sin.hpp @@ -48,6 +48,7 @@ class SinNode: public OperatorNode<1> } public: + using Node::rung; /** * @brief Base Constructor @@ -133,7 +134,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return _feats[0]->rung(cur_rung + 1);} + inline int rung(const int cur_rung) const {return _feats[0]->rung(cur_rung + 1);} /** * @brief Returns the type of node this is diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/six_pow/sixth_power.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/six_pow/sixth_power.hpp index 6a295650..00baa8f2 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/six_pow/sixth_power.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/six_pow/sixth_power.hpp @@ -48,6 +48,7 @@ class SixPowNode: public OperatorNode<1> } public: + using Node::rung; /** * @brief Base Constructor @@ -133,7 +134,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return _feats[0]->rung(cur_rung + 1);} + inline int rung(const int cur_rung) const {return _feats[0]->rung(cur_rung + 1);} /** * @brief Returns the type of node this is diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sq/square.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sq/square.hpp index 39eb3bfd..416736bb 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sq/square.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sq/square.hpp @@ -47,6 +47,7 @@ class SqNode: public OperatorNode<1> ar & boost::serialization::base_object<OperatorNode>(*this); } public: + using Node::rung; /** * @brief Base Constructor @@ -132,7 +133,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return _feats[0]->rung(cur_rung + 1);} + inline int rung(const int cur_rung) const {return _feats[0]->rung(cur_rung + 1);} /** * @brief Returns the type of node this is diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sqrt/square_root.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sqrt/square_root.hpp index 07b2d1da..fc96cfe7 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sqrt/square_root.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sqrt/square_root.hpp @@ -48,6 +48,7 @@ class SqrtNode: public OperatorNode<1> } public: + using Node::rung; /** * @brief Base Constructor @@ -133,7 +134,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return _feats[0]->rung(cur_rung + 1);} + inline int rung(const int cur_rung) const {return _feats[0]->rung(cur_rung + 1);} /** * @brief Returns the type of node this is diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sub/subtract.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sub/subtract.hpp index ea2b3d0d..12e967c0 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sub/subtract.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sub/subtract.hpp @@ -49,6 +49,7 @@ class SubNode: public OperatorNode<2> } public: + using Node::rung; /** * @brief Base Constructor @@ -138,7 +139,7 @@ public: * @brief return the rung of the feature (Height of the binary expression tree - 1) * */ - inline int rung(const int cur_rung=0) const {return std::max(_feats[0]->rung(cur_rung + 1), _feats[1]->rung(cur_rung + 1));} + inline int rung(const int cur_rung) const {return std::max(_feats[0]->rung(cur_rung + 1), _feats[1]->rung(cur_rung + 1));} /** * @brief Returns the type of node this is diff --git a/src/python/py_binding_cpp_def/bindings_docstring_keyed.cpp b/src/python/py_binding_cpp_def/bindings_docstring_keyed.cpp index 0b7f48f3..def747f6 100644 --- a/src/python/py_binding_cpp_def/bindings_docstring_keyed.cpp +++ b/src/python/py_binding_cpp_def/bindings_docstring_keyed.cpp @@ -66,12 +66,14 @@ void sisso::register_all() "@DocString_node_utils_phi_sel_from_file@" ); - void (*init_val_ar_list)(py::list, py::list, int, int, bool) = &node_value_arrs::initialize_values_arr; - void (*init_val_ar_arr)(np::ndarray, np::ndarray, int, int, bool) = &node_value_arrs::initialize_values_arr; void (*init_val_ar_list_no_params)(py::list, py::list, int, int) = &node_value_arrs::initialize_values_arr; void (*init_val_ar_arr_no_params)(np::ndarray, np::ndarray, int, int) = &node_value_arrs::initialize_values_arr; + #ifdef PARAMETERIZE + void (*init_val_ar_list)(py::list, py::list, int, int, bool) = &node_value_arrs::initialize_values_arr; + void (*init_val_ar_arr)(np::ndarray, np::ndarray, int, int, bool) = &node_value_arrs::initialize_values_arr; + def( "initialize_values_arr", init_val_ar_list, @@ -84,6 +86,7 @@ void sisso::register_all() (arg("task_sz_train"), arg("task_sz_test"), arg("n_primary_feat"), arg("max_rung"), arg("use_params")), "@DocString_node_vals_ts_arr@" ); + #endif def( "initialize_values_arr", @@ -341,6 +344,7 @@ void sisso::feature_creation::registerUnit() void (Node::*reindex_1)(unsigned long int) = &Node::reindex; void (Node::*reindex_2)(unsigned long int, unsigned long int) = &Node::reindex; int (Node::*n_leaves_prop)() const = &Node::n_leaves; + int (Node::*rung_prop)() const = &Node::rung; class_<sisso::feature_creation::node::NodeWrap, boost::noncopyable>("Node", "@DocString_cls_node@", no_init) .add_property("n_samp", &Node::n_samp, "@DocString_node_n_samp@") @@ -358,12 +362,12 @@ void sisso::feature_creation::registerUnit() .add_property("latex_expr", &Node::latex_expr, "@DocString_node_latex_expr@") .add_property("parameters", &Node::parameters_py, "@DocString_node_parameters_py@") .add_property("n_leaves", n_leaves_prop, "@DocString_node_n_leaves@") + .add_property("rung", rung_prop, "@DocString_node_rung@") .def("reindex", reindex_1, (arg("self"), arg("feat_ind")), "@DocString_node_reindex_1@") .def("reindex", reindex_2, (arg("self"), arg("feat_ind"), arg("arr_ind")), "@DocString_node_reindex_2@") .def("unit", pure_virtual(&Node::unit), (arg("self")), "@DocString_node_unit@") .def("is_nan", pure_virtual(&Node::is_nan), (arg("self")), "@DocString_node_is_nan@") .def("is_const", pure_virtual(&Node::is_const), (arg("self")), "@DocString_node_is_const@") - .def("rung", pure_virtual(&Node::rung), (arg("self"), arg("cur_rung")), "@DocString_node_rung@") .def("n_feats", pure_virtual(&Node::n_feats), (arg("self")), "@DocString_node_n_feats@") .def("feat", pure_virtual(&Node::feat), (arg("self"), arg("feat_ind")), "@DocString_node_feat@") .def("x_in_expr_list", pure_virtual(&Node::get_x_in_expr_list), (arg("self")), "@DocString_node_x_in_expr@") @@ -374,6 +378,8 @@ void sisso::feature_creation::registerUnit() { void (Node::*reindex_1)(unsigned long int) = &Node::reindex; void (Node::*reindex_2)(unsigned long int, unsigned long int) = &Node::reindex; + int (Node::*n_leaves_prop)() const = &Node::n_leaves; + int (Node::*rung_prop)() const = &Node::rung; class_<sisso::feature_creation::node::NodeWrap, boost::noncopyable>("Node", "@DocString_cls_node@", no_init) .add_property("n_samp", &Node::n_samp, "@DocString_node_n_samp@") .add_property("n_samp_test", &Node::n_samp_test, "@DocString_node_n_samp_test@") @@ -386,15 +392,15 @@ void sisso::feature_creation::registerUnit() .add_property("primary_feat_decomp", &Node::primary_feature_decomp_py, "@DocString_node_primary_feature_decomp@") .add_property("postfix_expr", &Node::postfix_expr, "@DocString_node_postfix_expr@") .add_property("latex_expr", &Node::latex_expr, "@DocString_node_latex_expr@") + .add_property("n_leaves", n_leaves_prop, "@DocString_node_n_leaves@") + .add_property("rung", rung_prop, "@DocString_node_rung@") .def("reindex", reindex_1, (arg("self"), arg("feat_ind")), "@DocString_node_reindex_1@") .def("reindex", reindex_2, (arg("self"), arg("feat_ind"), arg("arr_ind")), "@DocString_node_reindex_2@") .def("unit", pure_virtual(&Node::unit), (arg("self")), "@DocString_node_unit@") .def("is_nan", pure_virtual(&Node::is_nan), (arg("self")), "@DocString_node_is_nan@") .def("is_const", pure_virtual(&Node::is_const), (arg("self")), "@DocString_node_is_const@") - .def("rung", pure_virtual(&Node::rung), (arg("self"), arg("cur_rung")), "@DocString_node_rung@") .def("n_feats", pure_virtual(&Node::n_feats), (arg("self")), "@DocString_node_n_feats@") .def("feat", pure_virtual(&Node::feat), (arg("self"), arg("feat_ind")), "@DocString_node_feat@") - .def("n_leaves", pure_virtual(&Node::n_leaves), (arg("self"), arg("cur_n_leaves")), "@DocString_node_n_leaves@") .def("x_in_expr_list", pure_virtual(&Node::get_x_in_expr_list), (arg("self")), "@DocString_node_x_in_expr@") .def("matlab_fxn_expr", pure_virtual(&Node::matlab_fxn_expr), (arg("self")), "@DocString_node_matlab_expr@") ; @@ -429,7 +435,6 @@ void sisso::feature_creation::node::registerFeatureNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_feat_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_feat_node_expr_const@") .add_property("unit", &FeatureNode::unit, "@DocString_feat_node_unit@") - .add_property("rung", &FeatureNode::rung, "@DocString_feat_node_rung@") .add_property("x_in_expr_list", &FeatureNode::get_x_in_expr_list, "@DocString_feat_node_x_in_expr@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_feat_node_matlab_expr@") ; @@ -470,7 +475,6 @@ void sisso::feature_creation::node::registerModelNode() .def("eval", eval_ndarr, (arg("self"), arg("x_in")), "@DocString_model_node_eval_arr@") .def("eval", eval_list, (arg("self"), arg("x_in")), "@DocString_model_node_eval_list@") .def("eval", eval_dict, (arg("self"), arg("x_in")), "@DocString_model_node_eval_dict@") - .add_property("rung", &ModelNode::rung, "@DocString_model_node_rung@") .add_property("x_in_expr_list", &ModelNode::x_in_expr_list_py, "@DocString_model_node_x_in_expr@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_model_node_matlab_expr@") ; @@ -495,7 +499,6 @@ void sisso::feature_creation::node::registerAddNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_add_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_add_node_expr@") .add_property("unit", &AddNode::unit, "@DocString_add_node_unit@") - .add_property("rung", &AddNode::rung, "@DocString_add_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_add_node_matlab_expr@") ; } @@ -519,7 +522,6 @@ void sisso::feature_creation::node::registerSubNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_sub_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_sub_node_expr@") .add_property("unit", &SubNode::unit, "@DocString_sub_node_unit@") - .add_property("rung", &SubNode::rung, "@DocString_sub_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_sub_node_matlab_expr@") ; } @@ -543,7 +545,6 @@ void sisso::feature_creation::node::registerDivNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_div_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_div_node_expr@") .add_property("unit", &DivNode::unit, "@DocString_div_node_unit@") - .add_property("rung", &DivNode::rung, "@DocString_div_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_div_node_matlab_expr@") ; } @@ -567,7 +568,6 @@ void sisso::feature_creation::node::registerMultNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_mult_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_mult_node_expr@") .add_property("unit", &MultNode::unit, "@DocString_mult_node_unit@") - .add_property("rung", &MultNode::rung, "@DocString_mult_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_mult_node_matlab_expr@") ; } @@ -591,7 +591,6 @@ void sisso::feature_creation::node::registerAbsDiffNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_abs_diff_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_abs_diff_node_expr@") .add_property("unit", &AbsDiffNode::unit, "@DocString_abs_diff_node_unit@") - .add_property("rung", &AbsDiffNode::rung, "@DocString_abs_diff_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_abs_diff_node_matlab_expr@") ; } @@ -615,7 +614,6 @@ void sisso::feature_creation::node::registerAbsNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_abs_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_abs_node_expr@") .add_property("unit", &AbsNode::unit, "@DocString_abs_node_unit@") - .add_property("rung", &AbsNode::rung, "@DocString_abs_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_abs_node_matlab_expr@") ; } @@ -639,7 +637,6 @@ void sisso::feature_creation::node::registerInvNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_inv_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_inv_node_expr@") .add_property("unit", &InvNode::unit, "@DocString_inv_node_unit@") - .add_property("rung", &InvNode::rung, "@DocString_inv_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_inv_node_matlab_expr@") ; } @@ -663,7 +660,6 @@ void sisso::feature_creation::node::registerLogNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_log_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_log_node_expr@") .add_property("unit", &LogNode::unit, "@DocString_log_node_unit@") - .add_property("rung", &LogNode::rung, "@DocString_log_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_log_node_matlab_expr@") ; } @@ -687,7 +683,6 @@ void sisso::feature_creation::node::registerExpNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_exp_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_exp_node_expr@") .add_property("unit", &ExpNode::unit, "@DocString_exp_node_unit@") - .add_property("rung", &ExpNode::rung, "@DocString_exp_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_exp_node_matlab_expr@") ; } @@ -711,7 +706,6 @@ void sisso::feature_creation::node::registerNegExpNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_neg_exp_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_neg_exp_node_expr@") .add_property("unit", &NegExpNode::unit, "@DocString_neg_exp_node_unit@") - .add_property("rung", &NegExpNode::rung, "@DocString_neg_exp_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_neg_exp_node_matlab_expr@") ; } @@ -735,7 +729,6 @@ void sisso::feature_creation::node::registerSinNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_sin_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_sin_node_expr@") .add_property("unit", &SinNode::unit, "@DocString_sin_node_unit@") - .add_property("rung", &SinNode::rung, "@DocString_sin_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_sin_node_matlab_expr@") ; } @@ -759,7 +752,6 @@ void sisso::feature_creation::node::registerCosNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_cos_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_cos_node_expr@") .add_property("unit", &CosNode::unit, "@DocString_cos_node_unit@") - .add_property("rung", &CosNode::rung, "@DocString_cos_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_cos_node_matlab_expr@") ; } @@ -783,7 +775,6 @@ void sisso::feature_creation::node::registerCbNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_cb_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_cb_node_expr@") .add_property("unit", &CbNode::unit, "@DocString_cb_node_unit@") - .add_property("rung", &CbNode::rung, "@DocString_cb_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_cb_node_matlab_expr@") ; } @@ -807,7 +798,6 @@ void sisso::feature_creation::node::registerCbrtNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_cbrt_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_cbrt_node_expr@") .add_property("unit", &CbrtNode::unit, "@DocString_cbrt_node_unit@") - .add_property("rung", &CbrtNode::rung, "@DocString_cbrt_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_cbrt_node_matlab_expr@") ; } @@ -831,7 +821,6 @@ void sisso::feature_creation::node::registerSqNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_sq_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_sq_node_expr@") .add_property("unit", &SqNode::unit, "@DocString_sq_node_unit@") - .add_property("rung", &SqNode::rung, "@DocString_sq_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_sq_node_matlab_expr@") ; } @@ -855,7 +844,6 @@ void sisso::feature_creation::node::registerSqrtNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_sqrt_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_sqrt_node_expr@") .add_property("unit", &SqrtNode::unit, "@DocString_sqrt_node_unit@") - .add_property("rung", &SqrtNode::rung, "@DocString_sqrt_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_sqrt_node_matlab_expr@") ; } @@ -879,7 +867,6 @@ void sisso::feature_creation::node::registerSixPowNode() .def("set_test_value", set_test_value_no_param, (arg("self"), arg("offset"), arg("for_comp")), "@DocString_six_pow_node_set_test_value@") .add_property("expr", expr_no_param, "@DocString_six_pow_node_expr@") .add_property("unit", &SixPowNode::unit, "@DocString_six_pow_node_unit@") - .add_property("rung", &SixPowNode::rung, "@DocString_six_pow_node_rung@") .add_property("matlab_fxn_expr", matlab_expr, "@DocString_six_pow_node_matlab_expr@") ; } diff --git a/src/python/py_binding_cpp_def/bindings_docstring_keyed.hpp b/src/python/py_binding_cpp_def/bindings_docstring_keyed.hpp index 21760c0a..ce4de017 100644 --- a/src/python/py_binding_cpp_def/bindings_docstring_keyed.hpp +++ b/src/python/py_binding_cpp_def/bindings_docstring_keyed.hpp @@ -100,7 +100,7 @@ namespace sisso inline bool is_nan() const {return this->get_override("is_nan")();} inline bool is_const() const {return this->get_override("is_const")();} inline NODE_TYPE type() const {return this->get_override("type")();} - inline int rung(int cur_rung = 0) const {return this->get_override("rung")();} + inline int rung(int cur_rung) const {return this->get_override("rung")();} inline std::map<std::string, int> primary_feature_decomp() const {return this->get_override("primary_feature_decomp")();} inline void update_primary_feature_decomp(std::map<std::string, int>& pf_decomp) const {this->get_override("update_primary_feature_decomp")();} inline void update_postfix(std::string& cur_expr) const {this->get_override("update_postfix")();} @@ -134,7 +134,7 @@ namespace sisso inline void set_value(const double* params, int offset=-1, bool for_comp=false, int depth=1) const {this->get_override("set_value")();} inline void set_test_value(const double* params, int offset=-1, bool for_comp=false, int depth=1) const {this->get_override("set_test_value")();} inline NODE_TYPE type() const {return this->get_override("type")();} - inline int rung(int cur_rung = 0) const {return this->get_override("rung")();} + inline int rung(int cur_rung) const {return this->get_override("rung")();} inline Unit unit() const {return this->get_override("unit")();} inline std::string get_postfix_term() const {return this->get_override("get_postfix_term")();} inline std::string expr(const double* params, const int depth=1) const {return this->get_override("expr")();} @@ -217,7 +217,7 @@ namespace sisso { return this->get_override("type")(); } - inline int rung(int cur_rung = 0) const + inline int rung(int cur_rung) const { return this->get_override("rung")(); } @@ -282,7 +282,7 @@ namespace sisso { return this->get_override("type")(); } - inline int rung(int cur_rung = 0) const + inline int rung(int cur_rung) const { return this->get_override("rung")(); } diff --git a/src/python/py_binding_cpp_def/feature_creation/ModelNode.cpp b/src/python/py_binding_cpp_def/feature_creation/ModelNode.cpp index 2524fd41..832d2f27 100644 --- a/src/python/py_binding_cpp_def/feature_creation/ModelNode.cpp +++ b/src/python/py_binding_cpp_def/feature_creation/ModelNode.cpp @@ -22,8 +22,8 @@ #include "feature_creation/node/ModelNode.hpp" ModelNode::ModelNode( - const unsigned long int feat_ind, - const unsigned long int rung, + const int feat_ind, + const int rung, const std::string expr, const std::string latex_expr, const std::string postfix_expr, @@ -48,8 +48,8 @@ ModelNode( {} ModelNode::ModelNode( - const unsigned long int feat_ind, - const unsigned long int rung, + const int feat_ind, + const int rung, const std::string expr, const std::string latex_expr, const std::string postfix_expr, diff --git a/tests/exec_test/default/sisso.json b/tests/exec_test/default/sisso.json index 384e552e..7c12e68e 100644 --- a/tests/exec_test/default/sisso.json +++ b/tests/exec_test/default/sisso.json @@ -10,6 +10,5 @@ "leave_out_frac": 0.05, "n_models_store": 1, "leave_out_inds": [0, 1, 2, 60, 61], - "opset": ["add", "sub", "abs_diff", "mult", "div", "inv", "abs", "exp", "log", "sin", "cos", "sq", "cb", "six_pow", "sqrt", "cbrt", "neg_exp"], "fix_intercept": false } diff --git a/tests/googletest/feature_creation/units/test_untis.cc b/tests/googletest/feature_creation/units/test_untis.cc index ea5b33b9..e367712a 100644 --- a/tests/googletest/feature_creation/units/test_untis.cc +++ b/tests/googletest/feature_creation/units/test_untis.cc @@ -38,8 +38,18 @@ namespace { u_2 /= u_1; EXPECT_EQ(u_2, Unit("s/m")); + + u_1 *= u_2; + EXPECT_EQ(u_1, Unit("s")); + + u_1 /= u_2; + EXPECT_EQ(u_1, Unit("m")); + u_2 *= u_2; EXPECT_EQ(u_2, Unit("s * s/m^2")); + EXPECT_EQ(u_2.toLatexString(), "m$^\\text{-2}$s$^\\text{2}$"); + + EXPECT_EQ(Unit().toLatexString(), "Unitless"); } //test mean calculations @@ -69,7 +79,17 @@ namespace { u_2 /= u_1; EXPECT_EQ(u_2, Unit("s/m")); + + u_1 *= u_2; + EXPECT_EQ(u_1, Unit("s")); + + u_1 /= u_2; + EXPECT_EQ(u_1, Unit("m")); + u_2 *= u_2; EXPECT_EQ(u_2, Unit("s * s/m^2")); + EXPECT_EQ(u_2.toLatexString(), "m$^\\text{-2}$s$^\\text{2}$"); + + EXPECT_EQ(Unit().toLatexString(), "Unitless"); } } diff --git a/tests/pytest/test_feature_creation/test_feat_generation/test_abs_diff_node.py b/tests/pytest/test_feature_creation/test_feat_generation/test_abs_diff_node.py index f06158a1..04e35f9f 100644 --- a/tests/pytest/test_feature_creation/test_feat_generation/test_abs_diff_node.py +++ b/tests/pytest/test_feature_creation/test_feat_generation/test_abs_diff_node.py @@ -31,7 +31,13 @@ def test_abs_diff_node(): task_sizes_train = [900] task_sizes_test = [10] - initialize_values_arr(task_sizes_train, task_sizes_test, 4, 2) + initialize_values_arr( + np.array(task_sizes_train, dtype=np.int32), + np.array(task_sizes_test, dtype=np.int32), + 4, + 2, + False, + ) data_1 = np.random.random(task_sizes_train[0]) * 1e10 + 1e-10 test_data_1 = np.random.random(task_sizes_test[0]) * 1e10 + 1e-10 diff --git a/tests/pytest/test_feature_creation/test_feat_generation/test_abs_node.py b/tests/pytest/test_feature_creation/test_feat_generation/test_abs_node.py index 47ac337c..b61a1bab 100644 --- a/tests/pytest/test_feature_creation/test_feat_generation/test_abs_node.py +++ b/tests/pytest/test_feature_creation/test_feat_generation/test_abs_node.py @@ -24,7 +24,7 @@ def test_abs_node(): task_sizes_train = [90] task_sizes_test = [10] - initialize_values_arr(task_sizes_train, task_sizes_test, 3, 2) + initialize_values_arr(task_sizes_train, task_sizes_test, 3, 2, False) data_1 = np.random.random(task_sizes_train[0]) * 1e4 + 1e-10 test_data_1 = np.random.random(task_sizes_test[0]) * 1e4 + 1e-10 diff --git a/tests/pytest/test_feature_creation/test_feat_generation/test_add_node.py b/tests/pytest/test_feature_creation/test_feat_generation/test_add_node.py index a56cd2f1..67d7c37c 100644 --- a/tests/pytest/test_feature_creation/test_feat_generation/test_add_node.py +++ b/tests/pytest/test_feature_creation/test_feat_generation/test_add_node.py @@ -24,7 +24,12 @@ def test_add_node(): task_sizes_train = [90] task_sizes_test = [10] - initialize_values_arr(task_sizes_train, task_sizes_test, 4, 2) + initialize_values_arr( + np.array(task_sizes_train, dtype=np.int32), + np.array(task_sizes_test, dtype=np.int32), + 4, + 2, + ) data_1 = np.random.random(task_sizes_train[0]) * 1e4 + 1e-10 test_data_1 = np.random.random(task_sizes_test[0]) * 1e4 + 1e-10 diff --git a/tests/pytest/test_feature_creation/test_feat_generation/test_model_node.py b/tests/pytest/test_feature_creation/test_feat_generation/test_model_node.py index 4070244a..27264a62 100644 --- a/tests/pytest/test_feature_creation/test_feat_generation/test_model_node.py +++ b/tests/pytest/test_feature_creation/test_feat_generation/test_model_node.py @@ -43,6 +43,31 @@ def test_model_node(): feat_3 = AddNode(feat_1, feat_2, 2, 1e-50, 1e50) model_node = ModelNode(feat_3) + model_node_list = ModelNode( + model_node.feat_ind, + model_node.rung, + model_node.expr, + model_node.latex_expr, + model_node.postfix_expr, + model_node.matlab_fxn_expr, + list(model_node.value), + list(model_node.test_value), + model_node.x_in_expr_list, + model_node.unit, + ) + model_node_arr = ModelNode( + model_node.feat_ind, + model_node.rung, + model_node.expr, + model_node.latex_expr, + model_node.postfix_expr, + model_node.matlab_fxn_expr, + np.array(model_node.value), + np.array(model_node.test_value), + model_node.x_in_expr_list, + model_node.unit, + ) + assert model_node.n_leaves == 2 decomp = model_node.primary_feat_decomp @@ -60,6 +85,40 @@ def test_model_node(): pass assert model_node.matlab_fxn_expr == "(t_a + t_b)" + assert model_node.n_leaves == 2 + + assert model_node.n_leaves == 2 + + decomp = model_node_arr.primary_feat_decomp + assert len(decomp.keys()) == 2 + assert decomp["t_a"] == 1 + assert decomp["t_b"] == 1 + assert model_node_arr.x_in_expr_list[0] == "t_a" + assert model_node_arr.x_in_expr_list[1] == "t_b" + + assert model_node_arr.n_leaves == 2 + try: + model_node_arr.feat(2) + raise ValueError("Accessing feature that should throw an error") + except: + pass + assert model_node_arr.matlab_fxn_expr == "(t_a + t_b)" + + decomp = model_node_list.primary_feat_decomp + assert len(decomp.keys()) == 2 + assert decomp["t_a"] == 1 + assert decomp["t_b"] == 1 + assert model_node_list.x_in_expr_list[0] == "t_a" + assert model_node_list.x_in_expr_list[1] == "t_b" + + assert model_node_list.n_leaves == 2 + try: + model_node_list.feat(2) + raise ValueError("Accessing feature that should throw an error") + except: + pass + assert model_node_list.matlab_fxn_expr == "(t_a + t_b)" + if __name__ == "__main__": test_model_node() diff --git a/tests/pytest/test_feature_creation/test_feature_space/data.csv b/tests/pytest/test_feature_creation/test_feature_space/data.csv new file mode 100644 index 00000000..00650540 --- /dev/null +++ b/tests/pytest/test_feature_creation/test_feature_space/data.csv @@ -0,0 +1,101 @@ +Sample,Task,Prop,A (m),B (s),C,D (Unitless) +1,X,1031303.34310437,40047.7725031033,81.6019767547866,12535.2818525271,-683.666065848847 +2,X,207179.181972689,8273.93114052335,47.4359192293739,2518.19019867913,-1407.86160002623 +3,X,594547.990034924,-24495.5390890833,46.3994727792424,7226.59341895378,-154.449699580799 +4,X,1431871.75085735,-5975.17124802999,96.2922472869417,17404.1240046628,-383.63965153104 +5,X,2132341.51391611,33545.2455355934,23.2389997524879,25918.2170844233,-2214.8717939546 +6,X,1849456.85903214,-36585.1506450251,21.7653754396546,22479.8013103184,-499.788202406702 +7,X,416377.473683951,47617.1641535909,53.9342164837372,5060.96052467702,-2002.28785563532 +8,X,1834852.24383494,164.577549590314,55.7417291729005,22302.2848302114,-1462.8889504883 +9,X,2030615.0021387,-25590.077352893,13.3180597514294,24681.7483092487,-267.582565811964 +10,X,418204.906991729,-35631.266855653,67.830087711799,5083.17267158509,-2819.77637904098 +11,X,1600764.65336791,24069.5603461085,91.2031527296231,19456.9890506716,-2706.92171287459 +12,X,-237442.303891325,-28375.8492844066,76.6780058713539,-2886.10976641617,-1650.25772935281 +13,X,389569.403019936,-17679.1039531987,93.7334723703787,4735.11289934218,-553.765889146761 +14,X,1097874.59558522,25271.39171418,53.6965192771211,13344.4443174432,-1094.01486564295 +15,X,896512.426133544,-16691.6898965759,19.4379065649528,10896.9207498079,-2899.60958857901 +16,X,12475.3344165542,11073.3959911305,52.0025761588363,151.597422562947,-782.134708201617 +17,X,643218.531288929,-33665.7156040407,29.7373317632719,7818.17572605823,-1080.66347038372 +18,X,888098.246309737,-42864.1312633446,93.9228362331387,10794.6477981533,-1638.80485180208 +19,X,1636015.66023612,-1874.52319024457,61.4904198919873,19885.4591582095,-2643.77032366468 +20,X,1523022.28471858,-49138.4737863941,17.975585548934,18512.0435328828,-560.378442383903 +21,X,-18066.9165614168,-35122.5184807359,6.32108929256205,-219.638541412487,-1004.04464422701 +22,X,753574.994852389,-504.277781827623,64.3463985117791,9159.54014727008,-690.33547481712 +23,X,484679.670507055,-47904.9616755848,34.793137673643,5891.16232922052,-2871.23133035778 +24,X,1418886.29518641,40005.8303266016,89.663527446701,17246.2879576819,-1230.52218744124 +25,X,746864.366592613,-29303.0557293284,63.1160346689987,9077.97355841423,-3078.94168258733 +26,X,826676.469591929,31855.9700915967,12.4598774065994,10048.0763518243,-3214.1429201838 +27,X,904870.905255709,-1370.05112198737,18.1776031280461,10998.5166695707,-1733.87235240405 +28,X,1081673.04047048,46129.8007590074,65.8763747557873,13147.5171186325,-1237.15538447696 +29,X,1602766.31102942,12215.0498178804,28.9863403535557,19481.3188655265,-2669.08606113272 +30,X,848296.081366335,-8523.54146953082,14.4884132013553,10310.8591252139,-1070.59795231075 +31,X,881987.050483579,-32109.023962203,59.952453510063,10720.3672326848,-1978.64149010475 +32,X,1384967.83924126,31795.5231836559,46.3619825035018,16834.0147857661,-3214.77894538541 +33,X,1435243.99308821,-41605.9821955878,61.1093419800895,17445.1130460068,-1581.87602287648 +34,X,1482822.4415542,-49423.8250112063,57.7898783145655,18023.4211475179,-2245.35073430102 +35,X,1159462.50457973,24974.6967563244,2.46710777290358,14093.035073862,-1653.30479641573 +36,X,1385445.91552098,44300.000697173,14.1598975077974,16839.8257231643,-1154.39523418031 +37,X,1078840.90378916,-33471.5314909414,86.4825835158785,13113.0929698841,-1772.81761496697 +38,X,322072.318257427,-32616.3765208785,71.5517709413264,3914.69709752203,-1834.58611475719 +39,X,1547503.57192612,15339.6613906795,78.8203546957091,18809.6094936362,-538.87662795121 +40,X,1174714.5075073,38777.544632935,63.0951620300882,14278.4206242917,-380.323852794412 +41,X,94875.3402808423,12249.6781769406,90.3127462736438,1153.15574346477,-1590.10909636815 +42,X,362160.364120508,-49277.9984660007,8.3266338235128,4401.96060534147,-1423.02119058586 +43,X,673617.378755157,21157.5642575089,40.4360003803782,8187.66864424759,-2304.57417593545 +44,X,882351.052225793,44482.8684188695,60.148559750113,10724.7916131167,-3010.89784032583 +45,X,22400.9390066318,17108.6417404538,68.2422016131663,272.24149001619,-1091.87923472037 +46,X,1781136.79777257,30136.189144163,65.8784392884513,21649.382366535,-779.999951946907 +47,X,621416.608280441,-31495.5881531396,67.4176383345993,7553.17699006137,-3091.37667023128 +48,X,750411.885581194,42277.9111802948,52.7091601206799,9121.09305972893,-1213.67564944238 +49,X,1525062.49801326,-20619.9327982041,18.5983023041602,18536.841985025,-518.413321644593 +50,X,679068.208535292,42337.0868480189,55.8737535970023,8253.92257061978,-1337.41889839093 +51,X,447826.687204506,-3841.47148699515,57.8803936758992,5443.22046731452,-2117.64647879144 +52,X,336890.280723035,-25698.4911052116,26.2484582718796,4094.80695856079,-2304.9408398086 +53,X,468079.149217039,-36421.9167980631,9.52225176867021,5689.38576313015,-2346.34809901136 +54,X,1404060.53519045,10116.138294505,33.8807589471792,17066.0833189846,-2177.75555908996 +55,X,1827150.95390431,33677.6712656449,65.3664484400669,22208.6767557623,-768.872566798946 +56,X,-33394.4572217261,23643.588170146,95.3617653535894,-405.942240360551,-802.333589068958 +57,X,1443453.59596531,48648.6785581152,83.107773775309,17544.8993990111,-1826.75004222983 +58,X,1550858.36965351,39565.5654401456,28.6332188363784,18850.3865001573,-176.047021901582 +59,X,329623.778660326,9384.94614690253,83.9023194218408,4006.48383865674,-1510.2742546313 +60,X,596362.271476793,-7862.7530203713,84.8842436218459,7248.64570723748,-1125.70379322904 +61,Y,-1747903.77060764,-15426.701719437,73.530132833494,12277.508278164,-2388.05382648639 +62,Y,-602031.002716425,-26628.9177804096,56.1127291052339,4228.72153980883,-2494.38544516297 +63,Y,-914915.654901957,-22908.9603476779,55.2235512174418,6426.47150794108,-3065.18336481344 +64,Y,-1293976.98085175,44255.5466634393,24.3327718109724,9089.05662095335,-1530.79847762564 +65,Y,-556992.07118952,44821.3470186639,63.0165978378747,3912.36115061704,-3240.22306333347 +66,Y,-2294033.39637973,-44132.4823446645,42.5612469609221,16113.6068505302,-255.147829778129 +67,Y,-1213629.16675478,-42539.8069752961,48.9584343155192,8524.68120454026,-164.586906089718 +68,Y,292809.005099769,-49432.7543013633,80.507648968553,-2056.77244309559,-871.09190770659 +69,Y,-2235342.64861732,-6632.95213361424,93.4293107228537,15701.3540037878,-2178.77545326323 +70,Y,-3732932.41042696,-40485.6986880114,25.9765685287417,26220.6550555191,-598.407067002771 +71,Y,-252210.474776827,-14427.735364365,59.6676061209021,1771.52829396117,-845.471004201544 +72,Y,-98889.6656695742,-41488.1745504839,42.4820587894618,694.579325611127,-1299.98047519081 +73,Y,-1370204.37668197,-21879.2550863842,34.7942407834795,9624.48958514065,-1954.71594115708 +74,Y,-3104420.51726401,31704.4165180227,44.4564228685462,21805.8907533925,-977.750657934738 +75,Y,-2388277.98822188,7800.82135026513,48.5821408939988,16775.5953738995,-2577.18311095899 +76,Y,-724977.658463333,-44659.6170999659,35.6876655675306,5092.31777838026,-2837.25474789309 +77,Y,-1794477.91392858,6521.923601348,88.7042922408313,12604.6522325595,-2393.39103443748 +78,Y,-223213.115899978,28443.9701603649,37.3226807484787,1567.84638066104,-1284.75416736837 +79,Y,423046.005849878,9502.16765496074,17.4038852841401,-2971.57718730938,-793.452569769765 +80,Y,-3047818.61588223,7598.41423185622,90.0700126497531,21408.3102848147,-749.371738309082 +81,Y,-2409342.6015377,35261.072039404,47.9286965191158,16923.5564603709,-3048.15567690909 +82,Y,-742814.585466495,-29503.3166005498,7.75349725175401,5217.60709978568,-2729.9120626205 +83,Y,-571579.430647006,-44941.8628170447,85.8317735174233,4014.82500929801,-1269.94347716121 +84,Y,-2195610.2686634,48026.9824672444,3.47886888914346,15422.2676483873,-208.387321904327 +85,Y,-964020.379427545,-5862.59066560875,32.3951971412924,6771.39065366773,-2348.00221246913 +86,Y,-2102214.66994452,-1627.31398929461,65.1915191571454,14766.2426011092,-2448.65166476797 +87,Y,-890649.179337315,-31734.0384124326,73.7172018923155,6256.02004752945,-586.069879271884 +88,Y,-2207063.83218629,14835.6206610657,31.7102632894148,15502.7192420416,-1698.88417254839 +89,Y,-749402.325380223,-49686.6769123602,49.3012898909983,5263.8803991841,-1176.36020313534 +90,Y,-2494089.08559485,234.017339793194,43.1649546520338,17518.8293606888,-2223.27305100155 +91,Y,-758480.09438593,-42219.6177653841,85.476183481532,5327.64404629679,-864.677157209562 +92,Y,-2025827.98191011,2374.67279858794,33.5495503844189,14229.6907309965,-2868.43169850788 +93,Y,-2354065.35735529,-48559.373111767,43.9360775681768,16535.2805868339,-1226.37195019107 +94,Y,-1588621.54314025,-37866.8557345306,22.4186822710487,11158.6853894212,-2716.07040834036 +95,Y,-3175419.95188679,-45432.4026527398,31.3118028803292,22304.6017131089,-666.77340835222 +96,Y,-2152215.92330461,-26966.2051976371,0.258766409063485,15117.4590856704,-32.6895291544268 +97,Y,-547157.630624095,-1300.97533450509,46.2515307967681,3843.28254500137,-2502.56292413987 +98,Y,-2672876.70357122,28750.3814277021,7.66749583919236,18774.6605662742,-1875.23509974759 +99,Y,-2080211.9597305,-40822.549051454,89.438883925997,14611.6921599612,-1948.30990769798 +100,Y,-2578377.05246833,-2300.90575344433,65.926962237196,18110.8804765956,-2076.35142495637 diff --git a/tests/pytest/test_feature_creation/test_feature_space/selected_features.txt b/tests/pytest/test_feature_creation/test_feature_space/selected_features.txt new file mode 100644 index 00000000..0a50bb93 --- /dev/null +++ b/tests/pytest/test_feature_creation/test_feature_space/selected_features.txt @@ -0,0 +1,5 @@ +# FEAT_ID Feature Postfix Expression (RPN) +0 3|2|add|3|abs|add +#----------------------------------------------------------------------- +1 1|0|div|0|div +#----------------------------------------------------------------------- diff --git a/tests/pytest/test_feature_creation/test_feature_space/test_gen_feature_space_selected_from_file.py b/tests/pytest/test_feature_creation/test_feature_space/test_gen_feature_space_selected_from_file.py new file mode 100644 index 00000000..6ed153a6 --- /dev/null +++ b/tests/pytest/test_feature_creation/test_feature_space/test_gen_feature_space_selected_from_file.py @@ -0,0 +1,36 @@ +# Copyright 2021 Thomas A. R. Purcell +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import shutil +import numpy as np +from sissopp import phi_selected_from_file +from sissopp.py_interface import read_csv + +import pathlib + +parent = pathlib.Path(__file__).parent.absolute() + + +def test_gen_feature_space_from_file(): + inputs = read_csv(str(parent / "data.csv"), "Prop", task_key="Task", max_rung=2) + + phi_sel = phi_selected_from_file( + str(parent / "selected_features.txt"), inputs.phi_0 + ) + + assert phi_sel[0].postfix_expr == "3|2|add|3|abs|add" + assert phi_sel[1].postfix_expr == "1|0|div|0|div" + + +if __name__ == "__main__": + test_gen_feature_space_from_file() diff --git a/tests/pytest/test_feature_creation/test_parameterize/test_param_abs.py b/tests/pytest/test_feature_creation/test_parameterize/test_param_abs.py index 5a0d89bc..2d9e0044 100644 --- a/tests/pytest/test_feature_creation/test_parameterize/test_param_abs.py +++ b/tests/pytest/test_feature_creation/test_parameterize/test_param_abs.py @@ -37,8 +37,13 @@ def test_param_abs_node(): task_sizes_train = [900] task_sizes_test = [10] - initialize_values_arr(task_sizes_train, task_sizes_test, 1, 1) - initialize_param_storage() + initialize_values_arr( + np.array(task_sizes_train, dtype=np.int32), + np.array(task_sizes_test, dtype=np.int32), + 1, + 1, + True, + ) data_1 = np.linspace(-20, 20, task_sizes_train[0]) test_data_1 = np.linspace(-19.99, 19.99, task_sizes_test[0]) -- GitLab