diff --git a/src/feature_creation/feature_space/FeatureSpace.cpp b/src/feature_creation/feature_space/FeatureSpace.cpp index 5cebdc24855cf7c6df314f47acf06e951b739107..7948843787e349e486326235e63f808931e44309 100644 --- a/src/feature_creation/feature_space/FeatureSpace.cpp +++ b/src/feature_creation/feature_space/FeatureSpace.cpp @@ -540,6 +540,12 @@ void FeatureSpace::sis(std::vector<double>& prop) node_value_arrs::clear_temp_test_reg(); } + for(auto& feat : phi_sel) + { + feat->set_selected(false); + feat->set_d_mat_ind(-1); + } + // If we are only on one process then phi_sel are the selected features if(_mpi_comm->size() > 1) { diff --git a/src/feature_creation/node/FeatureNode.hpp b/src/feature_creation/node/FeatureNode.hpp index abf9086b0b7dfd539240bd20b8f752224082dc57..265e06952cb48079d39a606c2caea8dccaaf025b 100644 --- a/src/feature_creation/node/FeatureNode.hpp +++ b/src/feature_creation/node/FeatureNode.hpp @@ -243,6 +243,27 @@ public: */ inline std::string get_postfix_term(){return std::to_string(_feat_ind);} + //DocString: feat_node_nfeats + /** + * @brief Number of features used for an operator node + * @return the number of features for an operator node + */ + inline int n_feats(){return 0;} + + //DocString: feat_node_feat + /** + * @brief Return the ind node_ptr in the operator node's feat list + * + * @param ind the index of the node to access + * @return the ind feature in feature_list + */ + inline node_ptr feat(int ind) + { + if(ind > 0) + throw std::logic_error("Index not found in _feats"); + return nullptr; + } + /** * @brief update the dictionary used to check if an Add/Sub node is valid * diff --git a/src/feature_creation/node/Node.hpp b/src/feature_creation/node/Node.hpp index ef7e7b5101dbad58c020424def4046cb992b8f97..c7cb6be0dc8823ab7b13f22dc716a7a80dcfb8ce 100644 --- a/src/feature_creation/node/Node.hpp +++ b/src/feature_creation/node/Node.hpp @@ -316,6 +316,21 @@ public: */ virtual void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) = 0; + //DocString: node_nfeats + /** + * @brief Number of features used for an operator node + * @return the number of features for an operator node + */ + virtual int n_feats() = 0; + + //DocString: node_feat + /** + * @brief Return the ind node_ptr in the operator node's feat list + * + * @param ind the index of the node to access + * @return the ind feature in feature_list + */ + virtual std::shared_ptr<Node> feat(int ind) = 0; #ifdef PY_BINDINGS diff --git a/src/feature_creation/node/operator_nodes/OperatorNode.hpp b/src/feature_creation/node/operator_nodes/OperatorNode.hpp index 41edc4ddd283d1fd38bf47289d8bd1e23bd2f91f..5c5c97f12ae9d42074890f4eba8574770f9e695a 100644 --- a/src/feature_creation/node/operator_nodes/OperatorNode.hpp +++ b/src/feature_creation/node/operator_nodes/OperatorNode.hpp @@ -255,6 +255,27 @@ public: */ virtual std::string get_postfix_term() = 0; + //DocString: op_node_nfeats + /** + * @brief Number of features used for an operator node + * @return the number of features for an operator node + */ + inline int n_feats(){return N;} + + //DocString: op_node_feat + /** + * @brief Return the ind node_ptr in the operator node's feat list + * + * @param ind the index of the node to access + * @return the ind feature in feature_list + */ + inline node_ptr feat(int ind) + { + if(ind > N) + throw std::logic_error("Index not found in _feats"); + return _feats[ind]; + } + /** * @brief update the dictionary used to check if an Add/Sub node is valid * diff --git a/src/python/bindings_docstring_keyed.cpp b/src/python/bindings_docstring_keyed.cpp index 803c5597d9310932f26e03268efd3b70b528c1c7..4cd218283ae0007dae650c562beb59dfe63a3452 100644 --- a/src/python/bindings_docstring_keyed.cpp +++ b/src/python/bindings_docstring_keyed.cpp @@ -112,6 +112,8 @@ void sisso::feature_creation::node::registerNode() .def("is_nan", pure_virtual(&Node::is_nan), "@DocString_node_is_nan@") .def("is_const", pure_virtual(&Node::is_const), "@DocString_node_is_const@") .def("rung", pure_virtual(&Node::rung), "@DocString_node_rung@") + .def("n_feats", pure_virtual(&Node::n_feats), "@DocString_node_n_feats@") + .def("feat", pure_virtual(&Node::feat), "@DocString_node_feat@") ; } diff --git a/src/python/bindings_docstring_keyed.hpp b/src/python/bindings_docstring_keyed.hpp index e11568b8946ff41e22497b10a8af101e441dbd37..1ebc81b505aa614146400f2acf208d8eeb7c9736 100644 --- a/src/python/bindings_docstring_keyed.hpp +++ b/src/python/bindings_docstring_keyed.hpp @@ -52,6 +52,8 @@ namespace sisso std::string get_postfix_term(){return this->get_override("get_postfix_term")();} inline void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot){this->get_override("update_add_sub_leaves");} inline void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot){this->get_override("update_div_mult_leaves");} + inline int n_feats(){this->get_override("n_feats");} + inline std::shared_ptr<Node> feat(int ind){this->get_override("feat");} }; /** * @brief struct used wrap an OperatorNode object for conversion @@ -99,6 +101,9 @@ namespace sisso .def("rung", py::pure_virtual(&OperatorNode<N>::rung), "@DocString_op_node_rung@") .def("expr", py::pure_virtual(&OperatorNode<N>::expr), "@DocString_op_node_expr@") .def("unit", py::pure_virtual(&OperatorNode<N>::unit), "@DocString_op_node_unit@") + .add_property("n_feats", &OperatorNode<N>::n_feats, "@DocString_op_node_n_feats@") + .add_property("feat", &OperatorNode<N>::feat, "@DocString_op_node_feat@") + ; }