Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
7bb88123
Commit
7bb88123
authored
Oct 25, 2016
by
theos
Browse files
Fixed LineEnergy. Optimized the use of Energy class in LineSearchStrongWolfe.
parent
17b1c6e9
Changes
6
Hide whitespace changes
Inline
Side-by-side
demos/wiener_filter_hamiltonian.py
View file @
7bb88123
...
@@ -78,17 +78,17 @@ if __name__ == "__main__":
...
@@ -78,17 +78,17 @@ if __name__ == "__main__":
D
=
PropagatorOperator
(
S
=
S
,
N
=
N
,
R
=
R
)
D
=
PropagatorOperator
(
S
=
S
,
N
=
N
,
R
=
R
)
def
distance_measure
(
energy
,
iteration
):
def
distance_measure
(
energy
,
iteration
):
pass
x
=
energy
.
position
#
print (iteration, ((x-ss).norm()/ss.norm()).real)
print
(
iteration
,
((
x
-
ss
).
norm
()
/
ss
.
norm
()).
real
)
minimizer
=
SteepestDescent
(
convergence_tolerance
=
0
,
minimizer
=
SteepestDescent
(
convergence_tolerance
=
0
,
iteration_limit
=
50
,
iteration_limit
=
50
,
callback
=
distance_measure
)
callback
=
distance_measure
)
#
minimizer = VL_BFGS(convergence_tolerance=0,
minimizer
=
VL_BFGS
(
convergence_tolerance
=
0
,
#
iteration_limit=50,
iteration_limit
=
50
,
#
callback=distance_measure,
callback
=
distance_measure
,
#
max_history_length=5)
max_history_length
=
5
)
m0
=
Field
(
s_space
,
val
=
1
)
m0
=
Field
(
s_space
,
val
=
1
)
...
...
nifty/energies/line_energy.py
View file @
7bb88123
...
@@ -4,19 +4,22 @@ from .energy import Energy
...
@@ -4,19 +4,22 @@ from .energy import Energy
class
LineEnergy
(
Energy
):
class
LineEnergy
(
Energy
):
def
__init__
(
self
,
position
,
energy
,
line_direction
):
def
__init__
(
self
,
position
,
energy
,
line_direction
,
zero_point
=
None
):
self
.
energy
=
energy
self
.
line_direction
=
line_direction
super
(
LineEnergy
,
self
).
__init__
(
position
=
position
)
super
(
LineEnergy
,
self
).
__init__
(
position
=
position
)
self
.
line_direction
=
line_direction
if
zero_point
is
None
:
zero_point
=
energy
.
position
self
.
_zero_point
=
zero_point
position_on_line
=
self
.
_zero_point
+
self
.
position
*
line_direction
self
.
energy
=
energy
.
at
(
position
=
position_on_line
)
def
at
(
self
,
position
):
def
at
(
self
,
position
):
if
position
==
0
:
return
self
.
__class__
(
position
,
return
self
self
.
energy
,
else
:
self
.
line_direction
,
full_position
=
self
.
position
+
self
.
line_direction
*
position
zero_point
=
self
.
_zero_point
)
return
self
.
__class__
(
full_position
,
self
.
energy
,
self
.
line_direction
)
@
property
@
property
def
value
(
self
):
def
value
(
self
):
...
...
nifty/minimization/line_searching/line_search.py
View file @
7bb88123
...
@@ -59,6 +59,5 @@ class LineSearch(object, Loggable):
...
@@ -59,6 +59,5 @@ class LineSearch(object, Loggable):
self
.
f_k_minus_1
=
f_k_minus_1
self
.
f_k_minus_1
=
f_k_minus_1
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
perform_line_search
(
self
,
xk
,
pk
,
f_k
=
None
,
fprime_k
=
None
,
def
perform_line_search
(
self
,
energy
,
pk
,
f_k_minus_1
=
None
):
f_k_minus_1
=
None
):
raise
NotImplementedError
raise
NotImplementedError
nifty/minimization/line_searching/line_search_strong_wolfe.py
View file @
7bb88123
...
@@ -54,8 +54,9 @@ class LineSearchStrongWolfe(LineSearch):
...
@@ -54,8 +54,9 @@ class LineSearchStrongWolfe(LineSearch):
# initialize the zero phis
# initialize the zero phis
old_phi_0
=
self
.
f_k_minus_1
old_phi_0
=
self
.
f_k_minus_1
phi_0
=
self
.
line_energy
.
at
(
0
).
value
energy_0
=
self
.
line_energy
.
at
(
0
)
phiprime_0
=
self
.
line_energy
.
at
(
0
).
gradient
phi_0
=
energy_0
.
value
phiprime_0
=
energy_0
.
gradient
if
phiprime_0
==
0
:
if
phiprime_0
==
0
:
self
.
logger
.
warn
(
"Flat gradient in search direction."
)
self
.
logger
.
warn
(
"Flat gradient in search direction."
)
...
@@ -82,11 +83,13 @@ class LineSearchStrongWolfe(LineSearch):
...
@@ -82,11 +83,13 @@ class LineSearchStrongWolfe(LineSearch):
self
.
logger
.
warn
(
"Increment size became 0."
)
self
.
logger
.
warn
(
"Increment size became 0."
)
alpha_star
=
0.
alpha_star
=
0.
phi_star
=
phi_0
phi_star
=
phi_0
energy_star
=
energy_0
break
break
if
(
phi_alpha1
>
phi_0
+
c1
*
alpha1
*
phiprime_0
)
or
\
if
(
phi_alpha1
>
phi_0
+
c1
*
alpha1
*
phiprime_0
)
or
\
((
phi_alpha1
>=
phi_alpha0
)
and
(
i
>
1
)):
((
phi_alpha1
>=
phi_alpha0
)
and
(
i
>
1
)):
(
alpha_star
,
phi_star
)
=
self
.
_zoom
(
alpha0
,
alpha1
,
(
alpha_star
,
phi_star
,
energy_star
)
=
self
.
_zoom
(
alpha0
,
alpha1
,
phi_0
,
phiprime_0
,
phi_0
,
phiprime_0
,
phi_alpha0
,
phi_alpha0
,
phiprime_alpha0
,
phiprime_alpha0
,
...
@@ -98,10 +101,12 @@ class LineSearchStrongWolfe(LineSearch):
...
@@ -98,10 +101,12 @@ class LineSearchStrongWolfe(LineSearch):
if
abs
(
phiprime_alpha1
)
<=
-
c2
*
phiprime_0
:
if
abs
(
phiprime_alpha1
)
<=
-
c2
*
phiprime_0
:
alpha_star
=
alpha1
alpha_star
=
alpha1
phi_star
=
phi_alpha1
phi_star
=
phi_alpha1
energy_star
=
energy_alpha1
break
break
if
phiprime_alpha1
>=
0
:
if
phiprime_alpha1
>=
0
:
(
alpha_star
,
phi_star
)
=
self
.
_zoom
(
alpha1
,
alpha0
,
(
alpha_star
,
phi_star
,
energy_star
)
=
self
.
_zoom
(
alpha1
,
alpha0
,
phi_0
,
phiprime_0
,
phi_0
,
phiprime_0
,
phi_alpha1
,
phi_alpha1
,
phiprime_alpha1
,
phiprime_alpha1
,
...
@@ -119,10 +124,15 @@ class LineSearchStrongWolfe(LineSearch):
...
@@ -119,10 +124,15 @@ class LineSearchStrongWolfe(LineSearch):
# max_iterations was reached
# max_iterations was reached
alpha_star
=
alpha1
alpha_star
=
alpha1
phi_star
=
phi_alpha1
phi_star
=
phi_alpha1
energy_star
=
energy_alpha1
self
.
logger
.
error
(
"The line search algorithm did not converge."
)
self
.
logger
.
error
(
"The line search algorithm did not converge."
)
self
.
_last_alpha_star
=
alpha_star
self
.
_last_alpha_star
=
alpha_star
return
alpha_star
,
phi_star
# extract the full energy from the line_energy
energy_star
=
energy_star
.
energy
return
alpha_star
,
phi_star
,
energy_star
def
_zoom
(
self
,
alpha_lo
,
alpha_hi
,
phi_0
,
phiprime_0
,
def
_zoom
(
self
,
alpha_lo
,
alpha_hi
,
phi_0
,
phiprime_0
,
phi_lo
,
phiprime_lo
,
phi_hi
,
c1
,
c2
):
phi_lo
,
phiprime_lo
,
phi_hi
,
c1
,
c2
):
...
@@ -176,6 +186,7 @@ class LineSearchStrongWolfe(LineSearch):
...
@@ -176,6 +186,7 @@ class LineSearchStrongWolfe(LineSearch):
if
abs
(
phiprime_alphaj
)
<=
-
c2
*
phiprime_0
:
if
abs
(
phiprime_alphaj
)
<=
-
c2
*
phiprime_0
:
alpha_star
=
alpha_j
alpha_star
=
alpha_j
phi_star
=
phi_alphaj
phi_star
=
phi_alphaj
energy_star
=
energy_alphaj
break
break
# If not, check the sign of the slope
# If not, check the sign of the slope
if
phiprime_alphaj
*
delta_alpha
>=
0
:
if
phiprime_alphaj
*
delta_alpha
>=
0
:
...
@@ -188,11 +199,12 @@ class LineSearchStrongWolfe(LineSearch):
...
@@ -188,11 +199,12 @@ class LineSearchStrongWolfe(LineSearch):
phiprime_alphaj
)
phiprime_alphaj
)
else
:
else
:
alpha_star
,
phi_star
=
alpha_j
,
phi_alphaj
alpha_star
,
phi_star
,
energy_star
=
\
alpha_j
,
phi_alphaj
,
energy_alphaj
self
.
logger
.
error
(
"The line search algorithm (zoom) did not "
self
.
logger
.
error
(
"The line search algorithm (zoom) did not "
"converge."
)
"converge."
)
return
alpha_star
,
phi_star
return
alpha_star
,
phi_star
,
energy_star
def
_cubicmin
(
self
,
a
,
fa
,
fpa
,
b
,
fb
,
c
,
fc
):
def
_cubicmin
(
self
,
a
,
fa
,
fpa
,
b
,
fb
,
c
,
fc
):
"""
"""
...
...
nifty/minimization/quasi_newton_minimizer.py
View file @
7bb88123
...
@@ -86,13 +86,11 @@ class QuasiNewtonMinimizer(object, Loggable):
...
@@ -86,13 +86,11 @@ class QuasiNewtonMinimizer(object, Loggable):
# compute the step length, which minimizes energy.value along the
# compute the step length, which minimizes energy.value along the
# search direction
# search direction
step_length
,
step_length
=
self
.
line_searcher
.
perform_line_search
(
step_length
,
f_k
,
new_energy
=
\
self
.
line_searcher
.
perform_line_search
(
energy
=
energy
,
energy
=
energy
,
pk
=
descend_direction
,
pk
=
descend_direction
,
f_k_minus_1
=
f_k_minus_1
)
f_k_minus_1
=
f_k_minus_1
)
new_position
=
current_position
+
step_length
*
descend_direction
new_energy
=
energy
.
at
(
new_position
)
f_k_minus_1
=
energy
.
value
f_k_minus_1
=
energy
.
value
energy
=
new_energy
energy
=
new_energy
...
...
nifty/minimization/vl_bfgs.py
View file @
7bb88123
...
@@ -20,9 +20,9 @@ class VL_BFGS(QuasiNewtonMinimizer):
...
@@ -20,9 +20,9 @@ class VL_BFGS(QuasiNewtonMinimizer):
self
.
max_history_length
=
max_history_length
self
.
max_history_length
=
max_history_length
def
__call__
(
self
,
x0
,
f
,
fprime
,
f_args
=
()
):
def
__call__
(
self
,
energy
):
self
.
_information_store
=
None
self
.
_information_store
=
None
return
super
(
VL_BFGS
,
self
).
__call__
(
x0
,
f
,
fprime
,
f_args
=
()
)
return
super
(
VL_BFGS
,
self
).
__call__
(
energy
)
def
_get_descend_direction
(
self
,
x
,
gradient
):
def
_get_descend_direction
(
self
,
x
,
gradient
):
# initialize the information store if it doesn't already exist
# initialize the information store if it doesn't already exist
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment