diff --git a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log/parameterized_log.cpp b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log/parameterized_log.cpp index a0febd370886fd420f1e4c9d9a335ed42f3c0969..b3deff63490e3f00f49f596d1a82aa700e5f545b 100644 --- a/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log/parameterized_log.cpp +++ b/src/feature_creation/node/operator_nodes/allowed_operator_nodes/log/parameterized_log.cpp @@ -267,13 +267,13 @@ void LogParamNode::initialize_params(double* params, const int depth) const if(depth >= nlopt_wrapper::MAX_PARAM_DEPTH) { val_ptr = _feats[0]->value_ptr(); - params[1] = std::max(0.0, -1.0 * (*std::min_element(val_ptr, val_ptr + _n_samp)) + 1e-10); + params[1] = std::max(0.0, -1.0 * _sign_alpha * (*std::min_element(val_ptr, val_ptr + _n_samp, [this](double x1, double x2){return x1 * _sign_alpha < x2 * _sign_alpha;})) + 1e-10); return; } _feats[0]->initialize_params(params + 2, depth + 1); val_ptr = _feats[0]->value_ptr(params + 2); - params[1] = std::max(0.0, -1.0 * (*std::min_element(val_ptr, val_ptr + _n_samp)) + 1e-10); + params[1] = std::max(0.0, -1.0 * _sign_alpha * (*std::min_element(val_ptr, val_ptr + _n_samp, [this](double x1, double x2){return x1 * _sign_alpha < x2 * _sign_alpha;})) + 1e-10); } void LogParamNode::update_postfix(std::string& cur_expr, const bool add_params) const