Commit 6e71f1e8 authored by Martin Reinecke's avatar Martin Reinecke

more descriptive interface for LineEnergy

parent 12a5a3ac
Pipeline #15149 passed with stage
in 6 minutes and 36 seconds
...@@ -25,12 +25,12 @@ class LineEnergy: ...@@ -25,12 +25,12 @@ class LineEnergy:
Parameters Parameters
---------- ----------
linepos : float line_position : float
Defines the full spatial position of this energy via Defines the full spatial position of this energy via
self.energy.position = zero_point + linepos*line_direction self.energy.position = zero_point + line_position*line_direction
energy : Energy energy : Energy
The Energy object which will be evaluated along the given direction. The Energy object which will be evaluated along the given direction.
linedir : Field line_direction : Field
Direction used for line evaluation. Does not have to be normalized. Direction used for line evaluation. Does not have to be normalized.
offset : float *optional* offset : float *optional*
Indirectly defines the zero point of the line via the equation Indirectly defines the zero point of the line via the equation
...@@ -39,13 +39,13 @@ class LineEnergy: ...@@ -39,13 +39,13 @@ class LineEnergy:
Attributes Attributes
---------- ----------
linepos : float line_position : float
The position along the given line direction relative to the zero point. The position along the given line direction relative to the zero point.
value : float value : float
The value of the energy functional at the given position The value of the energy functional at the given position
dd : float directional_derivative : float
The directional derivative at the given position The directional derivative at the given position
linedir : Field line_direction : Field
Direction along which the movement is restricted. Does not have to be Direction along which the movement is restricted. Does not have to be
normalized. normalized.
energy : Energy energy : Energy
...@@ -67,19 +67,20 @@ class LineEnergy: ...@@ -67,19 +67,20 @@ class LineEnergy:
""" """
def __init__(self, linepos, energy, linedir, offset=0.): def __init__(self, line_position, energy, line_direction, offset=0.):
self._linepos = float(linepos) self._line_position = float(line_position)
self._linedir = linedir self._line_direction = line_direction
pos = energy.position + (self._linepos-float(offset))*self._linedir pos = energy.position \
+ (self._line_position-float(offset))*self._line_direction
self.energy = energy.at(position=pos) self.energy = energy.at(position=pos)
def at(self, linepos): def at(self, line_position):
""" Returns LineEnergy at new position, memorizing the zero point. """ Returns LineEnergy at new position, memorizing the zero point.
Parameters Parameters
---------- ----------
linepos : float line_position : float
Parameter for the new position on the line direction. Parameter for the new position on the line direction.
Returns Returns
...@@ -88,27 +89,27 @@ class LineEnergy: ...@@ -88,27 +89,27 @@ class LineEnergy:
""" """
return self.__class__(linepos, return self.__class__(line_position,
self.energy, self.energy,
self.linedir, self.line_direction,
offset=self.linepos) offset=self.line_position)
@property @property
def value(self): def value(self):
return self.energy.value return self.energy.value
@property @property
def linepos(self): def line_position(self):
return self._linepos return self._line_position
@property @property
def linedir(self): def line_direction(self):
return self._linedir return self._line_direction
@property @property
def dd(self): def directional_derivative(self):
res = self.energy.gradient.vdot(self.linedir) res = self.energy.gradient.vdot(self.line_direction)
if abs(res.imag)/max(abs(res.real),1.)>1e-12: if abs(res.imag) / max(abs(res.real), 1.) > 1e-12:
print "directional derivative has non-negligible " \ print "directional derivative has non-negligible " \
"imaginary part:", res "imaginary part:", res
return res.real return res.real
...@@ -162,7 +162,6 @@ class DescentMinimizer(Loggable, object): ...@@ -162,7 +162,6 @@ class DescentMinimizer(Loggable, object):
tx1=energy.position-new_energy.position tx1=energy.position-new_energy.position
# check if new energy value is bigger than old energy value # check if new energy value is bigger than old energy value
if (new_energy.value - energy.value) > 0: if (new_energy.value - energy.value) > 0:
print "Line search algorithm returned a new energy that was larger than the old one. Stopping."
self.logger.info("Line search algorithm returned a new energy " self.logger.info("Line search algorithm returned a new energy "
"that was larger than the old one. Stopping.") "that was larger than the old one. Stopping.")
break break
......
...@@ -63,9 +63,9 @@ class LineSearch(Loggable, object): ...@@ -63,9 +63,9 @@ class LineSearch(Loggable, object):
iteration of the line search procedure. (Default: None) iteration of the line search procedure. (Default: None)
""" """
self.line_energy = LineEnergy(linepos=0., self.line_energy = LineEnergy(line_position=0.,
energy=energy, energy=energy,
linedir=pk) line_direction=pk)
if f_k_minus_1 is not None: if f_k_minus_1 is not None:
f_k_minus_1 = f_k_minus_1.copy() f_k_minus_1 = f_k_minus_1.copy()
self.f_k_minus_1 = f_k_minus_1 self.f_k_minus_1 = f_k_minus_1
......
...@@ -110,7 +110,7 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -110,7 +110,7 @@ class LineSearchStrongWolfe(LineSearch):
old_phi_0 = self.f_k_minus_1 old_phi_0 = self.f_k_minus_1
le_0 = self.line_energy.at(0) le_0 = self.line_energy.at(0)
phi_0 = le_0.value phi_0 = le_0.value
phiprime_0 = le_0.dd phiprime_0 = le_0.directional_derivative
assert phiprime_0<0, "input direction must be a descent direction" assert phiprime_0<0, "input direction must be a descent direction"
# set alphas # set alphas
...@@ -141,7 +141,6 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -141,7 +141,6 @@ class LineSearchStrongWolfe(LineSearch):
if (phi_alpha1 > phi_0 + self.c1*alpha1*phiprime_0) or \ if (phi_alpha1 > phi_0 + self.c1*alpha1*phiprime_0) or \
((phi_alpha1 >= phi_alpha0) and (i > 0)): ((phi_alpha1 >= phi_alpha0) and (i > 0)):
print "zoom1:",i
(alpha_star, phi_star, le_star) = self._zoom( (alpha_star, phi_star, le_star) = self._zoom(
alpha0, alpha1, alpha0, alpha1,
phi_0, phiprime_0, phi_0, phiprime_0,
...@@ -150,7 +149,7 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -150,7 +149,7 @@ class LineSearchStrongWolfe(LineSearch):
phi_alpha1) phi_alpha1)
break break
phiprime_alpha1 = le_alpha1.dd phiprime_alpha1 = le_alpha1.directional_derivative
if abs(phiprime_alpha1) <= -self.c2*phiprime_0: if abs(phiprime_alpha1) <= -self.c2*phiprime_0:
alpha_star = alpha1 alpha_star = alpha1
phi_star = phi_alpha1 phi_star = phi_alpha1
...@@ -158,7 +157,6 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -158,7 +157,6 @@ class LineSearchStrongWolfe(LineSearch):
break break
if phiprime_alpha1 >= 0: if phiprime_alpha1 >= 0:
print "zoom2:",i
(alpha_star, phi_star, le_star) = self._zoom( (alpha_star, phi_star, le_star) = self._zoom(
alpha1, alpha0, alpha1, alpha0,
phi_0, phiprime_0, phi_0, phiprime_0,
...@@ -241,11 +239,9 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -241,11 +239,9 @@ class LineSearchStrongWolfe(LineSearch):
#phi_recent = None #phi_recent = None
assert phi_lo <= phi_0 + self.c1*alpha_lo*phiprime_0 assert phi_lo <= phi_0 + self.c1*alpha_lo*phiprime_0
assert phiprime_lo*(alpha_hi-alpha_lo)<0. assert phiprime_lo*(alpha_hi-alpha_lo)<0.
print "enter:"
for i in xrange(max_iterations): for i in xrange(max_iterations):
#assert phi_lo <= phi_0 + self.c1*alpha_lo*phiprime_0 #assert phi_lo <= phi_0 + self.c1*alpha_lo*phiprime_0
#assert phiprime_lo*(alpha_hi-alpha_lo)<0. #assert phiprime_lo*(alpha_hi-alpha_lo)<0.
# print alpha_lo, alpha_hi
delta_alpha = alpha_hi - alpha_lo delta_alpha = alpha_hi - alpha_lo
alpha_j = alpha_lo + 0.5*delta_alpha alpha_j = alpha_lo + 0.5*delta_alpha
...@@ -255,18 +251,11 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -255,18 +251,11 @@ class LineSearchStrongWolfe(LineSearch):
# If the first Wolfe condition is not met replace alpha_hi # If the first Wolfe condition is not met replace alpha_hi
# by alpha_j # by alpha_j
# print "W1:", phi_alphaj, phi_0 + self.c1*alpha_j*phiprime_0
print alpha_lo, phi_lo
print alpha_hi, phi_hi
# print phi_lo, phi_hi, phi_alphaj
# print phiprime_lo, phiprime_alphaj
if (phi_alphaj > phi_0 + self.c1*alpha_j*phiprime_0) or\ if (phi_alphaj > phi_0 + self.c1*alpha_j*phiprime_0) or\
(phi_alphaj >= phi_lo): (phi_alphaj >= phi_lo):
# print "beep"
alpha_hi, phi_hi = alpha_j, phi_alphaj alpha_hi, phi_hi = alpha_j, phi_alphaj
else: else:
# print "boop" phiprime_alphaj = le_alphaj.directional_derivative
phiprime_alphaj = le_alphaj.dd
# If the second Wolfe condition is met, return the result # If the second Wolfe condition is met, return the result
if abs(phiprime_alphaj) <= -self.c2*phiprime_0: if abs(phiprime_alphaj) <= -self.c2*phiprime_0:
alpha_star = alpha_j alpha_star = alpha_j
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment