From 5ee3f23d8b576632a53681cb031a3e19823e456d Mon Sep 17 00:00:00 2001 From: Martin Reinecke Date: Wed, 19 Jul 2017 14:16:22 +0200 Subject: [PATCH] re-introduce higher-order interpolation --- .../line_search_strong_wolfe.py | 30 ++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/nifty/minimization/line_searching/line_search_strong_wolfe.py b/nifty/minimization/line_searching/line_search_strong_wolfe.py index eb993f35..7be46724 100644 --- a/nifty/minimization/line_searching/line_search_strong_wolfe.py +++ b/nifty/minimization/line_searching/line_search_strong_wolfe.py @@ -234,16 +234,34 @@ class LineSearchStrongWolfe(LineSearch): cubic_delta = 0.2 # cubic quad_delta = 0.1 # quadratic phiprime_alphaj = 0. - # initialize the most recent versions (j-1) of phi and alpha - #alpha_recent = None - #phi_recent = None + assert phi_lo <= phi_0 + self.c1*alpha_lo*phiprime_0 assert phiprime_lo*(alpha_hi-alpha_lo)<0. for i in xrange(max_iterations): #assert phi_lo <= phi_0 + self.c1*alpha_lo*phiprime_0 #assert phiprime_lo*(alpha_hi-alpha_lo)<0. delta_alpha = alpha_hi - alpha_lo - alpha_j = alpha_lo + 0.5*delta_alpha + if delta_alpha < 0: + a, b = alpha_hi, alpha_lo + else: + a, b = alpha_lo, alpha_hi + + # Try cubic interpolation + if i > 0: + cubic_check = cubic_delta * delta_alpha + alpha_j = self._cubicmin(alpha_lo, phi_lo, phiprime_lo, + alpha_hi, phi_hi, + alpha_recent, phi_recent) + # If cubic was not successful or not available, try quadratic + if (i == 0) or (alpha_j is None) or (alpha_j > b - cubic_check) or\ + (alpha_j < a + cubic_check): + quad_check = quad_delta * delta_alpha + alpha_j = self._quadmin(alpha_lo, phi_lo, phiprime_lo, + alpha_hi, phi_hi) + # If quadratic was not successful, try bisection + if (alpha_j is None) or (alpha_j > b - quad_check) or \ + (alpha_j < a + quad_check): + alpha_j = alpha_lo + 0.5*delta_alpha # Check if the current value of alpha_j is already sufficient le_alphaj = self.line_energy.at(alpha_j) @@ -253,6 +271,7 @@ class LineSearchStrongWolfe(LineSearch): # by alpha_j if (phi_alphaj > phi_0 + self.c1*alpha_j*phiprime_0) or\ (phi_alphaj >= phi_lo): + alpha_recent, phi_recent = alpha_hi, phi_hi alpha_hi, phi_hi = alpha_j, phi_alphaj else: phiprime_alphaj = le_alphaj.directional_derivative @@ -264,7 +283,10 @@ class LineSearchStrongWolfe(LineSearch): break # If not, check the sign of the slope if phiprime_alphaj*delta_alpha >= 0: + alpha_recent, phi_recent = alpha_hi, phi_hi alpha_hi, phi_hi = alpha_lo, phi_lo + else: + alpha_recent, phi_recent = alpha_lo, phi_lo # Replace alpha_lo by alpha_j (alpha_lo, phi_lo, phiprime_lo) = (alpha_j, phi_alphaj, phiprime_alphaj) -- GitLab