Commit 8f0ade75 authored by Thomas Purcell's avatar Thomas Purcell
Browse files

Minor refactors

Added type condition tests (initial simplification)
Added som more documenatation
parent ec51b853
......@@ -126,30 +126,21 @@ void FeatureSpace::generate_feature_space()
}
}
_mpi_comm->barrier();
if(_mpi_comm->rank() == 0)
std::cout << "NEXT_PHI MADE" << std::endl;
else
std::cout << "NNNN NEXT_PHI MADE" << std::endl;
_start_gen.push_back(_phi.size());
std::vector<std::vector<node_ptr>> next_phi_gathered;
mpi::all_gather(*_mpi_comm, next_phi, next_phi_gathered);
if(_mpi_comm->rank() == 0)
std::cout << "all gather" << std::endl;
else
std::cout << "aaaa all_gather" << std::endl;
std::cout << nn << " set values" << std::endl;
for(auto& next_phi_vec : next_phi_gathered)
{
_phi.reserve(_phi.size() + next_phi_vec.size());
for(auto& feat : next_phi_vec)
{
if(nn <= node_value_arrs::N_RUNGS_STORED)
{
std::transform(_phi.begin(), _phi.end(), scores.begin(), [&feat](node_ptr f){return 1.0 - std::abs(util_funcs::r(feat->value_ptr(), f->value_ptr(), f->n_samp()));});
if(*std::min_element(scores.begin(), scores.begin()+_phi.size()) > 1e-13)
_phi.push_back(feat);
feat->set_value();
_phi.push_back(feat);
}
else
{
......@@ -157,7 +148,6 @@ void FeatureSpace::generate_feature_space()
}
}
}
std::cout << "DONE"<< std::endl;
}
_n_feat = _phi.size();
}
......
......@@ -91,6 +91,11 @@ public:
return std::all_of(value_ptr(), value_ptr() + _n_samp, [&mean](double d){return std::abs(d - mean) < 1e-12;});
}
/**
* @brief Returns the type of node this is
*/
inline NODE_TYPE type(){return NODE_TYPE::FEAT;}
/**
* @brief Accessor function to the value of the feature
*/
......
......@@ -105,6 +105,11 @@ public:
*/
virtual bool is_const() = 0;
/**
* @brief Returns the type of node this is
*/
virtual NODE_TYPE type() = 0;
/**
* @brief Serialization function to send over MPI
*
......
......@@ -97,6 +97,11 @@ public:
return std::all_of(value_ptr(), value_ptr() + _n_samp, [&mean](double d){return std::abs(d - mean) < 1e-12;});
}
/**
* @brief Returns the type of node this is
*/
virtual NODE_TYPE type() = 0;
/**
* @brief Set up the feature value pointers
*/
......
......@@ -11,8 +11,11 @@ AbsDiffNode::AbsDiffNode(std::vector<node_ptr> feats, int rung, int feat_ind) :
if(feats[0]->unit() != feats[1]->unit())
throw InvalidFeatureException();
if((feats[0]->type() == NODE_TYPE::LOG) && (feats[1]->type() == NODE_TYPE::LOG))
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
if(is_nan() || is_const())
throw InvalidFeatureException();
}
......@@ -22,8 +25,11 @@ AbsDiffNode::AbsDiffNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_in
if(feat_1->unit() != feat_2->unit())
throw InvalidFeatureException();
if((feat_1->type() == NODE_TYPE::LOG) && (feat_2->type() == NODE_TYPE::LOG))
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
if(is_nan() || is_const())
throw InvalidFeatureException();
}
......
......@@ -19,6 +19,11 @@ public:
void set_value();
/**
* @brief Returns the type of node this is
*/
inline NODE_TYPE type(){return NODE_TYPE::ABS_DIFF;}
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
......
......@@ -9,7 +9,7 @@ AbsNode::AbsNode(std::vector<node_ptr> feats, int rung, int feat_ind):
OperatorNode(feats, rung, feat_ind)
{
set_value();
if(is_nan() || is_const())
if(is_nan() || is_const())
throw InvalidFeatureException();
}
......@@ -17,7 +17,7 @@ AbsNode::AbsNode(node_ptr feat, int rung, int feat_ind):
OperatorNode({feat}, rung, feat_ind)
{
set_value();
if(is_nan() || is_const())
if(is_nan() || is_const())
throw InvalidFeatureException();
}
......
......@@ -19,6 +19,11 @@ public:
void set_value();
/**
* @brief Returns the type of node this is
*/
inline NODE_TYPE type(){return NODE_TYPE::ABS;}
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
......
......@@ -10,7 +10,7 @@ AddNode::AddNode(std::vector<node_ptr> feats, int rung, int feat_ind):
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
if(is_nan() || is_const())
throw InvalidFeatureException();
}
......@@ -21,7 +21,7 @@ AddNode::AddNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_ind):
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
if(is_nan() || is_const())
throw InvalidFeatureException();
}
......
......@@ -19,6 +19,11 @@ public:
void set_value();
/**
* @brief Returns the type of node this is
*/
inline NODE_TYPE type(){return NODE_TYPE::ADD;}
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
......
......@@ -9,8 +9,11 @@ CosNode::CosNode(std::vector<node_ptr> feats, int rung, int feat_ind):
if(feats[0]->unit() != Unit())
throw InvalidFeatureException();
if((feats[0]->type() == NODE_TYPE::SIN) || (feats[0]->type() == NODE_TYPE::COS))
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
if(is_nan() || is_const())
throw InvalidFeatureException();
}
......@@ -20,8 +23,11 @@ CosNode::CosNode(node_ptr feat, int rung, int feat_ind):
if(feat->unit() != Unit())
throw InvalidFeatureException();
if((feat->type() == NODE_TYPE::SIN) || (feat->type() == NODE_TYPE::COS))
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
if(is_nan() || is_const())
throw InvalidFeatureException();
}
......
......@@ -19,6 +19,11 @@ public:
void set_value();
/**
* @brief Returns the type of node this is
*/
inline NODE_TYPE type(){return NODE_TYPE::COS;}
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
......
......@@ -6,16 +6,22 @@ CbNode::CbNode()
CbNode::CbNode(std::vector<node_ptr> feats, int rung, int feat_ind):
OperatorNode(feats, rung, feat_ind)
{
if(feats[0]->type() == NODE_TYPE::CBRT)
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
if(is_nan() || is_const())
throw InvalidFeatureException();
}
CbNode::CbNode(node_ptr feat, int rung, int feat_ind):
OperatorNode({feat}, rung, feat_ind)
{
if(feat->type() == NODE_TYPE::CBRT)
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
if(is_nan() || is_const())
throw InvalidFeatureException();
}
......
......@@ -19,6 +19,11 @@ public:
void set_value();
/**
* @brief Returns the type of node this is
*/
inline NODE_TYPE type(){return NODE_TYPE::CB;}
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
......
......@@ -6,16 +6,22 @@ CbrtNode::CbrtNode()
CbrtNode::CbrtNode(std::vector<node_ptr> feats, int rung, int feat_ind):
OperatorNode(feats, rung, feat_ind)
{
if(feats[0]->type() == NODE_TYPE::CB)
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
if(is_nan() || is_const())
throw InvalidFeatureException();
}
CbrtNode::CbrtNode(node_ptr feat, int rung, int feat_ind):
OperatorNode({feat}, rung, feat_ind)
{
if(feat->type() == NODE_TYPE::CB)
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
if(is_nan() || is_const())
throw InvalidFeatureException();
}
......
......@@ -19,6 +19,11 @@ public:
void set_value();
/**
* @brief Returns the type of node this is
*/
inline NODE_TYPE type(){return NODE_TYPE::CBRT;}
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
......
......@@ -6,16 +6,22 @@ DivNode::DivNode()
DivNode::DivNode(std::vector<node_ptr> feats, int rung, int feat_ind):
OperatorNode(feats, rung, feat_ind)
{
if((feats[0]->type() == NODE_TYPE::INV) || (feats[1]->type() == NODE_TYPE::INV))
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
if(is_nan() || is_const())
throw InvalidFeatureException();
}
DivNode::DivNode(node_ptr feat_1, node_ptr feat_2, int rung, int feat_ind):
OperatorNode({feat_1, feat_2}, rung, feat_ind)
{
if((feat_1->type() == NODE_TYPE::INV) || (feat_2->type() == NODE_TYPE::INV))
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
if(is_nan() || is_const())
throw InvalidFeatureException();
}
......
......@@ -19,6 +19,11 @@ public:
void set_value();
/**
* @brief Returns the type of node this is
*/
inline NODE_TYPE type(){return NODE_TYPE::DIV;}
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
......
......@@ -9,8 +9,11 @@ ExpNode::ExpNode(std::vector<node_ptr> feats, int rung, int feat_ind):
if(feats[0]->unit() != Unit())
throw InvalidFeatureException();
if((feats[0]->type() == NODE_TYPE::NEG_EXP) || (feats[0]->type() == NODE_TYPE::EXP) || (feats[0]->type() == NODE_TYPE::ADD) || (feats[0]->type() == NODE_TYPE::SUB) || (feats[0]->type() == NODE_TYPE::LOG))
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
if(is_nan() || is_const())
throw InvalidFeatureException();
}
......@@ -20,8 +23,11 @@ ExpNode::ExpNode(node_ptr feat, int rung, int feat_ind):
if(feat->unit() != Unit())
throw InvalidFeatureException();
if((feat->type() == NODE_TYPE::NEG_EXP) || (feat->type() == NODE_TYPE::EXP) || (feat->type() == NODE_TYPE::ADD) || (feat->type() == NODE_TYPE::SUB) || (feat->type() == NODE_TYPE::LOG))
throw InvalidFeatureException();
set_value();
if(is_nan() || is_const())
if(is_nan() || is_const())
throw InvalidFeatureException();
}
......
......@@ -19,6 +19,11 @@ public:
void set_value();
/**
* @brief Returns the type of node this is
*/
inline NODE_TYPE type(){return NODE_TYPE::EXP;}
template <typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment