From 77ab05f920d06c825e7792386ffa4e4919d94f3b Mon Sep 17 00:00:00 2001
From: Thomas Purcell <purcell@fhi-berlin.mpg.de>
Date: Tue, 2 Jun 2020 14:34:46 +0200
Subject: [PATCH] Basic Simiplification Added for Add, Subtract, Abs_Diff, Mult
and Div Nodes
Checks for cancelations and duplicate features
---
src/feature_creation/node/FeatureNode.cpp | 20 ++++++++
src/feature_creation/node/FeatureNode.hpp | 17 +++++++
src/feature_creation/node/Node.hpp | 18 +++++++
.../node/operator_nodes/OperatorNode.hpp | 18 +++++++
.../absolute_difference.cpp | 49 ++++++++++++++++++-
.../absolute_difference.hpp | 19 +++++++
.../allowed_operator_nodes/absolute_value.cpp | 22 +++++++++
.../allowed_operator_nodes/absolute_value.hpp | 19 +++++++
.../allowed_operator_nodes/add.cpp | 47 ++++++++++++++++++
.../allowed_operator_nodes/add.hpp | 19 +++++++
.../allowed_operator_nodes/cos.cpp | 23 +++++++++
.../allowed_operator_nodes/cos.hpp | 19 +++++++
.../allowed_operator_nodes/cube.cpp | 16 ++++++
.../allowed_operator_nodes/cube.hpp | 19 +++++++
.../allowed_operator_nodes/cube_root.cpp | 16 ++++++
.../allowed_operator_nodes/cube_root.hpp | 19 +++++++
.../allowed_operator_nodes/divide.cpp | 47 ++++++++++++++++++
.../allowed_operator_nodes/divide.hpp | 19 +++++++
.../allowed_operator_nodes/exponential.cpp | 22 +++++++++
.../allowed_operator_nodes/exponential.hpp | 19 +++++++
.../allowed_operator_nodes/inverse.cpp | 20 +++++++-
.../allowed_operator_nodes/inverse.hpp | 19 +++++++
.../allowed_operator_nodes/log.cpp | 26 +++++++++-
.../allowed_operator_nodes/log.hpp | 19 +++++++
.../allowed_operator_nodes/multiply.cpp | 46 +++++++++++++++++
.../allowed_operator_nodes/multiply.hpp | 19 +++++++
.../negative_exponential.cpp | 22 +++++++++
.../negative_exponential.hpp | 19 +++++++
.../allowed_operator_nodes/sin.cpp | 22 +++++++++
.../allowed_operator_nodes/sin.hpp | 19 +++++++
.../allowed_operator_nodes/sixth_power.cpp | 16 ++++++
.../allowed_operator_nodes/sixth_power.hpp | 19 +++++++
.../allowed_operator_nodes/square.cpp | 16 ++++++
.../allowed_operator_nodes/square.hpp | 19 +++++++
.../allowed_operator_nodes/square_root.cpp | 16 ++++++
.../allowed_operator_nodes/square_root.hpp | 19 +++++++
.../allowed_operator_nodes/subtract.cpp | 47 ++++++++++++++++++
.../allowed_operator_nodes/subtract.hpp | 19 +++++++
38 files changed, 863 insertions(+), 6 deletions(-)
diff --git a/src/feature_creation/node/FeatureNode.cpp b/src/feature_creation/node/FeatureNode.cpp
index 373bb586..490b2b86 100644
--- a/src/feature_creation/node/FeatureNode.cpp
+++ b/src/feature_creation/node/FeatureNode.cpp
@@ -15,4 +15,24 @@ FeatureNode::FeatureNode(const FeatureNode &o) :
Node(o)
{}
+void FeatureNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ if(add_sub_leaves.count(_expr) > 0)
+ add_sub_leaves[_expr] += pl_mn;
+ else
+ add_sub_leaves[_expr] = pl_mn;
+
+ ++expected_abs_tot;
+}
+
+void FeatureNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ if(div_mult_leaves.count(_expr) > 0)
+ div_mult_leaves[_expr] += fact;
+ else
+ div_mult_leaves[_expr] = fact;
+
+ expected_abs_tot += std::abs(fact);
+}
+
// BOOST_CLASS_EXPORT(FeatureNode)
diff --git a/src/feature_creation/node/FeatureNode.hpp b/src/feature_creation/node/FeatureNode.hpp
index 8537d572..f813e23a 100644
--- a/src/feature_creation/node/FeatureNode.hpp
+++ b/src/feature_creation/node/FeatureNode.hpp
@@ -88,6 +88,23 @@ public:
*/
inline int rung(int cur_rung = 0){return cur_rung;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
/**
* @brief Serialization function to send over MPI
*
diff --git a/src/feature_creation/node/Node.hpp b/src/feature_creation/node/Node.hpp
index 35b2e941..89ec7456 100644
--- a/src/feature_creation/node/Node.hpp
+++ b/src/feature_creation/node/Node.hpp
@@ -97,6 +97,24 @@ public:
*/
virtual int rung(int cur_rung = 0) = 0;
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ virtual void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) = 0;
+
+ /**
+ * @brief update the dictionary used to check if an Mult/Div node is valid
+ *
+ * @param div_mult_leaves the dictionary used to check if an Mult/Div node is valid
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ virtual void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) = 0;
+
/**
* @brief Serialization function to send over MPI
*
diff --git a/src/feature_creation/node/operator_nodes/OperatorNode.hpp b/src/feature_creation/node/operator_nodes/OperatorNode.hpp
index 6cd8f424..79390ecd 100644
--- a/src/feature_creation/node/operator_nodes/OperatorNode.hpp
+++ b/src/feature_creation/node/operator_nodes/OperatorNode.hpp
@@ -90,6 +90,24 @@ public:
* @brief Returns the type of node this is
*/
virtual NODE_TYPE type() = 0;
+
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ virtual void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) = 0;
+
+ /**
+ * @brief update the dictionary used to check if an Mult/Div node is valid
+ *
+ * @param div_mult_leaves the dictionary used to check if an Mult/Div node is valid
+ * @param fact amount to increment the dictionary by
+ */
+ virtual void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) = 0;
+
};
#endif
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.cpp
index 2453e9a4..1dc3b2bc 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.cpp
@@ -11,7 +11,19 @@ AbsDiffNode::AbsDiffNode(std::vector<node_ptr> feats, int rung, int feat_ind) :
if(feats[0]->unit() != feats[1]->unit())
throw InvalidFeatureException();
- if((feats[0]->type() == NODE_TYPE::LOG) && (feats[1]->type() == NODE_TYPE::LOG))
+ std::map<std::string, int> add_sub_leaves;
+ int expected_abs_tot = 0;
+ update_add_sub_leaves(add_sub_leaves, 1, expected_abs_tot);
+
+ if((add_sub_leaves.size() < 2))
+ throw InvalidFeatureException();
+
+ if(std::abs(std::accumulate(add_sub_leaves.begin(), add_sub_leaves.end(), -1*expected_abs_tot, [](int tot, auto el){return tot + std::abs(el.second);})) > 0)
+ throw InvalidFeatureException();
+
+ int add_sub_tot_first = add_sub_leaves.begin()->second;
+
+ if((std::abs(add_sub_tot_first) > 1) && std::all_of(add_sub_leaves.begin(), add_sub_leaves.end(), [&add_sub_tot_first](auto el){return el.second == add_sub_tot_first;}))
throw InvalidFeatureException();
set_value();
@@ -25,10 +37,43 @@ AbsDiffNode::AbsDiffNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_in
if(feat_1->unit() != feat_2->unit())
throw InvalidFeatureException();
- if((feat_1->type() == NODE_TYPE::LOG) && (feat_2->type() == NODE_TYPE::LOG))
+ std::map<std::string, int> add_sub_leaves;
+ int expected_abs_tot = 0;
+ update_add_sub_leaves(add_sub_leaves, 1, expected_abs_tot);
+
+ if((add_sub_leaves.size() < 2))
+ throw InvalidFeatureException();
+
+ if(std::abs(std::accumulate(add_sub_leaves.begin(), add_sub_leaves.end(), -1*expected_abs_tot, [](int tot, auto el){return tot + std::abs(el.second);})) > 0)
throw InvalidFeatureException();
+ int add_sub_tot_first = add_sub_leaves.begin()->second;
+
+ if((std::abs(add_sub_tot_first) > 1) && std::all_of(add_sub_leaves.begin(), add_sub_leaves.end(), [&add_sub_tot_first](auto el){return el.second == add_sub_tot_first;}))
+ throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+
+void AbsDiffNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ std::string key = expr();
+ if(add_sub_leaves.count(key) > 0)
+ add_sub_leaves[key] += pl_mn;
+ else
+ add_sub_leaves[key] = pl_mn;
+
+ ++expected_abs_tot;
+}
+
+void AbsDiffNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ std::string key = expr();
+ if(div_mult_leaves.count(key) > 0)
+ div_mult_leaves[key] += fact;
+ else
+ div_mult_leaves[key] = fact;
+
+ expected_abs_tot *= std::abs(fact);
+}
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.hpp
index 4ce38e04..26477c60 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_difference.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::ABS_DIFF;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.cpp
index 4426869a..d740f802 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.cpp
@@ -20,3 +20,25 @@ AbsNode::AbsNode(node_ptr feat, int rung, int feat_ind):
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+
+void AbsNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ std::string key = expr();
+ if(add_sub_leaves.count(key) > 0)
+ add_sub_leaves[key] += pl_mn;
+ else
+ add_sub_leaves[key] = pl_mn;
+
+ ++expected_abs_tot;
+}
+
+void AbsNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ std::string key = expr();
+ if(div_mult_leaves.count(key) > 0)
+ div_mult_leaves[key] += fact;
+ else
+ div_mult_leaves[key] = fact;
+
+ expected_abs_tot *= std::abs(fact);
+}
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.hpp
index 07523c4a..ae3221f4 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/absolute_value.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::ABS;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.cpp
index a500268d..29223393 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.cpp
@@ -9,6 +9,21 @@ AddNode::AddNode(std::vector<node_ptr> feats, int rung, int feat_ind):
if(feats[0]->unit() != feats[1]->unit())
throw InvalidFeatureException();
+ std::map<std::string, int> add_sub_leaves;
+ int expected_abs_tot = 0;
+ update_add_sub_leaves(add_sub_leaves, 1, expected_abs_tot);
+
+ if((add_sub_leaves.size() < 2))
+ throw InvalidFeatureException();
+
+ if(std::abs(std::accumulate(add_sub_leaves.begin(), add_sub_leaves.end(), -1*expected_abs_tot, [](int tot, auto el){return tot + std::abs(el.second);})) > 0)
+ throw InvalidFeatureException();
+
+ int add_sub_tot_first = add_sub_leaves.begin()->second;
+
+ if((std::abs(add_sub_tot_first) > 1) && std::all_of(add_sub_leaves.begin(), add_sub_leaves.end(), [&add_sub_tot_first](auto el){return el.second == add_sub_tot_first;}))
+ throw InvalidFeatureException();
+
set_value();
if(is_nan() || is_const())
throw InvalidFeatureException();
@@ -20,7 +35,39 @@ AddNode::AddNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_ind):
if(feat_1->unit() != feat_2->unit())
throw InvalidFeatureException();
+ std::map<std::string, int> add_sub_leaves;
+ int expected_abs_tot = 0;
+ update_add_sub_leaves(add_sub_leaves, 1, expected_abs_tot);
+
+ if((add_sub_leaves.size() < 2))
+ throw InvalidFeatureException();
+
+ if(std::abs(std::accumulate(add_sub_leaves.begin(), add_sub_leaves.end(), -1*expected_abs_tot, [](int tot, auto el){return tot + std::abs(el.second);})) > 0)
+ throw InvalidFeatureException();
+
+ int add_sub_tot_first = add_sub_leaves.begin()->second;
+
+ if((std::abs(add_sub_tot_first) > 1) && std::all_of(add_sub_leaves.begin(), add_sub_leaves.end(), [&add_sub_tot_first](auto el){return el.second == add_sub_tot_first;}))
+ throw InvalidFeatureException();
+
set_value();
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+
+void AddNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ _feats[0]->update_add_sub_leaves(add_sub_leaves, pl_mn, expected_abs_tot);
+ _feats[1]->update_add_sub_leaves(add_sub_leaves, pl_mn, expected_abs_tot);
+}
+
+void AddNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ std::string key = expr();
+ if(div_mult_leaves.count(key) > 0)
+ div_mult_leaves[key] += fact;
+ else
+ div_mult_leaves[key] = fact;
+
+ expected_abs_tot *= std::abs(fact);
+}
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.hpp
index 6e496d05..a11e280f 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/add.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::ADD;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.cpp
index 954d858b..10167f15 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.cpp
@@ -30,3 +30,26 @@ CosNode::CosNode(node_ptr feat, int rung, int feat_ind):
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+
+void CosNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ std::string key = expr();
+ if(add_sub_leaves.count(key) > 0)
+ add_sub_leaves[key] += pl_mn;
+ else
+ add_sub_leaves[key] = pl_mn;
+
+ ++expected_abs_tot;
+}
+
+void CosNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ std::string key = expr();
+ if(div_mult_leaves.count(key) > 0)
+ div_mult_leaves[key] += fact;
+ else
+ div_mult_leaves[key] = fact;
+
+ expected_abs_tot *= std::abs(fact);
+}
+
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.hpp
index 29fa3753..ceca1724 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cos.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::COS;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.cpp
index e5d652d3..650817c1 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.cpp
@@ -24,3 +24,19 @@ CbNode::CbNode(node_ptr feat, int rung, int feat_ind):
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+
+void CbNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ std::string key = expr();
+ if(add_sub_leaves.count(key) > 0)
+ add_sub_leaves[key] += pl_mn;
+ else
+ add_sub_leaves[key] = pl_mn;
+
+ ++expected_abs_tot;
+}
+
+void CbNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ _feats[0]->update_div_mult_leaves(div_mult_leaves, fact * 3.0, expected_abs_tot);
+}
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.hpp
index 3bad615d..b7e9ca7d 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::CB;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.cpp
index 21eb0e8b..8c0f7b16 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.cpp
@@ -24,3 +24,19 @@ CbrtNode::CbrtNode(node_ptr feat, int rung, int feat_ind):
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+
+void CbrtNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ std::string key = expr();
+ if(add_sub_leaves.count(key) > 0)
+ add_sub_leaves[key] += pl_mn;
+ else
+ add_sub_leaves[key] = pl_mn;
+
+ ++expected_abs_tot;
+}
+
+void CbrtNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ _feats[0]->update_div_mult_leaves(div_mult_leaves, fact / 3.0, expected_abs_tot);
+}
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.hpp
index fa78ec94..2d723daf 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/cube_root.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::CBRT;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leavesis valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.cpp
index b116cc12..74961765 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.cpp
@@ -9,6 +9,21 @@ DivNode::DivNode(std::vector<node_ptr> feats, int rung, int feat_ind):
if((feats[0]->type() == NODE_TYPE::INV) || (feats[1]->type() == NODE_TYPE::INV))
throw InvalidFeatureException();
+ std::map<std::string, double> div_mult_leaves;
+ double expected_abs_tot = 0.0;
+ update_div_mult_leaves(div_mult_leaves, 1.0, expected_abs_tot);
+
+ if((div_mult_leaves.size() < 2))
+ throw InvalidFeatureException();
+
+ if(std::abs(std::accumulate(div_mult_leaves.begin(), div_mult_leaves.end(), -1.0*expected_abs_tot, [](double tot, auto el){return tot + std::abs(el.second);})) > 1e-12)
+ throw InvalidFeatureException();
+
+ int div_mult_tot_first = div_mult_leaves.begin()->second;
+
+ if((std::abs(div_mult_tot_first) > 1) && std::all_of(div_mult_leaves.begin(), div_mult_leaves.end(), [&div_mult_tot_first](auto el){return el.second == div_mult_tot_first;}))
+ throw InvalidFeatureException();
+
set_value();
if(is_nan() || is_const())
throw InvalidFeatureException();
@@ -20,7 +35,39 @@ DivNode::DivNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_ind):
if((feat_1->type() == NODE_TYPE::INV) || (feat_2->type() == NODE_TYPE::INV))
throw InvalidFeatureException();
+ std::map<std::string, double> div_mult_leaves;
+ double expected_abs_tot = 0.0;
+ update_div_mult_leaves(div_mult_leaves, 1.0, expected_abs_tot);
+
+ if((div_mult_leaves.size() < 2))
+ throw InvalidFeatureException();
+
+ if(std::abs(std::accumulate(div_mult_leaves.begin(), div_mult_leaves.end(), -1.0*expected_abs_tot, [](double tot, auto el){return tot + std::abs(el.second);})) > 1e-12)
+ throw InvalidFeatureException();
+
+ int div_mult_tot_first = div_mult_leaves.begin()->second;
+
+ if((std::abs(div_mult_tot_first) > 1) && std::all_of(div_mult_leaves.begin(), div_mult_leaves.end(), [&div_mult_tot_first](auto el){return el.second == div_mult_tot_first;}))
+ throw InvalidFeatureException();
+
set_value();
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+
+void DivNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ std::string key = expr();
+ if(add_sub_leaves.count(key) > 0)
+ add_sub_leaves[key] += pl_mn;
+ else
+ add_sub_leaves[key] = pl_mn;
+
+ ++expected_abs_tot;
+}
+
+void DivNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ _feats[0]->update_div_mult_leaves(div_mult_leaves, fact, expected_abs_tot);
+ _feats[1]->update_div_mult_leaves(div_mult_leaves, -1.0*fact, expected_abs_tot);
+}
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.hpp
index 755acd61..a98ad475 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/divide.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::DIV;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.cpp
index 3a6f47ae..8e81ea5f 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.cpp
@@ -30,3 +30,25 @@ ExpNode::ExpNode(node_ptr feat, int rung, int feat_ind):
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+
+void ExpNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ std::string key = expr();
+ if(add_sub_leaves.count(key) > 0)
+ add_sub_leaves[key] += pl_mn;
+ else
+ add_sub_leaves[key] = pl_mn;
+
+ ++expected_abs_tot;
+}
+
+void ExpNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ std::string key = expr();
+ if(div_mult_leaves.count(key) > 0)
+ div_mult_leaves[key] += fact;
+ else
+ div_mult_leaves[key] = fact;
+
+ expected_abs_tot *= std::abs(fact);
+}
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.hpp
index 605bd4f0..5ed79cbc 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/exponential.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::EXP;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.cpp
index 1cfe5f50..bdccb2d1 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.cpp
@@ -6,7 +6,7 @@ InvNode::InvNode()
InvNode::InvNode(std::vector<node_ptr> feats, int rung, int feat_ind):
OperatorNode(feats, rung, feat_ind)
{
- if(feats[0]->type() == NODE_TYPE::DIV)
+ if((feats[0]->type() == NODE_TYPE::DIV) || (feats[0]->type() == NODE_TYPE::EXP) || (feats[0]->type() == NODE_TYPE::NEG_EXP))
throw InvalidFeatureException();
set_value();
@@ -17,10 +17,26 @@ InvNode::InvNode(std::vector<node_ptr> feats, int rung, int feat_ind):
InvNode::InvNode(node_ptr feat, int rung, int feat_ind):
OperatorNode({feat}, rung, feat_ind)
{
- if(feat->type() == NODE_TYPE::DIV)
+ if((feat->type() == NODE_TYPE::DIV) || (feat->type() == NODE_TYPE::EXP) || (feat->type() == NODE_TYPE::NEG_EXP))
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+
+void InvNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ std::string key = expr();
+ if(add_sub_leaves.count(key) > 0)
+ add_sub_leaves[key] += pl_mn;
+ else
+ add_sub_leaves[key] = pl_mn;
+
+ ++expected_abs_tot;
+}
+
+void InvNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ _feats[0]->update_div_mult_leaves(div_mult_leaves, fact * -1.0, expected_abs_tot);
+}
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.hpp
index 4f2c24b2..46c088a3 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/inverse.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::INV;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.cpp
index 6196fcd4..cab91e51 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.cpp
@@ -9,7 +9,7 @@ LogNode::LogNode(std::vector<node_ptr> feats, int rung, int feat_ind):
if(feats[0]->unit() != Unit())
throw InvalidFeatureException();
- if((feats[0]->type() == NODE_TYPE::NEG_EXP) || (feats[0]->type() == NODE_TYPE::EXP) || (feats[0]->type() == NODE_TYPE::DIV) || (feats[0]->type() == NODE_TYPE::MULT) || (feats[0]->type() == NODE_TYPE::LOG))
+ if((feats[0]->type() == NODE_TYPE::NEG_EXP) || (feats[0]->type() == NODE_TYPE::EXP) || (feats[0]->type() == NODE_TYPE::DIV) || (feats[0]->type() == NODE_TYPE::INV) || (feats[0]->type() == NODE_TYPE::MULT) || (feats[0]->type() == NODE_TYPE::LOG) || (feats[0]->type() == NODE_TYPE::SIX_POW) || (feats[0]->type() == NODE_TYPE::CB) || (feats[0]->type() == NODE_TYPE::SQ) || (feats[0]->type() == NODE_TYPE::CBRT) || (feats[0]->type() == NODE_TYPE::SQRT))
throw InvalidFeatureException();
set_value();
@@ -23,10 +23,32 @@ LogNode::LogNode(node_ptr feat, int rung, int feat_ind):
if(feat->unit() != Unit())
throw InvalidFeatureException();
- if((feat->type() == NODE_TYPE::NEG_EXP) || (feat->type() == NODE_TYPE::EXP) || (feat->type() == NODE_TYPE::DIV) || (feat->type() == NODE_TYPE::MULT) || (feat->type() == NODE_TYPE::LOG))
+ if((feat->type() == NODE_TYPE::NEG_EXP) || (feat->type() == NODE_TYPE::EXP) || (feat->type() == NODE_TYPE::DIV) || (feat->type() == NODE_TYPE::INV) || (feat->type() == NODE_TYPE::MULT) || (feat->type() == NODE_TYPE::LOG) || (feat->type() == NODE_TYPE::SIX_POW) || (feat->type() == NODE_TYPE::CB) || (feat->type() == NODE_TYPE::SQ) || (feat->type() == NODE_TYPE::CBRT) || (feat->type() == NODE_TYPE::SQRT))
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+
+void LogNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ std::string key = expr();
+ if(add_sub_leaves.count(key) > 0)
+ add_sub_leaves[key] += pl_mn;
+ else
+ add_sub_leaves[key] = pl_mn;
+
+ ++expected_abs_tot;
+}
+
+void LogNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ std::string key = expr();
+ if(div_mult_leaves.count(key) > 0)
+ div_mult_leaves[key] += fact;
+ else
+ div_mult_leaves[key] = fact;
+
+ expected_abs_tot *= std::abs(fact);
+}
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.hpp
index 8ba84f86..00d6fa1e 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::LOG;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.cpp
index 7e3b23cd..c92d37ef 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.cpp
@@ -6,6 +6,21 @@ MultNode::MultNode()
MultNode::MultNode(std::vector<node_ptr> feats, int rung, int feat_ind):
OperatorNode(feats, rung, feat_ind)
{
+ std::map<std::string, double> div_mult_leaves;
+ double expected_abs_tot = 0.0;
+ update_div_mult_leaves(div_mult_leaves, 1.0, expected_abs_tot);
+
+ if((div_mult_leaves.size() < 2))
+ throw InvalidFeatureException();
+
+ if(std::abs(std::accumulate(div_mult_leaves.begin(), div_mult_leaves.end(), -1.0*expected_abs_tot, [](double tot, auto el){return tot + std::abs(el.second);})) > 1e-12)
+ throw InvalidFeatureException();
+
+ int div_mult_tot_first = div_mult_leaves.begin()->second;
+
+ if((std::abs(div_mult_tot_first) > 1) && std::all_of(div_mult_leaves.begin(), div_mult_leaves.end(), [&div_mult_tot_first](auto el){return el.second == div_mult_tot_first;}))
+ throw InvalidFeatureException();
+
set_value();
if(is_nan() || is_const())
throw InvalidFeatureException();
@@ -14,8 +29,39 @@ MultNode::MultNode(std::vector<node_ptr> feats, int rung, int feat_ind):
MultNode::MultNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_ind):
OperatorNode({feat_1, feat_2}, rung, feat_ind)
{
+ std::map<std::string, double> div_mult_leaves;
+ double expected_abs_tot = 0.0;
+ update_div_mult_leaves(div_mult_leaves, 1.0, expected_abs_tot);
+
+ if((div_mult_leaves.size() < 2))
+ throw InvalidFeatureException();
+
+ if(std::abs(std::accumulate(div_mult_leaves.begin(), div_mult_leaves.end(), -1.0*expected_abs_tot, [](double tot, auto el){return tot + std::abs(el.second);})) > 1e-12)
+ throw InvalidFeatureException();
+
+ int div_mult_tot_first = div_mult_leaves.begin()->second;
+
+ if((std::abs(div_mult_tot_first) > 1) && std::all_of(div_mult_leaves.begin(), div_mult_leaves.end(), [&div_mult_tot_first](auto el){return el.second == div_mult_tot_first;}))
+ throw InvalidFeatureException();
+
set_value();
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+void MultNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ std::string key = expr();
+ if(add_sub_leaves.count(key) > 0)
+ add_sub_leaves[key] += pl_mn;
+ else
+ add_sub_leaves[key] = pl_mn;
+
+ ++expected_abs_tot;
+}
+
+void MultNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ _feats[0]->update_div_mult_leaves(div_mult_leaves, fact, expected_abs_tot);
+ _feats[1]->update_div_mult_leaves(div_mult_leaves, fact, expected_abs_tot);
+}
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.hpp
index aa93bbd9..b29516e6 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/multiply.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::MULT;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.cpp
index 6cb67bb1..4776b396 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.cpp
@@ -32,3 +32,25 @@ NegExpNode::NegExpNode(node_ptr feat, int rung, int feat_ind):
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+
+void NegExpNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ std::string key = expr();
+ if(add_sub_leaves.count(key) > 0)
+ add_sub_leaves[key] += pl_mn;
+ else
+ add_sub_leaves[key] = pl_mn;
+
+ ++expected_abs_tot;
+}
+
+void NegExpNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ std::string key = "exp(" + _feats[0]->expr() + ")";
+ if(div_mult_leaves.count(key) > 0)
+ div_mult_leaves[key] -= fact;
+ else
+ div_mult_leaves[key] = -1.0 * fact;
+
+ expected_abs_tot *= std::abs(fact);
+}
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.hpp
index 32e4d785..e08b206b 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/negative_exponential.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::NEG_EXP;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.cpp
index ceba2dde..5aafb053 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.cpp
@@ -30,3 +30,25 @@ SinNode::SinNode(node_ptr feat, int rung, int feat_ind):
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+
+void SinNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ std::string key = expr();
+ if(add_sub_leaves.count(key) > 0)
+ add_sub_leaves[key] += pl_mn;
+ else
+ add_sub_leaves[key] = pl_mn;
+
+ ++expected_abs_tot;
+}
+
+void SinNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ std::string key = expr();
+ if(div_mult_leaves.count(key) > 0)
+ div_mult_leaves[key] += fact;
+ else
+ div_mult_leaves[key] = fact;
+
+ expected_abs_tot *= std::abs(fact);
+}
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.hpp
index d871a87e..db423ac2 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sin.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::SIN;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.cpp
index 3d1063d4..4c85272e 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.cpp
@@ -24,3 +24,19 @@ SixPowNode::SixPowNode(node_ptr feat, int rung, int feat_ind):
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+
+void SixPowNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ std::string key = expr();
+ if(add_sub_leaves.count(key) > 0)
+ add_sub_leaves[key] += pl_mn;
+ else
+ add_sub_leaves[key] = pl_mn;
+
+ ++expected_abs_tot;
+}
+
+void SixPowNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ _feats[0]->update_div_mult_leaves(div_mult_leaves, fact * 6.0, expected_abs_tot);
+}
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.hpp
index 6c2535be..57e47160 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/sixth_power.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::SIX_POW;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.cpp
index 082f9e30..4523668d 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.cpp
@@ -24,3 +24,19 @@ SqNode::SqNode(node_ptr feat, int rung, int feat_ind):
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+
+void SqNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ std::string key = expr();
+ if(add_sub_leaves.count(key) > 0)
+ add_sub_leaves[key] += pl_mn;
+ else
+ add_sub_leaves[key] = pl_mn;
+
+ ++expected_abs_tot;
+}
+
+void SqNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ _feats[0]->update_div_mult_leaves(div_mult_leaves, fact * 2.0, expected_abs_tot);
+}
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.hpp
index 93ed269c..669e69b4 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::SQ;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.cpp
index 0bbcaa9a..50cd2108 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.cpp
@@ -24,3 +24,19 @@ SqrtNode::SqrtNode(node_ptr feat, int rung, int feat_ind):
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+
+void SqrtNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ std::string key = expr();
+ if(add_sub_leaves.count(key) > 0)
+ add_sub_leaves[key] += pl_mn;
+ else
+ add_sub_leaves[key] = pl_mn;
+
+ ++expected_abs_tot;
+}
+
+void SqrtNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ _feats[0]->update_div_mult_leaves(div_mult_leaves, fact / 2.0, expected_abs_tot);
+}
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.hpp
index 086328a4..e1f33e68 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/square_root.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::SQRT;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.cpp
index 9c98078a..801d9d77 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.cpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.cpp
@@ -9,6 +9,21 @@ SubNode::SubNode(std::vector<node_ptr> feats, int rung, int feat_ind):
if(feats[0]->unit() != feats[1]->unit())
throw InvalidFeatureException();
+ std::map<std::string, int> add_sub_leaves;
+ int expected_abs_tot = 0;
+ update_add_sub_leaves(add_sub_leaves, 1, expected_abs_tot);
+
+ if((add_sub_leaves.size() < 2))
+ throw InvalidFeatureException();
+
+ if(std::abs(std::accumulate(add_sub_leaves.begin(), add_sub_leaves.end(), -1*expected_abs_tot, [](int tot, auto el){return tot + std::abs(el.second);})) > 0)
+ throw InvalidFeatureException();
+
+ int add_sub_tot_first = add_sub_leaves.begin()->second;
+
+ if((std::abs(add_sub_tot_first) > 1) && std::all_of(add_sub_leaves.begin(), add_sub_leaves.end(), [&add_sub_tot_first](auto el){return el.second == add_sub_tot_first;}))
+ throw InvalidFeatureException();
+
set_value();
if(is_nan() || is_const())
throw InvalidFeatureException();
@@ -20,7 +35,39 @@ SubNode::SubNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_ind):
if(feat_1->unit() != feat_2->unit())
throw InvalidFeatureException();
+ std::map<std::string, int> add_sub_leaves;
+ int expected_abs_tot = 0;
+ update_add_sub_leaves(add_sub_leaves, 1, expected_abs_tot);
+
+ if((add_sub_leaves.size() < 2))
+ throw InvalidFeatureException();
+
+ if(std::abs(std::accumulate(add_sub_leaves.begin(), add_sub_leaves.end(), -1*expected_abs_tot, [](int tot, auto el){return tot + std::abs(el.second);})) > 0)
+ throw InvalidFeatureException();
+
+ int add_sub_tot_first = add_sub_leaves.begin()->second;
+
+ if((std::abs(add_sub_tot_first) > 1) && std::all_of(add_sub_leaves.begin(), add_sub_leaves.end(), [&add_sub_tot_first](auto el){return el.second == add_sub_tot_first;}))
+ throw InvalidFeatureException();
+
set_value();
if(is_nan() || is_const())
throw InvalidFeatureException();
}
+
+void SubNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
+{
+ _feats[0]->update_add_sub_leaves(add_sub_leaves, pl_mn, expected_abs_tot);
+ _feats[1]->update_add_sub_leaves(add_sub_leaves, -1*pl_mn, expected_abs_tot);
+}
+
+void SubNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
+{
+ std::string key = expr();
+ if(div_mult_leaves.count(key) > 0)
+ div_mult_leaves[key] += fact;
+ else
+ div_mult_leaves[key] = fact;
+
+ expected_abs_tot *= std::abs(fact);
+}
diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.hpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.hpp
index daa24450..c8b167d2 100644
--- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.hpp
+++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/subtract.hpp
@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::SUB;}
+ /**
+ * @brief update the dictionary used to check if an Add/Sub node is valid
+ *
+ * @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
+ * @param pl_mn if for an addition node: 1 if for a subtraction node: -1
+ * @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
+ */
+ void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
+
+ /**
+ * @brief update the dictionary used to check if
+ * @details [long description]
+ *
+ * @param add_sub_leaves [description]
+ * @param fact amount to increment the dictionary by
+ * @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
+ */
+ void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
+
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
--
GitLab