Skip to content
Snippets Groups Projects
Select Git revision
  • becfcade5441def88b5461084cfe423896a656c5
  • master default protected
  • pre_gpu_changes
  • classification
4 results

test_mult_node.cc

Blame
  • test_mult_node.cc 5.33 KiB
    #ifdef PARAMETERIZE
    #include <feature_creation/node/operator_nodes/allowed_operator_nodes/mult/parameterized_multiply.hpp>
    #include <feature_creation/node/value_storage/nodes_value_containers.hpp>
    #include <feature_creation/node/FeatureNode.hpp>
    #include "gtest/gtest.h"
    
    #include <random>
    
    namespace
    {
        class MultParamNodeTest : public ::testing::Test
        {
        protected:
            void SetUp() override
            {
                nlopt_wrapper::MAX_PARAM_DEPTH = 1;
    
                node_value_arrs::initialize_values_arr(900, 10, 2);
    
                _task_sizes_train = {900};
    
                std::vector<double> value_1(900, 0.0);
                std::vector<double> value_2(900, 0.0);
    
                std::vector<double> test_value_1(10, 0.0);
                std::vector<double> test_value_2(10, 0.0);
    
                std::default_random_engine generator;
                std::uniform_real_distribution<double> distribution_feats(-50.0, 50.0);
                std::uniform_real_distribution<double> distribution_params(-2.50, 2.50);
    
                for(int ii = 0; ii < 900; ++ii)
                {
                    value_1[ii] = distribution_feats(generator);
                    value_2[ii] = distribution_feats(generator);
                }
    
                for(int ii = 0; ii < 10; ++ii)
                {
                    test_value_1[ii] = distribution_feats(generator);
                    test_value_2[ii] = distribution_feats(generator);
                }
    
                _feat_1 = std::make_shared<FeatureNode>(0, "A", value_1, test_value_1, Unit("m"));
                _feat_2 = std::make_shared<FeatureNode>(1, "B", value_2, test_value_2, Unit("s"));
    
                _phi = {_feat_1, _feat_2};
                _a = distribution_params(generator);
                _alpha = distribution_params(generator);
                _prop = std::vector<double>(900, 0.0);
                allowed_op_funcs::mult(900, _phi[0]->value_ptr(), _phi[1]->value_ptr(), _alpha, _a, _prop.data());
    
                _optimizer = nlopt_wrapper::get_optimizer("regression",_task_sizes_train, _prop, 1);
            }
    
            node_ptr _feat_1;
            node_ptr _feat_2;
            node_ptr _mult_test;
    
            std::vector<node_ptr> _phi;
            std::vector<double> _prop;
            std::vector<int> _task_sizes_train;
    
            double _a;
            double _alpha;
            std::shared_ptr<NLOptimizer> _optimizer;
        };
    
        TEST_F(MultParamNodeTest, GeneratorTest)
        {
            int feat_ind = _phi.size();
    
            generateMultParamNode(_phi, _phi[0], _phi[1], feat_ind, 1e-50, 1e-40, _optimizer);
            EXPECT_EQ(_phi.size(), 2) << " (MultParamNode created with an absolute value above the upper bound)";
    
            generateMultParamNode(_phi, _phi[0], _phi[1], feat_ind, 1e49, 1e50, _optimizer);
            EXPECT_EQ(_phi.size(), 2) << " (MultParamNode created with an absolute value below the lower bound)";
    
            generateMultParamNode(_phi, _phi[0], _phi[1], feat_ind, 1e-50, 1e50, _optimizer);
            EXPECT_EQ(_phi.size(), 3) << " (Failure to create a valid feature)";
            EXPECT_LT(1.0 - util_funcs::r2(_prop.data(), _phi.back()->value_ptr(), 900), 1e-4);
        }
    
        TEST_F(MultParamNodeTest, ConstructorTest)
        {
            int feat_ind = _phi.size();
    
            try
            {
                _mult_test = std::make_shared<MultParamNode>(_phi[0], _phi[1], feat_ind, 1e-50, 1e-40, _optimizer);
                EXPECT_TRUE(false) << " (MultParamNode created with an absolute value above the upper bound)";
            }
            catch(const InvalidFeatureException& e)
            {}
    
            try
            {
                _mult_test = std::make_shared<MultParamNode>(_phi[0], _phi[1], feat_ind, 1e49, 1e50, _optimizer);
                EXPECT_TRUE(false) << " (MultParamNode created with an absolute value below the lower bound)";
            }
            catch(const InvalidFeatureException& e)
            {}
    
            try
            {
                _mult_test = std::make_shared<MultParamNode>(_phi[0], _phi[1], feat_ind, 1e-50, 1e50, _optimizer);
                EXPECT_LT(1.0 - util_funcs::r2(_prop.data(), _mult_test->value_ptr(), 900), 1e-4);
            }
            catch(const InvalidFeatureException& e)
            {
                EXPECT_TRUE(false) << " (Failure to create a valid feature)";
            }
        }
    
        TEST_F(MultParamNodeTest, AttributesTest)
        {
            int feat_ind = _phi.size();
            _mult_test = std::make_shared<MultParamNode>(_phi[0], _phi[1], feat_ind, 1e-50, 1e50, _optimizer);
    
            EXPECT_EQ(_mult_test->rung(), 1);
    
            std::vector<double> expected_val(900, 0.0);
    
            allowed_op_funcs::mult(900, _phi[0]->value_ptr(), _phi[1]->value_ptr(), _mult_test->parameters()[0], _mult_test->parameters()[1], expected_val.data());
            EXPECT_LT(std::abs(_mult_test->value_ptr()[0] - expected_val[0]), 1e-10);
            EXPECT_LT(std::abs(_mult_test->value()[0] - expected_val[0]), 1e-10);
    
            allowed_op_funcs::mult(10, _phi[0]->test_value_ptr(), _phi[1]->test_value_ptr(), _mult_test->parameters()[0], _mult_test->parameters()[1], expected_val.data());
            EXPECT_LT(std::abs(_mult_test->test_value_ptr()[0] - expected_val[0]), 1e-10);
            EXPECT_LT(std::abs(_mult_test->test_value()[0] - expected_val[0]), 1e-10);
    
            std::stringstream postfix;
            postfix << "0|1|mult: " << std::setprecision(13) << std::scientific <<_mult_test->parameters()[0] << ',' << _mult_test->parameters()[1];
            EXPECT_STREQ(_mult_test->unit().toString().c_str(), "m * s");
            EXPECT_STREQ(_mult_test->postfix_expr().c_str(), postfix.str().c_str());
        }
    }
    #endif