From 8f0ade75261f81f26e46f800c720600e629dbf9d Mon Sep 17 00:00:00 2001 From: Thomas Purcell <purcell@fhi-berlin.mpg.de> Date: Mon, 1 Jun 2020 16:08:31 +0200 Subject: [PATCH] Minor refactors Added type condition tests (initial simplification) Added som more documenatation --- .../feature_space/FeatureSpace.cpp | 16 +++------------- src/feature_creation/node/FeatureNode.hpp | 5 +++++ src/feature_creation/node/Node.hpp | 5 +++++ .../node/operator_nodes/OperatorNode.hpp | 5 +++++ .../absolute_difference.cpp | 10 ++++++++-- .../absolute_difference.hpp | 5 +++++ .../allowed_operator_nodes/absolute_value.cpp | 4 ++-- .../allowed_operator_nodes/absolute_value.hpp | 5 +++++ .../allowed_operator_nodes/add.cpp | 4 ++-- .../allowed_operator_nodes/add.hpp | 5 +++++ .../allowed_operator_nodes/cos.cpp | 10 ++++++++-- .../allowed_operator_nodes/cos.hpp | 5 +++++ .../allowed_operator_nodes/cube.cpp | 10 ++++++++-- .../allowed_operator_nodes/cube.hpp | 5 +++++ .../allowed_operator_nodes/cube_root.cpp | 10 ++++++++-- .../allowed_operator_nodes/cube_root.hpp | 5 +++++ .../allowed_operator_nodes/divide.cpp | 10 ++++++++-- .../allowed_operator_nodes/divide.hpp | 5 +++++ .../allowed_operator_nodes/exponential.cpp | 10 ++++++++-- .../allowed_operator_nodes/exponential.hpp | 5 +++++ .../allowed_operator_nodes/inverse.cpp | 10 ++++++++-- .../allowed_operator_nodes/inverse.hpp | 5 +++++ .../allowed_operator_nodes/log.cpp | 10 ++++++++-- .../allowed_operator_nodes/log.hpp | 5 +++++ .../allowed_operator_nodes/multiply.cpp | 4 ++-- .../allowed_operator_nodes/multiply.hpp | 5 +++++ .../negative_exponential.cpp | 12 ++++++++++-- .../negative_exponential.hpp | 5 +++++ .../allowed_operator_nodes/sin.cpp | 10 ++++++++-- .../allowed_operator_nodes/sin.hpp | 5 +++++ .../allowed_operator_nodes/sixth_power.cpp | 10 ++++++++-- .../allowed_operator_nodes/sixth_power.hpp | 5 +++++ .../allowed_operator_nodes/square.cpp | 10 ++++++++-- .../allowed_operator_nodes/square.hpp | 5 +++++ .../allowed_operator_nodes/square_root.cpp | 10 ++++++++-- .../allowed_operator_nodes/square_root.hpp | 5 +++++ .../allowed_operator_nodes/subtract.cpp | 4 ++-- .../allowed_operator_nodes/subtract.hpp | 5 +++++ .../value_storage/nodes_value_containers.cpp | 1 + .../value_storage/nodes_value_containers.hpp | 1 + 40 files changed, 219 insertions(+), 47 deletions(-) diff --git a/src/feature_creation/feature_space/FeatureSpace.cpp b/src/feature_creation/feature_space/FeatureSpace.cpp index 07c09289..956c3397 100644 --- a/src/feature_creation/feature_space/FeatureSpace.cpp +++ b/src/feature_creation/feature_space/FeatureSpace.cpp @@ -126,30 +126,21 @@ void FeatureSpace::generate_feature_space() } } _mpi_comm->barrier(); - if(_mpi_comm->rank() == 0) - std::cout << "NEXT_PHI MADE" << std::endl; - else - std::cout << "NNNN NEXT_PHI MADE" << std::endl; _start_gen.push_back(_phi.size()); std::vector<std::vector<node_ptr>> next_phi_gathered; mpi::all_gather(*_mpi_comm, next_phi, next_phi_gathered); - if(_mpi_comm->rank() == 0) - std::cout << "all gather" << std::endl; - else - std::cout << "aaaa all_gather" << std::endl; - + std::cout << nn << " set values" << std::endl; for(auto& next_phi_vec : next_phi_gathered) { + _phi.reserve(_phi.size() + next_phi_vec.size()); for(auto& feat : next_phi_vec) { if(nn <= node_value_arrs::N_RUNGS_STORED) { - std::transform(_phi.begin(), _phi.end(), scores.begin(), [&feat](node_ptr f){return 1.0 - std::abs(util_funcs::r(feat->value_ptr(), f->value_ptr(), f->n_samp()));}); - if(*std::min_element(scores.begin(), scores.begin()+_phi.size()) > 1e-13) - _phi.push_back(feat); feat->set_value(); + _phi.push_back(feat); } else { @@ -157,7 +148,6 @@ void FeatureSpace::generate_feature_space() } } } - std::cout << "DONE"<< std::endl; } _n_feat = _phi.size(); } diff --git a/src/feature_creation/node/FeatureNode.hpp b/src/feature_creation/node/FeatureNode.hpp index 68f0a1a5..8d6db1ee 100644 --- a/src/feature_creation/node/FeatureNode.hpp +++ b/src/feature_creation/node/FeatureNode.hpp @@ -91,6 +91,11 @@ public: return std::all_of(value_ptr(), value_ptr() + _n_samp, [&mean](double d){return std::abs(d - mean) < 1e-12;}); } + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::FEAT;} + /** * @brief Accessor function to the value of the feature */ diff --git a/src/feature_creation/node/Node.hpp b/src/feature_creation/node/Node.hpp index fbc18b80..da370a64 100644 --- a/src/feature_creation/node/Node.hpp +++ b/src/feature_creation/node/Node.hpp @@ -105,6 +105,11 @@ public: */ virtual bool is_const() = 0; + /** + * @brief Returns the type of node this is + */ + virtual NODE_TYPE type() = 0; + /** * @brief Serialization function to send over MPI * diff --git a/src/feature_creation/node/operator_nodes/OperatorNode.hpp b/src/feature_creation/node/operator_nodes/OperatorNode.hpp index e2cf435f..1a3825b4 100644 --- a/src/feature_creation/node/operator_nodes/OperatorNode.hpp +++ b/src/feature_creation/node/operator_nodes/OperatorNode.hpp @@ -97,6 +97,11 @@ public: return std::all_of(value_ptr(), value_ptr() + _n_samp, [&mean](double d){return std::abs(d - mean) < 1e-12;}); } + /** + * @brief Returns the type of node this is + */ + virtual NODE_TYPE type() = 0; + /** * @brief Set up the feature value pointers */ diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.cpp index 4d4a5271..c0590ee1 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.cpp @@ -11,8 +11,11 @@ AbsDiffNode::AbsDiffNode(std::vector<node_ptr> feats, int rung, int feat_ind) : if(feats[0]->unit() != feats[1]->unit()) throw InvalidFeatureException(); + if((feats[0]->type() == NODE_TYPE::LOG) && (feats[1]->type() == NODE_TYPE::LOG)) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } @@ -22,8 +25,11 @@ AbsDiffNode::AbsDiffNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_in if(feat_1->unit() != feat_2->unit()) throw InvalidFeatureException(); + if((feat_1->type() == NODE_TYPE::LOG) && (feat_2->type() == NODE_TYPE::LOG)) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.hpp index 9cf2742f..0d3ec143 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::ABS_DIFF;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.cpp index 8212270e..71f35c94 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.cpp @@ -9,7 +9,7 @@ AbsNode::AbsNode(std::vector<node_ptr> feats, int rung, int feat_ind): OperatorNode(feats, rung, feat_ind) { set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } @@ -17,7 +17,7 @@ AbsNode::AbsNode(node_ptr feat, int rung, int feat_ind): OperatorNode({feat}, rung, feat_ind) { set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.hpp index c38f71b4..20afc8c5 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::ABS;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.cpp index 0d7081ba..1a75d051 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.cpp @@ -10,7 +10,7 @@ AddNode::AddNode(std::vector<node_ptr> feats, int rung, int feat_ind): throw InvalidFeatureException(); set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } @@ -21,7 +21,7 @@ AddNode::AddNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_ind): throw InvalidFeatureException(); set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.hpp index 6ae01b11..faff638c 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::ADD;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.cpp index 37f6eab6..07d40fb1 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.cpp @@ -9,8 +9,11 @@ CosNode::CosNode(std::vector<node_ptr> feats, int rung, int feat_ind): if(feats[0]->unit() != Unit()) throw InvalidFeatureException(); + if((feats[0]->type() == NODE_TYPE::SIN) || (feats[0]->type() == NODE_TYPE::COS)) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } @@ -20,8 +23,11 @@ CosNode::CosNode(node_ptr feat, int rung, int feat_ind): if(feat->unit() != Unit()) throw InvalidFeatureException(); + if((feat->type() == NODE_TYPE::SIN) || (feat->type() == NODE_TYPE::COS)) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.hpp index 04c39c81..bd246c74 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::COS;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.cpp index 1e441a6f..72cd11c9 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.cpp @@ -6,16 +6,22 @@ CbNode::CbNode() CbNode::CbNode(std::vector<node_ptr> feats, int rung, int feat_ind): OperatorNode(feats, rung, feat_ind) { + if(feats[0]->type() == NODE_TYPE::CBRT) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } CbNode::CbNode(node_ptr feat, int rung, int feat_ind): OperatorNode({feat}, rung, feat_ind) { + if(feat->type() == NODE_TYPE::CBRT) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.hpp index c62c9351..671d7312 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::CB;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.cpp index eb2517b1..18ab351a 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.cpp @@ -6,16 +6,22 @@ CbrtNode::CbrtNode() CbrtNode::CbrtNode(std::vector<node_ptr> feats, int rung, int feat_ind): OperatorNode(feats, rung, feat_ind) { + if(feats[0]->type() == NODE_TYPE::CB) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } CbrtNode::CbrtNode(node_ptr feat, int rung, int feat_ind): OperatorNode({feat}, rung, feat_ind) { + if(feat->type() == NODE_TYPE::CB) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.hpp index 29bab700..88f0e935 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::CBRT;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.cpp index 8138be10..2fcd1955 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.cpp @@ -6,16 +6,22 @@ DivNode::DivNode() DivNode::DivNode(std::vector<node_ptr> feats, int rung, int feat_ind): OperatorNode(feats, rung, feat_ind) { + if((feats[0]->type() == NODE_TYPE::INV) || (feats[1]->type() == NODE_TYPE::INV)) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } DivNode::DivNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_ind): OperatorNode({feat_1, feat_2}, rung, feat_ind) { + if((feat_1->type() == NODE_TYPE::INV) || (feat_2->type() == NODE_TYPE::INV)) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.hpp index 5d5e248e..08607cc8 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::DIV;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.cpp index a2a8d845..ed2450d6 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.cpp @@ -9,8 +9,11 @@ ExpNode::ExpNode(std::vector<node_ptr> feats, int rung, int feat_ind): if(feats[0]->unit() != Unit()) throw InvalidFeatureException(); + if((feats[0]->type() == NODE_TYPE::NEG_EXP) || (feats[0]->type() == NODE_TYPE::EXP) || (feats[0]->type() == NODE_TYPE::ADD) || (feats[0]->type() == NODE_TYPE::SUB) || (feats[0]->type() == NODE_TYPE::LOG)) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } @@ -20,8 +23,11 @@ ExpNode::ExpNode(node_ptr feat, int rung, int feat_ind): if(feat->unit() != Unit()) throw InvalidFeatureException(); + if((feat->type() == NODE_TYPE::NEG_EXP) || (feat->type() == NODE_TYPE::EXP) || (feat->type() == NODE_TYPE::ADD) || (feat->type() == NODE_TYPE::SUB) || (feat->type() == NODE_TYPE::LOG)) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.hpp index 6b5ed68f..46eb66c4 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::EXP;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.cpp index 6449a65c..e33471a1 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.cpp @@ -6,16 +6,22 @@ InvNode::InvNode() InvNode::InvNode(std::vector<node_ptr> feats, int rung, int feat_ind): OperatorNode(feats, rung, feat_ind) { + if(feats[0]->type() == NODE_TYPE::DIV) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } InvNode::InvNode(node_ptr feat, int rung, int feat_ind): OperatorNode({feat}, rung, feat_ind) { + if(feat->type() == NODE_TYPE::DIV) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.hpp index 7a6e3fb4..6bd02488 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::INV;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.cpp index 0f9939b0..2a3e77e2 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.cpp @@ -9,8 +9,11 @@ LogNode::LogNode(std::vector<node_ptr> feats, int rung, int feat_ind): if(feats[0]->unit() != Unit()) throw InvalidFeatureException(); + if((feats[0]->type() == NODE_TYPE::NEG_EXP) || (feats[0]->type() == NODE_TYPE::EXP) || (feats[0]->type() == NODE_TYPE::DIV) || (feats[0]->type() == NODE_TYPE::MULT) || (feats[0]->type() == NODE_TYPE::LOG)) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } @@ -20,8 +23,11 @@ LogNode::LogNode(node_ptr feat, int rung, int feat_ind): if(feat->unit() != Unit()) throw InvalidFeatureException(); + if((feat->type() == NODE_TYPE::NEG_EXP) || (feat->type() == NODE_TYPE::EXP) || (feat->type() == NODE_TYPE::DIV) || (feat->type() == NODE_TYPE::MULT) || (feat->type() == NODE_TYPE::LOG)) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.hpp index b728a1d4..db792788 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::LOG;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.cpp index 88661f5c..66729c80 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.cpp @@ -7,7 +7,7 @@ MultNode::MultNode(std::vector<node_ptr> feats, int rung, int feat_ind): OperatorNode(feats, rung, feat_ind) { set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } @@ -15,7 +15,7 @@ MultNode::MultNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_ind): OperatorNode({feat_1, feat_2}, rung, feat_ind) { set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.hpp index 9fc77618..9460ac6f 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::MULT;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.cpp index f706206e..ac6ff996 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.cpp @@ -9,8 +9,12 @@ NegExpNode::NegExpNode(std::vector<node_ptr> feats, int rung, int feat_ind): if(feats[0]->unit() != Unit()) throw InvalidFeatureException(); + if((feats[0]->type() == NODE_TYPE::NEG_EXP) || (feats[0]->type() == NODE_TYPE::EXP) || (feats[0]->type() == NODE_TYPE::ADD) || (feats[0]->type() == NODE_TYPE::SUB) || (feats[0]->type() == NODE_TYPE::LOG)) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + + if(is_nan() || is_const()) throw InvalidFeatureException(); } @@ -20,8 +24,12 @@ NegExpNode::NegExpNode(node_ptr feat, int rung, int feat_ind): if(feat->unit() != Unit()) throw InvalidFeatureException(); + if((feat->type() == NODE_TYPE::NEG_EXP) || (feat->type() == NODE_TYPE::EXP) || (feat->type() == NODE_TYPE::ADD) || (feat->type() == NODE_TYPE::SUB) || (feat->type() == NODE_TYPE::LOG)) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.hpp index 01aa9c0c..129036fd 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::NEG_EXP;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.cpp index 65ada0ce..c9845a1e 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.cpp @@ -9,8 +9,11 @@ SinNode::SinNode(std::vector<node_ptr> feats, int rung, int feat_ind): if(feats[0]->unit() != Unit()) throw InvalidFeatureException(); + if((feats[0]->type() == NODE_TYPE::SIN) || (feats[0]->type() == NODE_TYPE::COS)) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } @@ -20,8 +23,11 @@ SinNode::SinNode(node_ptr feat, int rung, int feat_ind): if(feat->unit() != Unit()) throw InvalidFeatureException(); + if((feat->type() == NODE_TYPE::SIN) || (feat->type() == NODE_TYPE::COS)) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.hpp index aa7c415e..306381cc 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::SIN;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.cpp index c419acfa..e69e3195 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.cpp @@ -6,16 +6,22 @@ SixPowNode::SixPowNode() SixPowNode::SixPowNode(std::vector<node_ptr> feats, int rung, int feat_ind): OperatorNode(feats, rung, feat_ind) { + if((feats[0]->type() == NODE_TYPE::CBRT) || (feats[0]->type() == NODE_TYPE::SQRT)) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } SixPowNode::SixPowNode(node_ptr feat, int rung, int feat_ind): OperatorNode({feat}, rung, feat_ind) { + if((feat->type() == NODE_TYPE::CBRT) || (feat->type() == NODE_TYPE::SQRT)) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.hpp index 06057fa1..85062f99 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::SIX_POW;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.cpp index 2f23801a..1fe69a5a 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.cpp @@ -6,16 +6,22 @@ SqNode::SqNode() SqNode::SqNode(std::vector<node_ptr> feats, int rung, int feat_ind): OperatorNode(feats, rung, feat_ind) { + if(feats[0]->type() == NODE_TYPE::SQRT) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } SqNode::SqNode(node_ptr feat, int rung, int feat_ind): OperatorNode({feat}, rung, feat_ind) { + if(feat->type() == NODE_TYPE::SQRT) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.hpp index 118eb3c4..8ced5433 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::SQ;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.cpp index 8a3b89e0..09673fd7 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.cpp @@ -6,16 +6,22 @@ SqrtNode::SqrtNode() SqrtNode::SqrtNode(std::vector<node_ptr> feats, int rung, int feat_ind): OperatorNode(feats, rung, feat_ind) { + if(feats[0]->type() == NODE_TYPE::SQRT) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } SqrtNode::SqrtNode(node_ptr feat, int rung, int feat_ind): OperatorNode({feat}, rung, feat_ind) { + if(feat->type() == NODE_TYPE::SQRT) + throw InvalidFeatureException(); + set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.hpp index bed54677..ab93e8e1 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::SQRT;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.cpp index 3874f9c0..4b7e6c7d 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.cpp @@ -10,7 +10,7 @@ SubNode::SubNode(std::vector<node_ptr> feats, int rung, int feat_ind): throw InvalidFeatureException(); set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } @@ -21,7 +21,7 @@ SubNode::SubNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_ind): throw InvalidFeatureException(); set_value(); -if(is_nan() || is_const()) + if(is_nan() || is_const()) throw InvalidFeatureException(); } diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.hpp index 10fc8a6e..5be24fe0 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.hpp @@ -19,6 +19,11 @@ public: void set_value(); + /** + * @brief Returns the type of node this is + */ + inline NODE_TYPE type(){return NODE_TYPE::SUB;} + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/value_storage/nodes_value_containers.cpp b/src/feature_creation/node/value_storage/nodes_value_containers.cpp index 2d2a02bd..d14f1c61 100644 --- a/src/feature_creation/node/value_storage/nodes_value_containers.cpp +++ b/src/feature_creation/node/value_storage/nodes_value_containers.cpp @@ -57,5 +57,6 @@ void node_value_arrs::setup_values_arr(int n_samples, int n_rung, int n_primary_ VALUES_ARR = std::unique_ptr<double[]>(new double[N_STORE_FEATURES * N_SAMPLES]); TEMP_STORAGE_ARR = std::unique_ptr<double[]>(new double[3 * N_STORE_FEATURES * N_SAMPLES]); TEMP_STORAGE_REG = std::unique_ptr<int[]>(new int[3 * N_STORE_FEATURES]); + std::copy_n(std::vector<int>(3*N_STORE_FEATURES, -1).data(), 3*N_STORE_FEATURES, TEMP_STORAGE_REG.get()); } diff --git a/src/feature_creation/node/value_storage/nodes_value_containers.hpp b/src/feature_creation/node/value_storage/nodes_value_containers.hpp index d06c82ac..daefbeba 100644 --- a/src/feature_creation/node/value_storage/nodes_value_containers.hpp +++ b/src/feature_creation/node/value_storage/nodes_value_containers.hpp @@ -1,6 +1,7 @@ #ifndef NODE_VALEU_ARR #define NODE_VALEU_ARR +#include <algorithm> #include <memory> #include <vector> -- GitLab