Commit 77ab05f9 authored by Thomas Purcell's avatar Thomas Purcell
Browse files

Basic Simiplification Added for Add, Subtract, Abs_Diff, Mult and Div Nodes

Checks for cancelations and duplicate features
parent 196ff8a1
......@@ -15,4 +15,24 @@ FeatureNode::FeatureNode(const FeatureNode &o) :
Node(o)
{}
void FeatureNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
{
if(add_sub_leaves.count(_expr) > 0)
add_sub_leaves[_expr] += pl_mn;
else
add_sub_leaves[_expr] = pl_mn;
++expected_abs_tot;
}
void FeatureNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
{
if(div_mult_leaves.count(_expr) > 0)
div_mult_leaves[_expr] += fact;
else
div_mult_leaves[_expr] = fact;
expected_abs_tot += std::abs(fact);
}
// BOOST_CLASS_EXPORT(FeatureNode)
......@@ -88,6 +88,23 @@ public:
*/
inline int rung(int cur_rung = 0){return cur_rung;}
/**
* @brief update the dictionary used to check if an Add/Sub node is valid
*
* @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
* @param pl_mn if for an addition node: 1 if for a subtraction node: -1
* @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
*/
void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
/**
* @brief update the dictionary used to check if
* @details [long description]
*
* @param add_sub_leaves [description]
*/
void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
/**
* @brief Serialization function to send over MPI
*
......
......@@ -97,6 +97,24 @@ public:
*/
virtual int rung(int cur_rung = 0) = 0;
/**
* @brief update the dictionary used to check if an Add/Sub node is valid
*
* @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
* @param pl_mn if for an addition node: 1 if for a subtraction node: -1
* @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
*/
virtual void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) = 0;
/**
* @brief update the dictionary used to check if an Mult/Div node is valid
*
* @param div_mult_leaves the dictionary used to check if an Mult/Div node is valid
* @param fact amount to increment the dictionary by
* @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
*/
virtual void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) = 0;
/**
* @brief Serialization function to send over MPI
*
......
......@@ -90,6 +90,24 @@ public:
* @brief Returns the type of node this is
*/
virtual NODE_TYPE type() = 0;
/**
* @brief update the dictionary used to check if an Add/Sub node is valid
*
* @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
* @param pl_mn if for an addition node: 1 if for a subtraction node: -1
* @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
*/
virtual void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot) = 0;
/**
* @brief update the dictionary used to check if an Mult/Div node is valid
*
* @param div_mult_leaves the dictionary used to check if an Mult/Div node is valid
* @param fact amount to increment the dictionary by
*/
virtual void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot) = 0;
};
#endif
......@@ -11,7 +11,19 @@ AbsDiffNode::AbsDiffNode(std::vector<node_ptr> feats, int rung, int feat_ind) :
if(feats[0]->unit() != feats[1]->unit())
throw InvalidFeatureException();
if((feats[0]->type() == NODE_TYPE::LOG) && (feats[1]->type() == NODE_TYPE::LOG))
std::map<std::string, int> add_sub_leaves;
int expected_abs_tot = 0;
update_add_sub_leaves(add_sub_leaves, 1, expected_abs_tot);
if((add_sub_leaves.size() < 2))
throw InvalidFeatureException();
if(std::abs(std::accumulate(add_sub_leaves.begin(), add_sub_leaves.end(), -1*expected_abs_tot, [](int tot, auto el){return tot + std::abs(el.second);})) > 0)
throw InvalidFeatureException();
int add_sub_tot_first = add_sub_leaves.begin()->second;
if((std::abs(add_sub_tot_first) > 1) && std::all_of(add_sub_leaves.begin(), add_sub_leaves.end(), [&add_sub_tot_first](auto el){return el.second == add_sub_tot_first;}))
throw InvalidFeatureException();
set_value();
......@@ -25,10 +37,43 @@ AbsDiffNode::AbsDiffNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_in
if(feat_1->unit() != feat_2->unit())
throw InvalidFeatureException();
if((feat_1->type() == NODE_TYPE::LOG) && (feat_2->type() == NODE_TYPE::LOG))
std::map<std::string, int> add_sub_leaves;
int expected_abs_tot = 0;
update_add_sub_leaves(add_sub_leaves, 1, expected_abs_tot);
if((add_sub_leaves.size() < 2))
throw InvalidFeatureException();
if(std::abs(std::accumulate(add_sub_leaves.begin(), add_sub_leaves.end(), -1*expected_abs_tot, [](int tot, auto el){return tot + std::abs(el.second);})) > 0)
throw InvalidFeatureException();
int add_sub_tot_first = add_sub_leaves.begin()->second;
if((std::abs(add_sub_tot_first) > 1) && std::all_of(add_sub_leaves.begin(), add_sub_leaves.end(), [&add_sub_tot_first](auto el){return el.second == add_sub_tot_first;}))
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
throw InvalidFeatureException();
}
void AbsDiffNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
{
std::string key = expr();
if(add_sub_leaves.count(key) > 0)
add_sub_leaves[key] += pl_mn;
else
add_sub_leaves[key] = pl_mn;
++expected_abs_tot;
}
void AbsDiffNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
{
std::string key = expr();
if(div_mult_leaves.count(key) > 0)
div_mult_leaves[key] += fact;
else
div_mult_leaves[key] = fact;
expected_abs_tot *= std::abs(fact);
}
......@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::ABS_DIFF;}
/**
* @brief update the dictionary used to check if an Add/Sub node is valid
*
* @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
* @param pl_mn if for an addition node: 1 if for a subtraction node: -1
* @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
*/
void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
/**
* @brief update the dictionary used to check if
* @details [long description]
*
* @param add_sub_leaves [description]
* @param fact amount to increment the dictionary by
* @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
*/
void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
......
......@@ -20,3 +20,25 @@ AbsNode::AbsNode(node_ptr feat, int rung, int feat_ind):
if(is_nan() || is_const())
throw InvalidFeatureException();
}
void AbsNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
{
std::string key = expr();
if(add_sub_leaves.count(key) > 0)
add_sub_leaves[key] += pl_mn;
else
add_sub_leaves[key] = pl_mn;
++expected_abs_tot;
}
void AbsNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
{
std::string key = expr();
if(div_mult_leaves.count(key) > 0)
div_mult_leaves[key] += fact;
else
div_mult_leaves[key] = fact;
expected_abs_tot *= std::abs(fact);
}
......@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::ABS;}
/**
* @brief update the dictionary used to check if an Add/Sub node is valid
*
* @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
* @param pl_mn if for an addition node: 1 if for a subtraction node: -1
* @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
*/
void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
/**
* @brief update the dictionary used to check if
* @details [long description]
*
* @param add_sub_leaves [description]
* @param fact amount to increment the dictionary by
* @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
*/
void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
......
......@@ -9,6 +9,21 @@ AddNode::AddNode(std::vector<node_ptr> feats, int rung, int feat_ind):
if(feats[0]->unit() != feats[1]->unit())
throw InvalidFeatureException();
std::map<std::string, int> add_sub_leaves;
int expected_abs_tot = 0;
update_add_sub_leaves(add_sub_leaves, 1, expected_abs_tot);
if((add_sub_leaves.size() < 2))
throw InvalidFeatureException();
if(std::abs(std::accumulate(add_sub_leaves.begin(), add_sub_leaves.end(), -1*expected_abs_tot, [](int tot, auto el){return tot + std::abs(el.second);})) > 0)
throw InvalidFeatureException();
int add_sub_tot_first = add_sub_leaves.begin()->second;
if((std::abs(add_sub_tot_first) > 1) && std::all_of(add_sub_leaves.begin(), add_sub_leaves.end(), [&add_sub_tot_first](auto el){return el.second == add_sub_tot_first;}))
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
throw InvalidFeatureException();
......@@ -20,7 +35,39 @@ AddNode::AddNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_ind):
if(feat_1->unit() != feat_2->unit())
throw InvalidFeatureException();
std::map<std::string, int> add_sub_leaves;
int expected_abs_tot = 0;
update_add_sub_leaves(add_sub_leaves, 1, expected_abs_tot);
if((add_sub_leaves.size() < 2))
throw InvalidFeatureException();
if(std::abs(std::accumulate(add_sub_leaves.begin(), add_sub_leaves.end(), -1*expected_abs_tot, [](int tot, auto el){return tot + std::abs(el.second);})) > 0)
throw InvalidFeatureException();
int add_sub_tot_first = add_sub_leaves.begin()->second;
if((std::abs(add_sub_tot_first) > 1) && std::all_of(add_sub_leaves.begin(), add_sub_leaves.end(), [&add_sub_tot_first](auto el){return el.second == add_sub_tot_first;}))
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
throw InvalidFeatureException();
}
void AddNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
{
_feats[0]->update_add_sub_leaves(add_sub_leaves, pl_mn, expected_abs_tot);
_feats[1]->update_add_sub_leaves(add_sub_leaves, pl_mn, expected_abs_tot);
}
void AddNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
{
std::string key = expr();
if(div_mult_leaves.count(key) > 0)
div_mult_leaves[key] += fact;
else
div_mult_leaves[key] = fact;
expected_abs_tot *= std::abs(fact);
}
......@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::ADD;}
/**
* @brief update the dictionary used to check if an Add/Sub node is valid
*
* @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
* @param pl_mn if for an addition node: 1 if for a subtraction node: -1
* @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
*/
void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
/**
* @brief update the dictionary used to check if
* @details [long description]
*
* @param add_sub_leaves [description]
* @param fact amount to increment the dictionary by
* @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
*/
void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
......
......@@ -30,3 +30,26 @@ CosNode::CosNode(node_ptr feat, int rung, int feat_ind):
if(is_nan() || is_const())
throw InvalidFeatureException();
}
void CosNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
{
std::string key = expr();
if(add_sub_leaves.count(key) > 0)
add_sub_leaves[key] += pl_mn;
else
add_sub_leaves[key] = pl_mn;
++expected_abs_tot;
}
void CosNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
{
std::string key = expr();
if(div_mult_leaves.count(key) > 0)
div_mult_leaves[key] += fact;
else
div_mult_leaves[key] = fact;
expected_abs_tot *= std::abs(fact);
}
......@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::COS;}
/**
* @brief update the dictionary used to check if an Add/Sub node is valid
*
* @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
* @param pl_mn if for an addition node: 1 if for a subtraction node: -1
* @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
*/
void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
/**
* @brief update the dictionary used to check if
* @details [long description]
*
* @param add_sub_leaves [description]
* @param fact amount to increment the dictionary by
* @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
*/
void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
......
......@@ -24,3 +24,19 @@ CbNode::CbNode(node_ptr feat, int rung, int feat_ind):
if(is_nan() || is_const())
throw InvalidFeatureException();
}
void CbNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
{
std::string key = expr();
if(add_sub_leaves.count(key) > 0)
add_sub_leaves[key] += pl_mn;
else
add_sub_leaves[key] = pl_mn;
++expected_abs_tot;
}
void CbNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
{
_feats[0]->update_div_mult_leaves(div_mult_leaves, fact * 3.0, expected_abs_tot);
}
......@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::CB;}
/**
* @brief update the dictionary used to check if an Add/Sub node is valid
*
* @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
* @param pl_mn if for an addition node: 1 if for a subtraction node: -1
* @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
*/
void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
/**
* @brief update the dictionary used to check if
* @details [long description]
*
* @param add_sub_leaves [description]
* @param fact amount to increment the dictionary by
* @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
*/
void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
......
......@@ -24,3 +24,19 @@ CbrtNode::CbrtNode(node_ptr feat, int rung, int feat_ind):
if(is_nan() || is_const())
throw InvalidFeatureException();
}
void CbrtNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
{
std::string key = expr();
if(add_sub_leaves.count(key) > 0)
add_sub_leaves[key] += pl_mn;
else
add_sub_leaves[key] = pl_mn;
++expected_abs_tot;
}
void CbrtNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
{
_feats[0]->update_div_mult_leaves(div_mult_leaves, fact / 3.0, expected_abs_tot);
}
......@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::CBRT;}
/**
* @brief update the dictionary used to check if an Add/Sub node is valid
*
* @param add_sub_leaves the dictionary used to check if an Add/Sub node
* @param expected_abs_tot The expected absolute sum of all values in add_sub_leavesis valid
* @param pl_mn if for an addition node: 1 if for a subtraction node: -1
*/
void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
/**
* @brief update the dictionary used to check if
* @details [long description]
*
* @param add_sub_leaves [description]
* @param fact amount to increment the dictionary by
* @param expected_abs_tot The expected absolute sum of all values in div_mult_leaves
*/
void update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot);
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
......
......@@ -9,6 +9,21 @@ DivNode::DivNode(std::vector<node_ptr> feats, int rung, int feat_ind):
if((feats[0]->type() == NODE_TYPE::INV) || (feats[1]->type() == NODE_TYPE::INV))
throw InvalidFeatureException();
std::map<std::string, double> div_mult_leaves;
double expected_abs_tot = 0.0;
update_div_mult_leaves(div_mult_leaves, 1.0, expected_abs_tot);
if((div_mult_leaves.size() < 2))
throw InvalidFeatureException();
if(std::abs(std::accumulate(div_mult_leaves.begin(), div_mult_leaves.end(), -1.0*expected_abs_tot, [](double tot, auto el){return tot + std::abs(el.second);})) > 1e-12)
throw InvalidFeatureException();
int div_mult_tot_first = div_mult_leaves.begin()->second;
if((std::abs(div_mult_tot_first) > 1) && std::all_of(div_mult_leaves.begin(), div_mult_leaves.end(), [&div_mult_tot_first](auto el){return el.second == div_mult_tot_first;}))
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
throw InvalidFeatureException();
......@@ -20,7 +35,39 @@ DivNode::DivNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_ind):
if((feat_1->type() == NODE_TYPE::INV) || (feat_2->type() == NODE_TYPE::INV))
throw InvalidFeatureException();
std::map<std::string, double> div_mult_leaves;
double expected_abs_tot = 0.0;
update_div_mult_leaves(div_mult_leaves, 1.0, expected_abs_tot);
if((div_mult_leaves.size() < 2))
throw InvalidFeatureException();
if(std::abs(std::accumulate(div_mult_leaves.begin(), div_mult_leaves.end(), -1.0*expected_abs_tot, [](double tot, auto el){return tot + std::abs(el.second);})) > 1e-12)
throw InvalidFeatureException();
int div_mult_tot_first = div_mult_leaves.begin()->second;
if((std::abs(div_mult_tot_first) > 1) && std::all_of(div_mult_leaves.begin(), div_mult_leaves.end(), [&div_mult_tot_first](auto el){return el.second == div_mult_tot_first;}))
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
throw InvalidFeatureException();
}
void DivNode::update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot)
{
std::string key = expr();
if(add_sub_leaves.count(key) > 0)
add_sub_leaves[key] += pl_mn;
else
add_sub_leaves[key] = pl_mn;
++expected_abs_tot;
}
void DivNode::update_div_mult_leaves(std::map<std::string, double>& div_mult_leaves, double fact, double& expected_abs_tot)
{
_feats[0]->update_div_mult_leaves(div_mult_leaves, fact, expected_abs_tot);
_feats[1]->update_div_mult_leaves(div_mult_leaves, -1.0*fact, expected_abs_tot);
}
......@@ -33,6 +33,25 @@ public:
*/
inline NODE_TYPE type(){return NODE_TYPE::DIV;}
/**
* @brief update the dictionary used to check if an Add/Sub node is valid
*
* @param add_sub_leaves the dictionary used to check if an Add/Sub node is valid
* @param pl_mn if for an addition node: 1 if for a subtraction node: -1
* @param expected_abs_tot The expected absolute sum of all values in add_sub_leaves
*/
void update_add_sub_leaves(std::map<std::string, int>& add_sub_leaves, int pl_mn, int& expected_abs_tot);
/**
* @brief update the dictionary used to check if
* @details [long description]
*