Skip to content
Snippets Groups Projects
Commit 54384f7d authored by Thomas Purcell's avatar Thomas Purcell
Browse files

Fix paramter initialization

For sign_alpha=-1 old initialization fails
parent c9315901
No related branches found
No related tags found
No related merge requests found
...@@ -267,13 +267,13 @@ void LogParamNode::initialize_params(double* params, const int depth) const ...@@ -267,13 +267,13 @@ void LogParamNode::initialize_params(double* params, const int depth) const
if(depth >= nlopt_wrapper::MAX_PARAM_DEPTH) if(depth >= nlopt_wrapper::MAX_PARAM_DEPTH)
{ {
val_ptr = _feats[0]->value_ptr(); val_ptr = _feats[0]->value_ptr();
params[1] = std::max(0.0, -1.0 * (*std::min_element(val_ptr, val_ptr + _n_samp)) + 1e-10); params[1] = std::max(0.0, -1.0 * _sign_alpha * (*std::min_element(val_ptr, val_ptr + _n_samp, [this](double x1, double x2){return x1 * _sign_alpha < x2 * _sign_alpha;})) + 1e-10);
return; return;
} }
_feats[0]->initialize_params(params + 2, depth + 1); _feats[0]->initialize_params(params + 2, depth + 1);
val_ptr = _feats[0]->value_ptr(params + 2); val_ptr = _feats[0]->value_ptr(params + 2);
params[1] = std::max(0.0, -1.0 * (*std::min_element(val_ptr, val_ptr + _n_samp)) + 1e-10); params[1] = std::max(0.0, -1.0 * _sign_alpha * (*std::min_element(val_ptr, val_ptr + _n_samp, [this](double x1, double x2){return x1 * _sign_alpha < x2 * _sign_alpha;})) + 1e-10);
} }
void LogParamNode::update_postfix(std::string& cur_expr, const bool add_params) const void LogParamNode::update_postfix(std::string& cur_expr, const bool add_params) const
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment