Skip to content
Snippets Groups Projects
Commit c9315901 authored by Thomas Purcell's avatar Thomas Purcell
Browse files

Use sign_alpha for logarithm

+/-x + \beta is not the same for loarithms
parent 4d36c238
Branches
No related tags found
No related merge requests found
......@@ -266,7 +266,7 @@ public:
* @param from_parent How many parameters are between the start of this node's parameters and its parent
* @param depth the current depth of the node on the Binary expression tree
*/
void set_bounds(double* lb, double* ub, const int depth=1) const;
virtual void set_bounds(double* lb, double* ub, const int depth=1) const;
/**
* @brief Set the bounds for the nl parameterization
......@@ -274,7 +274,7 @@ public:
* @param params pointer to the parameters vector
* @param depth the current depth of the node on the Binary expression tree
*/
void initialize_params(double* params, const int depth = 1) const;
virtual void initialize_params(double* params, const int depth = 1) const;
/**
* @brief Calculates the derivative of an operation with respect to the parameters for a given sample
......
......@@ -146,8 +146,19 @@ node_ptr LogParamNode::hard_copy()const
void LogParamNode::get_parameters(std::shared_ptr<NLOptimizer> optimizer)
{
// Change the sign of alpha as a control on linear dependency without forcing one sign or another
_sign_alpha = 1.0;
double min_res = optimizer->optimize_feature_params(this);
if(min_res == std::numeric_limits<double>::infinity())
std::vector<double> param_cp(_params);
_sign_alpha = -1.0;
double min_res_neg = optimizer->optimize_feature_params(this);
if(min_res_neg > min_res)
{
std::copy_n(param_cp.data(), param_cp.size(), _params.data());
_sign_alpha = 1.0;
}
else if(min_res_neg == std::numeric_limits<double>::infinity())
{
_params[0] = 0.0;
}
......@@ -234,6 +245,37 @@ void LogNode::initialize_params(double* params, const int depth) const
params[1] = std::max(0.0, -1.0 * (*std::min_element(val_ptr, val_ptr + _n_samp)) + 1e-10);
}
void LogParamNode::set_bounds(double* lb, double* ub, const int depth) const
{
// The parameters of logarithm are dependent on the external shift/scale parameters, but physically relevant
lb[0] = _sign_alpha;
ub[0] = _sign_alpha;
if(depth >= nlopt_wrapper::MAX_PARAM_DEPTH)
{
return;
}
_feats[0]->set_bounds(lb + 2, ub + 2, depth + 1);
}
void LogParamNode::initialize_params(double* params, const int depth) const
{
params[0] = _sign_alpha;
double* val_ptr;
if(depth >= nlopt_wrapper::MAX_PARAM_DEPTH)
{
val_ptr = _feats[0]->value_ptr();
params[1] = std::max(0.0, -1.0 * (*std::min_element(val_ptr, val_ptr + _n_samp)) + 1e-10);
return;
}
_feats[0]->initialize_params(params + 2, depth + 1);
val_ptr = _feats[0]->value_ptr(params + 2);
params[1] = std::max(0.0, -1.0 * (*std::min_element(val_ptr, val_ptr + _n_samp)) + 1e-10);
}
void LogParamNode::update_postfix(std::string& cur_expr, const bool add_params) const
{
std::stringstream postfix;
......
......@@ -43,6 +43,7 @@ protected:
using LogNode::matlab_fxn_expr;
std::vector<double> _params; //!< The parameters vector
double _sign_alpha; //!< Used to alternate between +/- 1
public:
/**
......@@ -161,6 +162,24 @@ public:
*/
inline std::string matlab_fxn_expr() const {return matlab_fxn_expr(_params.data());}
/**
* @brief Set the bounds for the nl parameterization
*
* @param lb pointer to the lower bounds data
* @param ub pointer to the upper bounds data
* @param from_parent How many parameters are between the start of this node's parameters and its parent
* @param depth the current depth of the node on the Binary expression tree
*/
void set_bounds(double* lb, double* ub, const int depth=1) const;
/**
* @brief Set the bounds for the nl parameterization
*
* @param params pointer to the parameters vector
* @param depth the current depth of the node on the Binary expression tree
*/
void initialize_params(double* params, const int depth = 1) const;
/**
* @brief The parameters used for introducing more non linearity in the operators
*/
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment