diff --git a/src/feature_creation/node/FeatureNode.cpp b/src/feature_creation/node/FeatureNode.cpp index 373bb586195ebd011949d389e49d349762092b2e..490b2b860dcb0379bda40bbaa90ac2f317877841 100644 --- a/src/feature_creation/node/FeatureNode.cpp +++ b/src/feature_creation/node/FeatureNode.cpp @@ -15,4 +15,24 @@ FeatureNode::FeatureNode(const FeatureNode &o) : Node(o) {} +void FeatureNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + if(add_sub_leaves.count(_expr) > 0) + add_sub_leaves[_expr] += pl_mn; + else + add_sub_leaves[_expr] = pl_mn; + + ++expected_abs_tot; +} + +void FeatureNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + if(div_mult_leaves.count(_expr) > 0) + div_mult_leaves[_expr] += fact; + else + div_mult_leaves[_expr] = fact; + + expected_abs_tot += std::abs(fact); +} + // BOOST_CLASS_EXPORT(FeatureNode) diff --git a/src/feature_creation/node/FeatureNode.hpp b/src/feature_creation/node/FeatureNode.hpp index 8537d57248b9ee7f94b6b64d54f351fae6ecaaea..f813e23a1ae8a29bfc3aa767f73150464e2a7ec6 100644 --- a/src/feature_creation/node/FeatureNode.hpp +++ b/src/feature_creation/node/FeatureNode.hpp @@ -88,6 +88,23 @@ public: */ inline int rung(int cur_rung = 0){return cur_rung;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + /** * @brief Serialization function to send over MPI * diff --git a/src/feature_creation/node/Node.hpp b/src/feature_creation/node/Node.hpp index 35b2e9416eb490ed820ee810a4a3521db0720e4a..89ec7456026af6fa01d86b524ff9ecb2ef279078 100644 --- a/src/feature_creation/node/Node.hpp +++ b/src/feature_creation/node/Node.hpp @@ -97,6 +97,24 @@ public: */ virtual int rung(int cur_rung = 0) = 0; + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + virtual void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) = 0; + + /** + * @brief update the dictionary used to check if an Mult/Div node is valid + * + * @param div_mult_leaves the dictionary used to check if an Mult/Div node is valid + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + virtual void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) = 0; + /** * @brief Serialization function to send over MPI * diff --git a/src/feature_creation/node/operator_nodes/OperatorNode.hpp b/src/feature_creation/node/operator_nodes/OperatorNode.hpp index 6cd8f4244e15fc4ef817ccc17978bb10e1ef0ca5..79390ecd695a2198e2dee11c355b3ad0fcbc47ac 100644 --- a/src/feature_creation/node/operator_nodes/OperatorNode.hpp +++ b/src/feature_creation/node/operator_nodes/OperatorNode.hpp @@ -90,6 +90,24 @@ public: * @brief Returns the type of node this is */ virtual NODE_TYPE type() = 0; + + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + virtual void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) = 0; + + /** + * @brief update the dictionary used to check if an Mult/Div node is valid + * + * @param div_mult_leaves the dictionary used to check if an Mult/Div node is valid + * @param fact amount to increment the dictionary by + */ + virtual void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) = 0; + }; #endif diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.cpp index 2453e9a40766ddf29ee380208c80e2ff20b96553..1dc3b2bce3ebaa7b48b08250f11d9b01815fec4a 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.cpp @@ -11,7 +11,19 @@ AbsDiffNode::AbsDiffNode(std::vector<node_ptr> feats, int rung, int feat_ind) : if(feats[0]->unit() != feats[1]->unit()) throw InvalidFeatureException(); - if((feats[0]->type() == NODE_TYPE::LOG) && (feats[1]->type() == NODE_TYPE::LOG)) + std::map<std::string, int> add_sub_leaves; + int expected_abs_tot = 0; + update_add_sub_leaves(add_sub_leaves, 1, expected_abs_tot); + + if((add_sub_leaves.size() < 2)) + throw InvalidFeatureException(); + + if(std::abs(std::accumulate(add_sub_leaves.begin(), add_sub_leaves.end(), -1*expected_abs_tot, [](int tot, auto el){return tot + std::abs(el.second);})) > 0) + throw InvalidFeatureException(); + + int add_sub_tot_first = add_sub_leaves.begin()->second; + + if((std::abs(add_sub_tot_first) > 1) && std::all_of(add_sub_leaves.begin(), add_sub_leaves.end(), [&add_sub_tot_first](auto el){return el.second == add_sub_tot_first;})) throw InvalidFeatureException(); set_value(); @@ -25,10 +37,43 @@ AbsDiffNode::AbsDiffNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_in if(feat_1->unit() != feat_2->unit()) throw InvalidFeatureException(); - if((feat_1->type() == NODE_TYPE::LOG) && (feat_2->type() == NODE_TYPE::LOG)) + std::map<std::string, int> add_sub_leaves; + int expected_abs_tot = 0; + update_add_sub_leaves(add_sub_leaves, 1, expected_abs_tot); + + if((add_sub_leaves.size() < 2)) + throw InvalidFeatureException(); + + if(std::abs(std::accumulate(add_sub_leaves.begin(), add_sub_leaves.end(), -1*expected_abs_tot, [](int tot, auto el){return tot + std::abs(el.second);})) > 0) throw InvalidFeatureException(); + int add_sub_tot_first = add_sub_leaves.begin()->second; + + if((std::abs(add_sub_tot_first) > 1) && std::all_of(add_sub_leaves.begin(), add_sub_leaves.end(), [&add_sub_tot_first](auto el){return el.second == add_sub_tot_first;})) + throw InvalidFeatureException(); set_value(); if(is_nan() || is_const()) throw InvalidFeatureException(); } + +void AbsDiffNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + std::string key = expr(); + if(add_sub_leaves.count(key) > 0) + add_sub_leaves[key] += pl_mn; + else + add_sub_leaves[key] = pl_mn; + + ++expected_abs_tot; +} + +void AbsDiffNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + std::string key = expr(); + if(div_mult_leaves.count(key) > 0) + div_mult_leaves[key] += fact; + else + div_mult_leaves[key] = fact; + + expected_abs_tot *= std::abs(fact); +} diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.hpp index 4ce38e04a77c11f669628ef26369f2870c827d8f..26477c60d364c0175f932af44a2c674204fbc34c 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::ABS_DIFF;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.cpp index 4426869a25a6c1d7cb8436cc6551a4e46d0ec56e..d740f8029e6f7e7458b1096dc62b10cec1e9b0a8 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.cpp @@ -20,3 +20,25 @@ AbsNode::AbsNode(node_ptr feat, int rung, int feat_ind): if(is_nan() || is_const()) throw InvalidFeatureException(); } + +void AbsNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + std::string key = expr(); + if(add_sub_leaves.count(key) > 0) + add_sub_leaves[key] += pl_mn; + else + add_sub_leaves[key] = pl_mn; + + ++expected_abs_tot; +} + +void AbsNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + std::string key = expr(); + if(div_mult_leaves.count(key) > 0) + div_mult_leaves[key] += fact; + else + div_mult_leaves[key] = fact; + + expected_abs_tot *= std::abs(fact); +} diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.hpp index 07523c4a6fddd9d03b71ab36f5a95c9af62d6429..ae3221f41a619fe8fc12bb3e3bc1c03aea538374 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::ABS;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.cpp index a500268db7c78e31a0d19520929b07b68a520edc..2922339361b4fbd385b59386d29652e38174af01 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.cpp @@ -9,6 +9,21 @@ AddNode::AddNode(std::vector<node_ptr> feats, int rung, int feat_ind): if(feats[0]->unit() != feats[1]->unit()) throw InvalidFeatureException(); + std::map<std::string, int> add_sub_leaves; + int expected_abs_tot = 0; + update_add_sub_leaves(add_sub_leaves, 1, expected_abs_tot); + + if((add_sub_leaves.size() < 2)) + throw InvalidFeatureException(); + + if(std::abs(std::accumulate(add_sub_leaves.begin(), add_sub_leaves.end(), -1*expected_abs_tot, [](int tot, auto el){return tot + std::abs(el.second);})) > 0) + throw InvalidFeatureException(); + + int add_sub_tot_first = add_sub_leaves.begin()->second; + + if((std::abs(add_sub_tot_first) > 1) && std::all_of(add_sub_leaves.begin(), add_sub_leaves.end(), [&add_sub_tot_first](auto el){return el.second == add_sub_tot_first;})) + throw InvalidFeatureException(); + set_value(); if(is_nan() || is_const()) throw InvalidFeatureException(); @@ -20,7 +35,39 @@ AddNode::AddNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_ind): if(feat_1->unit() != feat_2->unit()) throw InvalidFeatureException(); + std::map<std::string, int> add_sub_leaves; + int expected_abs_tot = 0; + update_add_sub_leaves(add_sub_leaves, 1, expected_abs_tot); + + if((add_sub_leaves.size() < 2)) + throw InvalidFeatureException(); + + if(std::abs(std::accumulate(add_sub_leaves.begin(), add_sub_leaves.end(), -1*expected_abs_tot, [](int tot, auto el){return tot + std::abs(el.second);})) > 0) + throw InvalidFeatureException(); + + int add_sub_tot_first = add_sub_leaves.begin()->second; + + if((std::abs(add_sub_tot_first) > 1) && std::all_of(add_sub_leaves.begin(), add_sub_leaves.end(), [&add_sub_tot_first](auto el){return el.second == add_sub_tot_first;})) + throw InvalidFeatureException(); + set_value(); if(is_nan() || is_const()) throw InvalidFeatureException(); } + +void AddNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + _feats[0]->update_add_sub_leaves(add_sub_leaves, pl_mn, expected_abs_tot); + _feats[1]->update_add_sub_leaves(add_sub_leaves, pl_mn, expected_abs_tot); +} + +void AddNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + std::string key = expr(); + if(div_mult_leaves.count(key) > 0) + div_mult_leaves[key] += fact; + else + div_mult_leaves[key] = fact; + + expected_abs_tot *= std::abs(fact); +} diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.hpp index 6e496d05f920bc7584d97d51f61fd769cdbe7f57..a11e280f823b0557a27337503a5faaf45b511321 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::ADD;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.cpp index 954d858b3636f816fe0609b07890bcb5ef1818cb..10167f15e7185b2b53aa760bc734f909541a95c2 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.cpp @@ -30,3 +30,26 @@ CosNode::CosNode(node_ptr feat, int rung, int feat_ind): if(is_nan() || is_const()) throw InvalidFeatureException(); } + +void CosNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + std::string key = expr(); + if(add_sub_leaves.count(key) > 0) + add_sub_leaves[key] += pl_mn; + else + add_sub_leaves[key] = pl_mn; + + ++expected_abs_tot; +} + +void CosNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + std::string key = expr(); + if(div_mult_leaves.count(key) > 0) + div_mult_leaves[key] += fact; + else + div_mult_leaves[key] = fact; + + expected_abs_tot *= std::abs(fact); +} + diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.hpp index 29fa37531ad8e446209cf2c5aef2dc2a46320e35..ceca1724965953819986946ce513d4635da72635 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::COS;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.cpp index e5d652d340a6b50fb10facb3e95b77850359329b..650817c17ca1a2b244c384f4055075a171e5787d 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.cpp @@ -24,3 +24,19 @@ CbNode::CbNode(node_ptr feat, int rung, int feat_ind): if(is_nan() || is_const()) throw InvalidFeatureException(); } + +void CbNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + std::string key = expr(); + if(add_sub_leaves.count(key) > 0) + add_sub_leaves[key] += pl_mn; + else + add_sub_leaves[key] = pl_mn; + + ++expected_abs_tot; +} + +void CbNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + _feats[0]->update_div_mult_leaves(div_mult_leaves, fact * 3.0, expected_abs_tot); +} diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.hpp index 3bad615dd54991c2a9037377e98b25c0c62105f5..b7e9ca7d6e2a8626b1daedd9664615c324692e89 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::CB;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.cpp index 21eb0e8b566056cd4d89a24d122df96589b3014b..8c0f7b166238d9c350cfd3e9314099018ecc24c0 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.cpp @@ -24,3 +24,19 @@ CbrtNode::CbrtNode(node_ptr feat, int rung, int feat_ind): if(is_nan() || is_const()) throw InvalidFeatureException(); } + +void CbrtNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + std::string key = expr(); + if(add_sub_leaves.count(key) > 0) + add_sub_leaves[key] += pl_mn; + else + add_sub_leaves[key] = pl_mn; + + ++expected_abs_tot; +} + +void CbrtNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + _feats[0]->update_div_mult_leaves(div_mult_leaves, fact / 3.0, expected_abs_tot); +} diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.hpp index fa78ec94089efc352226bfda621c98bac391019f..2d723daf9224f71a5a13eb6d19fb5eeac2640659 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::CBRT;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leavesis valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.cpp index b116cc12264e5a51d6b68448b305327c85b778b4..74961765ca798fa43198ccaac9b0de6776e94832 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.cpp @@ -9,6 +9,21 @@ DivNode::DivNode(std::vector<node_ptr> feats, int rung, int feat_ind): if((feats[0]->type() == NODE_TYPE::INV) || (feats[1]->type() == NODE_TYPE::INV)) throw InvalidFeatureException(); + std::map<std::string, double> div_mult_leaves; + double expected_abs_tot = 0.0; + update_div_mult_leaves(div_mult_leaves, 1.0, expected_abs_tot); + + if((div_mult_leaves.size() < 2)) + throw InvalidFeatureException(); + + if(std::abs(std::accumulate(div_mult_leaves.begin(), div_mult_leaves.end(), -1.0*expected_abs_tot, [](double tot, auto el){return tot + std::abs(el.second);})) > 1e-12) + throw InvalidFeatureException(); + + int div_mult_tot_first = div_mult_leaves.begin()->second; + + if((std::abs(div_mult_tot_first) > 1) && std::all_of(div_mult_leaves.begin(), div_mult_leaves.end(), [&div_mult_tot_first](auto el){return el.second == div_mult_tot_first;})) + throw InvalidFeatureException(); + set_value(); if(is_nan() || is_const()) throw InvalidFeatureException(); @@ -20,7 +35,39 @@ DivNode::DivNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_ind): if((feat_1->type() == NODE_TYPE::INV) || (feat_2->type() == NODE_TYPE::INV)) throw InvalidFeatureException(); + std::map<std::string, double> div_mult_leaves; + double expected_abs_tot = 0.0; + update_div_mult_leaves(div_mult_leaves, 1.0, expected_abs_tot); + + if((div_mult_leaves.size() < 2)) + throw InvalidFeatureException(); + + if(std::abs(std::accumulate(div_mult_leaves.begin(), div_mult_leaves.end(), -1.0*expected_abs_tot, [](double tot, auto el){return tot + std::abs(el.second);})) > 1e-12) + throw InvalidFeatureException(); + + int div_mult_tot_first = div_mult_leaves.begin()->second; + + if((std::abs(div_mult_tot_first) > 1) && std::all_of(div_mult_leaves.begin(), div_mult_leaves.end(), [&div_mult_tot_first](auto el){return el.second == div_mult_tot_first;})) + throw InvalidFeatureException(); + set_value(); if(is_nan() || is_const()) throw InvalidFeatureException(); } + +void DivNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + std::string key = expr(); + if(add_sub_leaves.count(key) > 0) + add_sub_leaves[key] += pl_mn; + else + add_sub_leaves[key] = pl_mn; + + ++expected_abs_tot; +} + +void DivNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + _feats[0]->update_div_mult_leaves(div_mult_leaves, fact, expected_abs_tot); + _feats[1]->update_div_mult_leaves(div_mult_leaves, -1.0*fact, expected_abs_tot); +} diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.hpp index 755acd61904ba1ce039b4c47b4988da318597f18..a98ad4755b7fb0c56755c756cbd1ca969d16eca7 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::DIV;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.cpp index 3a6f47ae712a15273725592b22042c7469b2ddaa..8e81ea5fbe2d529a1c9011b61cace2a526cafd29 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.cpp @@ -30,3 +30,25 @@ ExpNode::ExpNode(node_ptr feat, int rung, int feat_ind): if(is_nan() || is_const()) throw InvalidFeatureException(); } + +void ExpNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + std::string key = expr(); + if(add_sub_leaves.count(key) > 0) + add_sub_leaves[key] += pl_mn; + else + add_sub_leaves[key] = pl_mn; + + ++expected_abs_tot; +} + +void ExpNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + std::string key = expr(); + if(div_mult_leaves.count(key) > 0) + div_mult_leaves[key] += fact; + else + div_mult_leaves[key] = fact; + + expected_abs_tot *= std::abs(fact); +} diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.hpp index 605bd4f04fb93894aed26470c9bcd3611b0bc09c..5ed79cbc3fa471ae65a8548af28986e5dcda4789 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::EXP;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.cpp index 1cfe5f501b29cd59d7ec05fe6383133c19422b66..bdccb2d19b629e0c7aaae9c8ca8e61766d9bb2e9 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.cpp @@ -6,7 +6,7 @@ InvNode::InvNode() InvNode::InvNode(std::vector<node_ptr> feats, int rung, int feat_ind): OperatorNode(feats, rung, feat_ind) { - if(feats[0]->type() == NODE_TYPE::DIV) + if((feats[0]->type() == NODE_TYPE::DIV) || (feats[0]->type() == NODE_TYPE::EXP) || (feats[0]->type() == NODE_TYPE::NEG_EXP)) throw InvalidFeatureException(); set_value(); @@ -17,10 +17,26 @@ InvNode::InvNode(std::vector<node_ptr> feats, int rung, int feat_ind): InvNode::InvNode(node_ptr feat, int rung, int feat_ind): OperatorNode({feat}, rung, feat_ind) { - if(feat->type() == NODE_TYPE::DIV) + if((feat->type() == NODE_TYPE::DIV) || (feat->type() == NODE_TYPE::EXP) || (feat->type() == NODE_TYPE::NEG_EXP)) throw InvalidFeatureException(); set_value(); if(is_nan() || is_const()) throw InvalidFeatureException(); } + +void InvNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + std::string key = expr(); + if(add_sub_leaves.count(key) > 0) + add_sub_leaves[key] += pl_mn; + else + add_sub_leaves[key] = pl_mn; + + ++expected_abs_tot; +} + +void InvNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + _feats[0]->update_div_mult_leaves(div_mult_leaves, fact * -1.0, expected_abs_tot); +} diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.hpp index 4f2c24b2263c05b3d16dbc92663229123c56cb8d..46c088a3c71d763f65a146652c5b7b7433ec397b 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::INV;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.cpp index 6196fcd413519439c80e4e6ae67b3ee0d9fd7065..cab91e516d6d138cd42be1b5faea50c0bb759fa7 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.cpp @@ -9,7 +9,7 @@ LogNode::LogNode(std::vector<node_ptr> feats, int rung, int feat_ind): if(feats[0]->unit() != Unit()) throw InvalidFeatureException(); - if((feats[0]->type() == NODE_TYPE::NEG_EXP) || (feats[0]->type() == NODE_TYPE::EXP) || (feats[0]->type() == NODE_TYPE::DIV) || (feats[0]->type() == NODE_TYPE::MULT) || (feats[0]->type() == NODE_TYPE::LOG)) + if((feats[0]->type() == NODE_TYPE::NEG_EXP) || (feats[0]->type() == NODE_TYPE::EXP) || (feats[0]->type() == NODE_TYPE::DIV) || (feats[0]->type() == NODE_TYPE::INV) || (feats[0]->type() == NODE_TYPE::MULT) || (feats[0]->type() == NODE_TYPE::LOG) || (feats[0]->type() == NODE_TYPE::SIX_POW) || (feats[0]->type() == NODE_TYPE::CB) || (feats[0]->type() == NODE_TYPE::SQ) || (feats[0]->type() == NODE_TYPE::CBRT) || (feats[0]->type() == NODE_TYPE::SQRT)) throw InvalidFeatureException(); set_value(); @@ -23,10 +23,32 @@ LogNode::LogNode(node_ptr feat, int rung, int feat_ind): if(feat->unit() != Unit()) throw InvalidFeatureException(); - if((feat->type() == NODE_TYPE::NEG_EXP) || (feat->type() == NODE_TYPE::EXP) || (feat->type() == NODE_TYPE::DIV) || (feat->type() == NODE_TYPE::MULT) || (feat->type() == NODE_TYPE::LOG)) + if((feat->type() == NODE_TYPE::NEG_EXP) || (feat->type() == NODE_TYPE::EXP) || (feat->type() == NODE_TYPE::DIV) || (feat->type() == NODE_TYPE::INV) || (feat->type() == NODE_TYPE::MULT) || (feat->type() == NODE_TYPE::LOG) || (feat->type() == NODE_TYPE::SIX_POW) || (feat->type() == NODE_TYPE::CB) || (feat->type() == NODE_TYPE::SQ) || (feat->type() == NODE_TYPE::CBRT) || (feat->type() == NODE_TYPE::SQRT)) throw InvalidFeatureException(); set_value(); if(is_nan() || is_const()) throw InvalidFeatureException(); } + +void LogNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + std::string key = expr(); + if(add_sub_leaves.count(key) > 0) + add_sub_leaves[key] += pl_mn; + else + add_sub_leaves[key] = pl_mn; + + ++expected_abs_tot; +} + +void LogNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + std::string key = expr(); + if(div_mult_leaves.count(key) > 0) + div_mult_leaves[key] += fact; + else + div_mult_leaves[key] = fact; + + expected_abs_tot *= std::abs(fact); +} diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.hpp index 8ba84f86828ae282d84f940322406c8ceb3f617c..00d6fa1eb984c21ec4929f4a86329728180d5559 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::LOG;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.cpp index 7e3b23cda6d5f01a7e2b719767b2bd8c73259f44..c92d37efeac8808e5329524b45ff09c535427ea7 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.cpp @@ -6,6 +6,21 @@ MultNode::MultNode() MultNode::MultNode(std::vector<node_ptr> feats, int rung, int feat_ind): OperatorNode(feats, rung, feat_ind) { + std::map<std::string, double> div_mult_leaves; + double expected_abs_tot = 0.0; + update_div_mult_leaves(div_mult_leaves, 1.0, expected_abs_tot); + + if((div_mult_leaves.size() < 2)) + throw InvalidFeatureException(); + + if(std::abs(std::accumulate(div_mult_leaves.begin(), div_mult_leaves.end(), -1.0*expected_abs_tot, [](double tot, auto el){return tot + std::abs(el.second);})) > 1e-12) + throw InvalidFeatureException(); + + int div_mult_tot_first = div_mult_leaves.begin()->second; + + if((std::abs(div_mult_tot_first) > 1) && std::all_of(div_mult_leaves.begin(), div_mult_leaves.end(), [&div_mult_tot_first](auto el){return el.second == div_mult_tot_first;})) + throw InvalidFeatureException(); + set_value(); if(is_nan() || is_const()) throw InvalidFeatureException(); @@ -14,8 +29,39 @@ MultNode::MultNode(std::vector<node_ptr> feats, int rung, int feat_ind): MultNode::MultNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_ind): OperatorNode({feat_1, feat_2}, rung, feat_ind) { + std::map<std::string, double> div_mult_leaves; + double expected_abs_tot = 0.0; + update_div_mult_leaves(div_mult_leaves, 1.0, expected_abs_tot); + + if((div_mult_leaves.size() < 2)) + throw InvalidFeatureException(); + + if(std::abs(std::accumulate(div_mult_leaves.begin(), div_mult_leaves.end(), -1.0*expected_abs_tot, [](double tot, auto el){return tot + std::abs(el.second);})) > 1e-12) + throw InvalidFeatureException(); + + int div_mult_tot_first = div_mult_leaves.begin()->second; + + if((std::abs(div_mult_tot_first) > 1) && std::all_of(div_mult_leaves.begin(), div_mult_leaves.end(), [&div_mult_tot_first](auto el){return el.second == div_mult_tot_first;})) + throw InvalidFeatureException(); + set_value(); if(is_nan() || is_const()) throw InvalidFeatureException(); } +void MultNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + std::string key = expr(); + if(add_sub_leaves.count(key) > 0) + add_sub_leaves[key] += pl_mn; + else + add_sub_leaves[key] = pl_mn; + + ++expected_abs_tot; +} + +void MultNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + _feats[0]->update_div_mult_leaves(div_mult_leaves, fact, expected_abs_tot); + _feats[1]->update_div_mult_leaves(div_mult_leaves, fact, expected_abs_tot); +} diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.hpp index aa93bbd9a4814b6e91df2ee52df42a4488186f23..b29516e6d84c21f7261e9f582322984f325cf95f 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::MULT;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.cpp index 6cb67bb1f6cb75459f5ebfcecbf90fb613944d03..4776b3964cfc52c7a0349b98d56b77fcf9432075 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.cpp @@ -32,3 +32,25 @@ NegExpNode::NegExpNode(node_ptr feat, int rung, int feat_ind): if(is_nan() || is_const()) throw InvalidFeatureException(); } + +void NegExpNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + std::string key = expr(); + if(add_sub_leaves.count(key) > 0) + add_sub_leaves[key] += pl_mn; + else + add_sub_leaves[key] = pl_mn; + + ++expected_abs_tot; +} + +void NegExpNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + std::string key = "exp(" + _feats[0]->expr() + ")"; + if(div_mult_leaves.count(key) > 0) + div_mult_leaves[key] -= fact; + else + div_mult_leaves[key] = -1.0 * fact; + + expected_abs_tot *= std::abs(fact); +} diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.hpp index 32e4d7859c43cef9bc0284f09b21b44cf79d7f78..e08b206b102ed6b16cb7e1304a287070d3dd3baa 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::NEG_EXP;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.cpp index ceba2dded9d97bc8af0bca223313b7a0c6f15b32..5aafb0533fcfe10ebe80b8ca8f4b6bc6de5ca3bf 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.cpp @@ -30,3 +30,25 @@ SinNode::SinNode(node_ptr feat, int rung, int feat_ind): if(is_nan() || is_const()) throw InvalidFeatureException(); } + +void SinNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + std::string key = expr(); + if(add_sub_leaves.count(key) > 0) + add_sub_leaves[key] += pl_mn; + else + add_sub_leaves[key] = pl_mn; + + ++expected_abs_tot; +} + +void SinNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + std::string key = expr(); + if(div_mult_leaves.count(key) > 0) + div_mult_leaves[key] += fact; + else + div_mult_leaves[key] = fact; + + expected_abs_tot *= std::abs(fact); +} diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.hpp index d871a87e23ca804d2fa7125d1ffd0dda78459378..db423ac2c741a06e93f6c8d533e304554bebb94e 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::SIN;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.cpp index 3d1063d4e86181888d7633ba6763b21a0f3b035c..4c85272e3951687d9fa53bb696eb651c31268825 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.cpp @@ -24,3 +24,19 @@ SixPowNode::SixPowNode(node_ptr feat, int rung, int feat_ind): if(is_nan() || is_const()) throw InvalidFeatureException(); } + +void SixPowNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + std::string key = expr(); + if(add_sub_leaves.count(key) > 0) + add_sub_leaves[key] += pl_mn; + else + add_sub_leaves[key] = pl_mn; + + ++expected_abs_tot; +} + +void SixPowNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + _feats[0]->update_div_mult_leaves(div_mult_leaves, fact * 6.0, expected_abs_tot); +} diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.hpp index 6c2535be438fe96704ebfd0ba2a49437933aec89..57e4716069eae22d6ca5f15ac5c5b1d4f9d23831 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::SIX_POW;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.cpp index 082f9e30299ca761e554572455776e5a403a9e62..4523668d484cfc662e6abf7685f486d0b7ebc335 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.cpp @@ -24,3 +24,19 @@ SqNode::SqNode(node_ptr feat, int rung, int feat_ind): if(is_nan() || is_const()) throw InvalidFeatureException(); } + +void SqNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + std::string key = expr(); + if(add_sub_leaves.count(key) > 0) + add_sub_leaves[key] += pl_mn; + else + add_sub_leaves[key] = pl_mn; + + ++expected_abs_tot; +} + +void SqNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + _feats[0]->update_div_mult_leaves(div_mult_leaves, fact * 2.0, expected_abs_tot); +} diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.hpp index 93ed269c1e348ab486e5a870bc93cd9d4cfdede7..669e69b45415ac909f633c0e482a580b53f9029e 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::SQ;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.cpp index 0bbcaa9acf6666f5cd58443dc1e4efabe3a0a835..50cd210864ae046dd1a728e043d0f522ff7f7745 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.cpp @@ -24,3 +24,19 @@ SqrtNode::SqrtNode(node_ptr feat, int rung, int feat_ind): if(is_nan() || is_const()) throw InvalidFeatureException(); } + +void SqrtNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + std::string key = expr(); + if(add_sub_leaves.count(key) > 0) + add_sub_leaves[key] += pl_mn; + else + add_sub_leaves[key] = pl_mn; + + ++expected_abs_tot; +} + +void SqrtNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + _feats[0]->update_div_mult_leaves(div_mult_leaves, fact / 2.0, expected_abs_tot); +} diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.hpp index 086328a428fa0484adb1e3a3944e06d539d5011e..e1f33e68499716c468c9a4d1070d3ee326919c80 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::SQRT;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) { diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.cpp index 9c98078ab4143bab73a91898b8a4f0c7e0173665..801d9d77f0b4e67b2123fa694e49520772888880 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.cpp @@ -9,6 +9,21 @@ SubNode::SubNode(std::vector<node_ptr> feats, int rung, int feat_ind): if(feats[0]->unit() != feats[1]->unit()) throw InvalidFeatureException(); + std::map<std::string, int> add_sub_leaves; + int expected_abs_tot = 0; + update_add_sub_leaves(add_sub_leaves, 1, expected_abs_tot); + + if((add_sub_leaves.size() < 2)) + throw InvalidFeatureException(); + + if(std::abs(std::accumulate(add_sub_leaves.begin(), add_sub_leaves.end(), -1*expected_abs_tot, [](int tot, auto el){return tot + std::abs(el.second);})) > 0) + throw InvalidFeatureException(); + + int add_sub_tot_first = add_sub_leaves.begin()->second; + + if((std::abs(add_sub_tot_first) > 1) && std::all_of(add_sub_leaves.begin(), add_sub_leaves.end(), [&add_sub_tot_first](auto el){return el.second == add_sub_tot_first;})) + throw InvalidFeatureException(); + set_value(); if(is_nan() || is_const()) throw InvalidFeatureException(); @@ -20,7 +35,39 @@ SubNode::SubNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_ind): if(feat_1->unit() != feat_2->unit()) throw InvalidFeatureException(); + std::map<std::string, int> add_sub_leaves; + int expected_abs_tot = 0; + update_add_sub_leaves(add_sub_leaves, 1, expected_abs_tot); + + if((add_sub_leaves.size() < 2)) + throw InvalidFeatureException(); + + if(std::abs(std::accumulate(add_sub_leaves.begin(), add_sub_leaves.end(), -1*expected_abs_tot, [](int tot, auto el){return tot + std::abs(el.second);})) > 0) + throw InvalidFeatureException(); + + int add_sub_tot_first = add_sub_leaves.begin()->second; + + if((std::abs(add_sub_tot_first) > 1) && std::all_of(add_sub_leaves.begin(), add_sub_leaves.end(), [&add_sub_tot_first](auto el){return el.second == add_sub_tot_first;})) + throw InvalidFeatureException(); + set_value(); if(is_nan() || is_const()) throw InvalidFeatureException(); } + +void SubNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) +{ + _feats[0]->update_add_sub_leaves(add_sub_leaves, pl_mn, expected_abs_tot); + _feats[1]->update_add_sub_leaves(add_sub_leaves, -1*pl_mn, expected_abs_tot); +} + +void SubNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) +{ + std::string key = expr(); + if(div_mult_leaves.count(key) > 0) + div_mult_leaves[key] += fact; + else + div_mult_leaves[key] = fact; + + expected_abs_tot *= std::abs(fact); +} diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.hpp index daa24450ba43d00c85f4fc55eda2cb6d79a12c68..c8b167d24ed442a7106b82015591767ec4e93211 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.hpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.hpp @@ -33,6 +33,25 @@ public: */ inline NODE_TYPE type(){return NODE_TYPE::SUB;} + /** + * @brief update the dictionary used to check if an Add/Sub node is valid + * + * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid + * @param pl_mn if for an addition node: 1 if for a subtraction node: -1 + * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves + */ + void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot); + + /** + * @brief update the dictionary used to check if + * @details [long description] + * + * @param add_sub_leaves [description] + * @param fact amount to increment the dictionary by + * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves + */ + void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot); + template <typename Archive> void serialize(Archive& ar, const unsigned int version) {