Commit c296d752 authored by Thomas Purcell's avatar Thomas Purcell
Browse files

Add tests for standardized values

This is now explicitly tested
parent 91dd07d0
......@@ -278,11 +278,18 @@ void node_value_arrs::finalize_values_arr()
TASK_START_TRAIN.resize(0);
TASK_SZ_TEST.resize(0);
PARAM_STORAGE_ARR.resize(0);
PARAM_STORAGE_TEST_ARR.resize(0);
D_MATRIX.resize(0);
VALUES_ARR.resize(0);
TEST_VALUES_ARR.resize(0);
TEMP_STORAGE_ARR.resize(0);
TEMP_STORAGE_TEST_ARR.resize(0);
PARAM_STORAGE_ARR.resize(0);
PARAM_STORAGE_TEST_ARR.resize(0);
STANDARDIZED_D_MATRIX.resize(0);
STANDARDIZED_STORAGE_ARR.resize(0);
STANDARDIZED_TEST_STORAGE_ARR.resize(0);
}
......@@ -21,9 +21,9 @@ namespace {
//test mean calculations
TEST(ValueStorage, ValueStorageTest)
{
EXPECT_THROW(node_value_arrs::initialize_values_arr({5}, {2}, 1, -2, true), std::logic_error);
EXPECT_THROW(node_value_arrs::initialize_values_arr({5}, {2}, 2, -2, true), std::logic_error);
node_value_arrs::initialize_values_arr({5}, {2}, 1, 2, true);
node_value_arrs::initialize_values_arr({5}, {2}, 2, 2, true);
EXPECT_THROW(node_value_arrs::set_task_sz_train({20}), std::logic_error);
EXPECT_THROW(node_value_arrs::set_task_sz_test({6}), std::logic_error);
......@@ -31,15 +31,19 @@ namespace {
EXPECT_EQ(node_value_arrs::N_SAMPLES, 5);
EXPECT_EQ(node_value_arrs::N_SAMPLES_TEST, 2);
EXPECT_EQ(node_value_arrs::N_RUNGS_STORED, 0);
EXPECT_EQ(node_value_arrs::N_STORE_FEATURES, 1);
EXPECT_EQ(node_value_arrs::N_STORE_FEATURES, 2);
EXPECT_EQ(node_value_arrs::N_OP_SLOTS, 6);
EXPECT_EQ(node_value_arrs::MAX_RUNG, 2);
EXPECT_EQ(node_value_arrs::VALUES_ARR.size(), 5);
EXPECT_EQ(node_value_arrs::TEST_VALUES_ARR.size(), 2);
EXPECT_EQ(node_value_arrs::TEMP_STORAGE_ARR.size(), node_value_arrs::MAX_N_THREADS * (6 * 1 + 1) * 5);
EXPECT_EQ(node_value_arrs::TEMP_STORAGE_REG.size(), node_value_arrs::MAX_N_THREADS * (6 * 1 + 1));
EXPECT_EQ(node_value_arrs::TEMP_STORAGE_TEST_ARR.size(), node_value_arrs::MAX_N_THREADS * (6 * 1 + 1) * 2);
EXPECT_EQ(node_value_arrs::TEMP_STORAGE_TEST_REG.size(), node_value_arrs::MAX_N_THREADS * (6 * 1 + 1));
EXPECT_EQ(node_value_arrs::VALUES_ARR.size(), 10);
EXPECT_EQ(node_value_arrs::TEST_VALUES_ARR.size(), 4);
EXPECT_EQ(node_value_arrs::TEMP_STORAGE_ARR.size(), node_value_arrs::MAX_N_THREADS * (6 * 2 + 1) * 5);
EXPECT_EQ(node_value_arrs::TEMP_STORAGE_REG.size(), node_value_arrs::MAX_N_THREADS * (6 * 2 + 1));
EXPECT_EQ(node_value_arrs::TEMP_STORAGE_TEST_ARR.size(), node_value_arrs::MAX_N_THREADS * (6 * 2 + 1) * 2);
EXPECT_EQ(node_value_arrs::TEMP_STORAGE_TEST_REG.size(), node_value_arrs::MAX_N_THREADS * (6 * 2 + 1));
EXPECT_EQ(node_value_arrs::STANDARDIZED_D_MATRIX.size(), 0);
EXPECT_EQ(node_value_arrs::STANDARDIZED_STORAGE_ARR.size(), node_value_arrs::MAX_N_THREADS * 2 * 3 * 5);
EXPECT_EQ(node_value_arrs::STANDARDIZED_TEST_STORAGE_ARR.size(), node_value_arrs::MAX_N_THREADS * 2 * 3 * 2);
EXPECT_THROW(node_value_arrs::resize_values_arr(10, 2), std::logic_error);
node_value_arrs::resize_values_arr(1, 2);
......@@ -53,14 +57,17 @@ namespace {
node_value_arrs::initialize_d_matrix_arr();
EXPECT_EQ(node_value_arrs::N_SELECTED, 0);
EXPECT_EQ(node_value_arrs::D_MATRIX.size(), 0);
EXPECT_EQ(node_value_arrs::STANDARDIZED_D_MATRIX.size(), 0);
node_value_arrs::resize_d_matrix_arr(2);
EXPECT_EQ(node_value_arrs::N_SELECTED, 2);
EXPECT_EQ(node_value_arrs::D_MATRIX.size(), 10);
EXPECT_EQ(node_value_arrs::STANDARDIZED_D_MATRIX.size(), 10);
node_value_arrs::resize_d_matrix_arr(3);
EXPECT_EQ(node_value_arrs::N_SELECTED, 5);
EXPECT_EQ(node_value_arrs::D_MATRIX.size(), 25);
EXPECT_EQ(node_value_arrs::STANDARDIZED_D_MATRIX.size(), 25);
node_value_arrs::get_value_ptr(1, 1, 0)[1] = 1.0;
EXPECT_EQ(node_value_arrs::VALUES_ARR[6], 1.0);
......@@ -71,22 +78,28 @@ namespace {
node_value_arrs::get_value_ptr(10, 141, 2, 0)[0] = 1.0;
EXPECT_EQ(node_value_arrs::temp_storage_reg(10, 2, 0, false), 141);
EXPECT_EQ(node_value_arrs::access_temp_storage(node_value_arrs::get_op_slot(2, 0, false))[0], 1.0);
EXPECT_EQ(node_value_arrs::access_temp_storage(node_value_arrs::get_op_slot(2, 0, false) * 2)[0], 1.0);
node_value_arrs::get_test_value_ptr(10, 141, 2, 0)[0] = 1.0;
EXPECT_EQ(node_value_arrs::temp_storage_test_reg(10, 2, 0, false), 141);
EXPECT_EQ(node_value_arrs::access_temp_storage_test(node_value_arrs::get_op_slot(2, 0, false))[0], 1.0);
EXPECT_EQ(node_value_arrs::access_temp_storage_test(node_value_arrs::get_op_slot(2, 0, false) * 2)[0], 1.0);
node_value_arrs::get_d_matrix_ptr(1)[0] = 1.0;
EXPECT_EQ(node_value_arrs::D_MATRIX[5], 1.0);
node_value_arrs::access_temp_stand_storage(1, false)[0] = 3.0;
EXPECT_EQ(node_value_arrs::STANDARDIZED_STORAGE_ARR[5 + omp_get_thread_num() * 30], 3.0);
node_value_arrs::access_temp_stand_storage_test(0, true)[0] = 3.0;
EXPECT_EQ(node_value_arrs::STANDARDIZED_TEST_STORAGE_ARR[4 + omp_get_thread_num() * 12], 3.0);
#pragma omp parallel
{
int sz_reg = (node_value_arrs::N_OP_SLOTS * node_value_arrs::N_PRIMARY_FEATURES + 1);
std::fill_n(node_value_arrs::TEMP_STORAGE_REG.data() + sz_reg * omp_get_thread_num(), sz_reg, omp_get_thread_num() + 1);
EXPECT_EQ(node_value_arrs::TEMP_STORAGE_REG[7 * omp_get_thread_num()], omp_get_thread_num() + 1);
EXPECT_EQ(node_value_arrs::TEMP_STORAGE_REG[14 * omp_get_thread_num()], omp_get_thread_num() + 1);
node_value_arrs::clear_temp_reg_thread();
EXPECT_EQ(node_value_arrs::TEMP_STORAGE_REG[7 * omp_get_thread_num()], -1);
EXPECT_EQ(node_value_arrs::TEMP_STORAGE_REG[14 * omp_get_thread_num()], -1);
}
std::fill_n(node_value_arrs::TEMP_STORAGE_REG.data(), node_value_arrs::TEMP_STORAGE_REG.size(), 2.0);
......@@ -107,6 +120,9 @@ namespace {
EXPECT_EQ(node_value_arrs::PARAM_STORAGE_ARR.size(), 0);
EXPECT_EQ(node_value_arrs::PARAM_STORAGE_TEST_ARR.size(), 0);
EXPECT_EQ(node_value_arrs::D_MATRIX.size(), 0);
EXPECT_EQ(node_value_arrs::STANDARDIZED_D_MATRIX.size(), 0);
EXPECT_EQ(node_value_arrs::STANDARDIZED_STORAGE_ARR.size(), 0);
EXPECT_EQ(node_value_arrs::STANDARDIZED_TEST_STORAGE_ARR.size(), 0);
EXPECT_EQ(node_value_arrs::TASK_SZ_TRAIN.size(), 0);
EXPECT_EQ(node_value_arrs::TASK_START_TRAIN.size(), 0);
EXPECT_EQ(node_value_arrs::TASK_SZ_TEST.size(), 0);
......
......@@ -324,6 +324,21 @@ namespace {
EXPECT_EQ(util_funcs::max_abs_val<double>(dNeg3.data(), dNeg3.size()), 9.0);
}
// test standardize
TEST(MathUtils, StandardizeTest)
{
std::vector<double> test = {2, 4, 4, 4, 5, 5, 7, 9};
std::vector<double> test_std = {-1.5, -0.5, -0.5, -0.5, 0, 0, 1, 2};
util_funcs::standardize(test.data(), test.size(), test.data());
EXPECT_LT(std::abs(util_funcs::mean(test)), 1e-10);
EXPECT_LT(std::abs(util_funcs::stand_dev(test) - 1.0), 1e-10);
std::transform(test.begin(), test.end(), test_std.begin(), test.begin(), std::minus<double>());
EXPECT_TRUE(std::all_of(test.begin(), test.end(), [](double val){return std::abs(val) < 1e-10;}));
}
// test iterate
TEST(MathUtils, IterateTest)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment