Commit 9960b3a9 authored by Thomas Purcell's avatar Thomas Purcell
Browse files

Bug Fixes: Setup Python library to not initialize centeral storage with n_samples

Minor changes to fix bugs when setting these changes up
parent d6d4f35d
......@@ -437,7 +437,7 @@ namespace node_value_arrs
#ifdef PY_BINDINGS
// DocString: node_vals_ts_list
// DocString: node_vals_ts_list_no_params
/**
* @brief Initialize the node value arrays
* @details Using the size of the initial feature space constructor the storage arrays
......@@ -463,7 +463,7 @@ namespace node_value_arrs
);
}
// DocString: node_vals_ts_arr
// DocString: node_vals_ts_arr_no_params
/**
* @brief Initialize the node value arrays
* @details Using the size of the initial feature space constructor the storage arrays
......@@ -488,6 +488,62 @@ namespace node_value_arrs
false
);
}
// DocString: node_vals_ts_list
/**
* @brief Initialize the node value arrays
* @details Using the size of the initial feature space constructor the storage arrays
*
* @param task_sz_train (list): The number of training samples per task
* @param task_sz_test (list): The number of test sample per task
* @param n_primary_feat (int): The number of primary features
* @param max_rung (int): The maximum rung for all features
* @param use_params (bool): If true also initialize parameterized storage
*/
inline void initialize_values_arr(
py::list task_sz_train,
py::list task_sz_test,
int n_primary_feat,
int max_rung,
bool use_params
)
{
initialize_values_arr(
python_conv_utils::from_list<int>(task_sz_train),
python_conv_utils::from_list<int>(task_sz_test),
n_primary_feat,
max_rung,
use_params
);
}
// DocString: node_vals_ts_arr
/**
* @brief Initialize the node value arrays
* @details Using the size of the initial feature space constructor the storage arrays
*
* @param task_sz_train (np.ndarray): The number of training samples per task
* @param task_sz_test (np.ndarray): The number of test sample per task
* @param n_primary_feat (int): The number of primary features
* @param max_rung (int): The maximum rung for all features
* @param use_params (bool): If true also initialize parameterized storage
*/
inline void initialize_values_arr(
np::ndarray task_sz_train,
np::ndarray task_sz_test,
int n_primary_feat,
int max_rung,
bool use_params
)
{
initialize_values_arr(
python_conv_utils::from_ndarray<int>(task_sz_train),
python_conv_utils::from_ndarray<int>(task_sz_test),
n_primary_feat,
max_rung,
use_params
);
}
#endif
}
......
......@@ -977,7 +977,10 @@ public:
/**
* @brief Sets a list of FeatureNodes for the primary feature space
*/
inline void set_phi_0_py(py::list phi_0) {_phi_0 = python_conv_utils::from_list<FeatureNode>(phi_0);}
inline void set_phi_0_py(py::list phi_0)
{
set_phi_0(python_conv_utils::from_list<FeatureNode>(phi_0));
}
#endif
};
......
......@@ -60,8 +60,6 @@ void sisso::register_all()
sisso::feature_creation::node::registerSixPowNode();
void (*init_val_ar)(int, int, int, int) = &node_value_arrs::initialize_values_arr;
void (*init_val_ar_list)(py::list, py::list, int, int) = &node_value_arrs::initialize_values_arr;
void (*init_val_ar_arr)(np::ndarray, np::ndarray, int, int) = &node_value_arrs::initialize_values_arr;
def(
"phi_selected_from_file",
......@@ -75,18 +73,40 @@ void sisso::register_all()
(arg("n_samples"), arg("n_samples_test"), arg("n_primary_feat"), arg("max_rung")),
"@DocString_node_vals_init_no_ts@"
);
void (*init_val_ar_list)(py::list, py::list, int, int, bool) = &node_value_arrs::initialize_values_arr;
void (*init_val_ar_arr)(np::ndarray, np::ndarray, int, int, bool) = &node_value_arrs::initialize_values_arr;
void (*init_val_ar_list_no_params)(py::list, py::list, int, int) = &node_value_arrs::initialize_values_arr;
void (*init_val_ar_arr_no_params)(np::ndarray, np::ndarray, int, int) = &node_value_arrs::initialize_values_arr;
def(
"initialize_values_arr",
init_val_ar_list,
(arg("task_sz_train"), arg("task_sz_test"), arg("n_primary_feat"), arg("max_rung")),
(arg("task_sz_train"), arg("task_sz_test"), arg("n_primary_feat"), arg("max_rung"), arg("use_params")),
"@DocString_node_vals_ts_list@"
);
def(
"initialize_values_arr",
init_val_ar_arr,
(arg("task_sz_train"), arg("task_sz_test"), arg("n_primary_feat"), arg("max_rung")),
(arg("task_sz_train"), arg("task_sz_test"), arg("n_primary_feat"), arg("max_rung"), arg("use_params")),
"@DocString_node_vals_ts_arr@"
);
def(
"initialize_values_arr",
init_val_ar_list_no_params,
(arg("task_sz_train"), arg("task_sz_test"), arg("n_primary_feat"), arg("max_rung")),
"@DocString_node_vals_ts_list_no_params@"
);
def(
"initialize_values_arr",
init_val_ar_arr_no_params,
(arg("task_sz_train"), arg("task_sz_test"), arg("n_primary_feat"), arg("max_rung")),
"@DocString_node_vals_ts_arr_no_params@"
);
def(
"initialize_d_matrix_arr",
&node_value_arrs::initialize_d_matrix_arr,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment