Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
2b4b6c07
Commit
2b4b6c07
authored
Aug 15, 2017
by
Martin Reinecke
Browse files
merge master
parents
8696dd81
b5d09b6a
Pipeline
#16578
passed with stage
in 35 minutes and 47 seconds
Changes
23
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
demos/critical_filtering.py
View file @
2b4b6c07
...
...
@@ -23,9 +23,9 @@ def plot_parameters(m, t, p, p_d):
t
=
t
.
val
.
get_full_data
().
real
p
=
p
.
val
.
get_full_data
().
real
p_d
=
p_d
.
val
.
get_full_data
().
real
pl
.
plot
([
go
.
Heatmap
(
z
=
m
)],
filename
=
'map.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
m
)],
filename
=
'map.html'
,
auto_open
=
False
)
pl
.
plot
([
go
.
Scatter
(
x
=
x
,
y
=
t
),
go
.
Scatter
(
x
=
x
,
y
=
p
),
go
.
Scatter
(
x
=
x
,
y
=
p_d
)],
filename
=
"t.html"
)
go
.
Scatter
(
x
=
x
,
y
=
p_d
)],
filename
=
"t.html"
,
auto_open
=
False
)
class
AdjointFFTResponse
(
LinearOperator
):
...
...
@@ -106,7 +106,7 @@ if __name__ == "__main__":
data_power
=
log
(
fft
(
d
).
power_analyze
(
binbounds
=
p_space
.
binbounds
))
d_data
=
d
.
val
.
get_full_data
().
real
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
d_data
)],
filename
=
'data.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
d_data
)],
filename
=
'data.html'
,
auto_open
=
False
)
# Minimization strategy
def
convergence_measure
(
a_energy
,
iteration
):
# returns current energy
...
...
demos/wiener_filter_via_curvature.py
View file @
2b4b6c07
...
...
@@ -2,22 +2,23 @@ import numpy as np
from
nifty
import
RGSpace
,
PowerSpace
,
Field
,
FFTOperator
,
ComposedOperator
,
\
DiagonalOperator
,
ResponseOperator
,
plotting
,
\
create_power_operator
create_power_operator
,
nifty_configuration
from
nifty.library
import
WienerFilterCurvature
if
__name__
==
"__main__"
:
distribution_strategy
=
'not'
nifty_configuration
[
'default_distribution_strategy'
]
=
'fftw'
nifty_configuration
[
'harmonic_rg_base'
]
=
'real'
# Setting up variable parameters
# Typical distance over which the field is correlated
correlation_length
=
0.0
1
correlation_length
=
0.0
5
# Variance of field in position space sqrt(<|s_x|^2>)
field_variance
=
2.
# smoothing length of response (in same unit as L)
response_sigma
=
0.1
response_sigma
=
0.
0
1
# The signal to noise ratio
signal_to_noise
=
0.7
...
...
@@ -36,19 +37,17 @@ if __name__ == "__main__":
signal_space
=
RGSpace
([
N_pixels
,
N_pixels
],
distances
=
L
/
N_pixels
)
harmonic_space
=
FFTOperator
.
get_default_codomain
(
signal_space
)
fft
=
FFTOperator
(
harmonic_space
,
target
=
signal_space
,
domain_dtype
=
np
.
complex
,
target_dtype
=
np
.
float
)
power_space
=
PowerSpace
(
harmonic_space
,
distribution_strategy
=
distribution_strategy
)
fft
=
FFTOperator
(
harmonic_space
,
target
=
signal_space
)
power_space
=
PowerSpace
(
harmonic_space
)
# Creating the mock data
S
=
create_power_operator
(
harmonic_space
,
power_spectrum
=
power_spectrum
,
distribution_strategy
=
distribution_strategy
)
S
=
create_power_operator
(
harmonic_space
,
power_spectrum
=
power_spectrum
)
mock_power
=
Field
(
power_space
,
val
=
power_spectrum
,
distribution_strategy
=
distribution_strategy
)
mock_power
=
Field
(
power_space
,
val
=
power_spectrum
)
np
.
random
.
seed
(
43
)
mock_harmonic
=
mock_power
.
power_synthesize
(
real_signal
=
True
)
if
nifty_configuration
[
'harmonic_rg_base'
]
==
'real'
:
mock_harmonic
=
mock_harmonic
.
real
mock_signal
=
fft
(
mock_harmonic
)
R
=
ResponseOperator
(
signal_space
,
sigma
=
(
response_sigma
,))
...
...
@@ -73,9 +72,11 @@ if __name__ == "__main__":
m_s
=
fft
(
m
)
plotter
=
plotting
.
RG2DPlotter
()
plotter
.
title
=
'mock_signal.html'
;
plotter
(
mock_signal
)
plotter
.
title
=
'data.html'
plotter
(
Field
(
signal_space
,
val
=
data
.
val
.
get_full_data
().
reshape
(
signal_space
.
shape
)))
plotter
.
title
=
'map.html'
;
plotter
(
m_s
)
\ No newline at end of file
plotter
.
path
=
'mock_signal.html'
plotter
(
mock_signal
.
real
)
plotter
.
path
=
'data.html'
plotter
(
Field
(
signal_space
,
val
=
data
.
val
.
get_full_data
().
real
.
reshape
(
signal_space
.
shape
)))
plotter
.
path
=
'map.html'
plotter
(
m_s
.
real
)
demos/wiener_filter_via_hamiltonian.py
View file @
2b4b6c07
...
...
@@ -10,6 +10,7 @@ rank = comm.rank
np
.
random
.
seed
(
42
)
class
AdjointFFTResponse
(
LinearOperator
):
def
__init__
(
self
,
FFT
,
R
,
default_spaces
=
None
):
super
(
AdjointFFTResponse
,
self
).
__init__
(
default_spaces
)
...
...
@@ -23,6 +24,7 @@ class AdjointFFTResponse(LinearOperator):
def
_adjoint_times
(
self
,
x
,
spaces
=
None
):
return
self
.
FFT
(
self
.
R
.
adjoint_times
(
x
))
@
property
def
domain
(
self
):
return
self
.
_domain
...
...
@@ -36,13 +38,12 @@ class AdjointFFTResponse(LinearOperator):
return
False
if
__name__
==
"__main__"
:
distribution_strategy
=
'not'
# Set up position space
s_space
=
RGSpace
([
128
,
128
])
s_space
=
RGSpace
([
128
,
128
])
# s_space = HPSpace(32)
# Define harmonic transformation and associated harmonic space
...
...
@@ -52,7 +53,8 @@ if __name__ == "__main__":
# Setting up power space
p_space
=
PowerSpace
(
h_space
,
distribution_strategy
=
distribution_strategy
)
# Choosing the prior correlation structure and defining correlation operator
# Choosing the prior correlation structure and defining
# correlation operator
p_spec
=
(
lambda
k
:
(
42
/
(
k
+
1
)
**
3
))
S
=
create_power_operator
(
h_space
,
power_spectrum
=
p_spec
,
...
...
@@ -69,7 +71,7 @@ if __name__ == "__main__":
Instrument
=
DiagonalOperator
(
s_space
,
diagonal
=
1.
)
# Instrument._diagonal.val[200:400, 200:400] = 0
#Adding a harmonic transformation to the instrument
#
Adding a harmonic transformation to the instrument
R
=
AdjointFFTResponse
(
fft
,
Instrument
)
signal_to_noise
=
1.
N
=
DiagonalOperator
(
s_space
,
diagonal
=
ss
.
var
()
/
signal_to_noise
,
bare
=
True
)
...
...
@@ -84,9 +86,9 @@ if __name__ == "__main__":
# Choosing the minimization strategy
def
convergence_measure
(
energy
,
iteration
):
# returns current energy
def
convergence_measure
(
energy
,
iteration
):
# returns current energy
x
=
energy
.
value
print
(
x
,
iteration
)
print
(
x
,
iteration
)
# minimizer = SteepestDescent(convergence_tolerance=0,
# iteration_limit=50,
...
...
@@ -109,20 +111,19 @@ if __name__ == "__main__":
m0
=
Field
(
h_space
,
val
=
.
0
)
# Initializing the Wiener Filter energy
energy
=
WienerFilterEnergy
(
position
=
m0
,
d
=
d
,
R
=
R
,
N
=
N
,
S
=
S
,
inverter
=
inverter
)
energy
=
WienerFilterEnergy
(
position
=
m0
,
d
=
d
,
R
=
R
,
N
=
N
,
S
=
S
)
D0
=
energy
.
curvature
# Solving the problem analytically
m0
=
D0
.
inverse_times
(
j
)
sample_variance
=
Field
(
sh
.
domain
,
val
=
0.
+
0j
)
sample_mean
=
Field
(
sh
.
domain
,
val
=
0.
+
0j
)
sample_variance
=
Field
(
sh
.
domain
,
val
=
0.
+
0j
)
sample_mean
=
Field
(
sh
.
domain
,
val
=
0.
+
0j
)
# sampling the uncertainty map
n_samples
=
1
for
i
in
range
(
n_samples
):
sample
=
sugar
.
generate_posterior_sample
(
m0
,
D0
)
sample
=
sugar
.
generate_posterior_sample
(
m0
,
D0
)
sample_variance
+=
sample
**
2
sample_mean
+=
sample
variance
=
sample_variance
/
n_samples
-
(
sample_mean
/
n_samples
)
nifty/basic_arithmetics.py
View file @
2b4b6c07
...
...
@@ -24,7 +24,7 @@ from .field import Field
__all__
=
[
'cos'
,
'sin'
,
'cosh'
,
'sinh'
,
'tan'
,
'tanh'
,
'arccos'
,
'arcsin'
,
'arccosh'
,
'arcsinh'
,
'arctan'
,
'arctanh'
,
'sqrt'
,
'exp'
,
'log'
,
'conjugate'
,
'clipped_exp'
]
'conjugate'
,
'clipped_exp'
,
'limitted_exp'
]
def
_math_helper
(
x
,
function
):
...
...
@@ -100,6 +100,19 @@ def clipped_exp(x):
return
_math_helper
(
x
,
lambda
z
:
np
.
exp
(
np
.
minimum
(
200
,
z
)))
def
limitted_exp
(
x
):
thr
=
200
expthr
=
np
.
exp
(
thr
)
return
_math_helper
(
x
,
lambda
z
:
_limitted_exp_helper
(
z
,
thr
,
expthr
))
def
_limitted_exp_helper
(
x
,
thr
,
expthr
):
mask
=
(
x
>
thr
)
result
=
np
.
exp
(
x
)
result
[
mask
]
=
((
1
-
thr
)
+
x
[
mask
])
*
expthr
return
result
def
log
(
x
,
base
=
None
):
result
=
_math_helper
(
x
,
np
.
log
)
if
base
is
not
None
:
...
...
nifty/config/nifty_config.py
View file @
2b4b6c07
...
...
@@ -70,11 +70,18 @@ variable_default_distribution_strategy = keepers.Variable(
if
z
==
'fftw'
else
True
),
genus
=
'str'
)
variable_harmonic_rg_base
=
keepers
.
Variable
(
'harmonic_rg_base'
,
[
'real'
,
'complex'
],
lambda
z
:
z
in
[
'real'
,
'complex'
],
genus
=
'str'
)
nifty_configuration
=
keepers
.
get_Configuration
(
name
=
'NIFTy'
,
variables
=
[
variable_fft_module
,
variable_default_field_dtype
,
variable_default_distribution_strategy
],
variable_default_distribution_strategy
,
variable_harmonic_rg_base
],
file_name
=
'NIFTy.conf'
,
search_paths
=
[
os
.
path
.
expanduser
(
'~'
)
+
"/.config/nifty/"
,
os
.
path
.
expanduser
(
'~'
)
+
"/.config/"
,
...
...
nifty/energies/energy.py
View file @
2b4b6c07
...
...
@@ -65,12 +65,9 @@ class Energy(with_metaclass(NiftyMeta, type('NewBase', (Loggable, object), {})))
"""
def
__init__
(
self
,
position
):
super
(
Energy
,
self
).
__init__
()
self
.
_cache
=
{}
try
:
position
=
position
.
copy
()
except
AttributeError
:
pass
self
.
_position
=
position
self
.
_position
=
position
.
copy
()
def
at
(
self
,
position
):
""" Initializes and returns a new Energy object at the new position.
...
...
nifty/energies/line_energy.py
View file @
2b4b6c07
...
...
@@ -70,12 +70,16 @@ class LineEnergy(object):
"""
def
__init__
(
self
,
line_position
,
energy
,
line_direction
,
offset
=
0.
):
super
(
LineEnergy
,
self
).
__init__
()
self
.
_line_position
=
float
(
line_position
)
self
.
_line_direction
=
line_direction
pos
=
energy
.
position
\
+
(
self
.
_line_position
-
float
(
offset
))
*
self
.
_line_direction
self
.
energy
=
energy
.
at
(
position
=
pos
)
if
self
.
_line_position
==
float
(
offset
):
self
.
energy
=
energy
else
:
pos
=
energy
.
position
\
+
(
self
.
_line_position
-
float
(
offset
))
*
self
.
_line_direction
self
.
energy
=
energy
.
at
(
position
=
pos
)
def
at
(
self
,
line_position
):
""" Returns LineEnergy at new position, memorizing the zero point.
...
...
nifty/field.py
View file @
2b4b6c07
...
...
@@ -22,7 +22,6 @@ from builtins import zip
from
builtins
import
range
import
ast
import
itertools
import
numpy
as
np
from
keepers
import
Versionable
,
\
...
...
@@ -410,7 +409,6 @@ class Field(Loggable, Versionable, object):
distribution_strategy
=
distribution_strategy
,
logarithmic
=
logarithmic
,
nbin
=
nbin
,
binbounds
=
binbounds
)
power_spectrum
=
cls
.
_calculate_power_spectrum
(
field_val
=
work_field
.
val
,
pdomain
=
power_domain
,
...
...
@@ -441,6 +439,7 @@ class Field(Loggable, Versionable, object):
target_shape
=
field_val
.
shape
,
target_strategy
=
field_val
.
distribution_strategy
,
axes
=
axes
)
power_spectrum
=
pindex
.
bincount
(
weights
=
field_val
,
axis
=
axes
)
rho
=
pdomain
.
rho
...
...
@@ -466,14 +465,14 @@ class Field(Loggable, Versionable, object):
"A slicing distributor shall not be reshaped to "
"something non-sliced."
)
semiscaled_shape
=
[
1
,
]
*
len
(
target_shape
)
for
i
in
axes
:
semiscaled_shape
[
i
]
=
target
_shape
[
i
]
semiscaled_
local_
shape
=
[
1
,
]
*
len
(
target_shape
)
for
i
in
range
(
len
(
axes
))
:
semiscaled_
local_
shape
[
axes
[
i
]
]
=
pindex
.
local
_shape
[
i
]
local_data
=
pindex
.
get_local_data
(
copy
=
False
)
semiscaled_local_data
=
local_data
.
reshape
(
semiscaled_shape
)
semiscaled_local_data
=
local_data
.
reshape
(
semiscaled_
local_
shape
)
result_obj
=
pindex
.
copy_empty
(
global_shape
=
target_shape
,
distribution_strategy
=
target_strategy
)
result_obj
.
set_full_data
(
semiscaled_local_data
,
copy
=
False
)
result_obj
.
data
[:]
=
semiscaled_local_data
return
result_obj
...
...
nifty/library/wiener_filter/wiener_filter_energy.py
View file @
2b4b6c07
...
...
@@ -7,7 +7,7 @@ class WienerFilterEnergy(Energy):
"""The Energy for the Wiener filter.
It covers the case of linear measurement with
Gaussian noise and Gauss
a
in signal prior with known covariance.
Gaussian noise and Gaussi
a
n signal prior with known covariance.
Parameters
----------
...
...
nifty/minimization/descent_minimizer.py
View file @
2b4b6c07
...
...
@@ -123,7 +123,6 @@ class DescentMinimizer(with_metaclass(NiftyMeta, type('NewBase', (Loggable, obje
convergence
=
0
f_k_minus_1
=
None
step_length
=
0
iteration_number
=
1
while
True
:
...
...
@@ -150,7 +149,7 @@ class DescentMinimizer(with_metaclass(NiftyMeta, type('NewBase', (Loggable, obje
# compute the step length, which minimizes energy.value along the
# search direction
try
:
step_length
,
f_k
,
new_energy
=
\
new_energy
=
\
self
.
line_searcher
.
perform_line_search
(
energy
=
energy
,
pk
=
descent_direction
,
...
...
@@ -160,12 +159,10 @@ class DescentMinimizer(with_metaclass(NiftyMeta, type('NewBase', (Loggable, obje
"Stopping because of RuntimeError in line-search"
)
break
if
f_k_minus_1
is
None
:
delta
=
1e30
else
:
delta
=
(
abs
(
f_k
-
f_k_minus_1
)
/
max
(
abs
(
f_k
),
abs
(
f_k_minus_1
),
1.
))
f_k_minus_1
=
energy
.
value
f_k
=
new_energy
.
value
delta
=
(
abs
(
f_k
-
f_k_minus_1
)
/
max
(
abs
(
f_k
),
abs
(
f_k_minus_1
),
1.
))
# check if new energy value is bigger than old energy value
if
(
new_energy
.
value
-
energy
.
value
)
>
0
:
self
.
logger
.
info
(
"Line search algorithm returned a new energy "
...
...
@@ -174,9 +171,9 @@ class DescentMinimizer(with_metaclass(NiftyMeta, type('NewBase', (Loggable, obje
energy
=
new_energy
# check convergence
self
.
logger
.
debug
(
"Iteration:%08u
step_length=%3.1E
"
self
.
logger
.
debug
(
"Iteration:%08u "
"delta=%3.1E energy=%3.1E"
%
(
iteration_number
,
step_length
,
delta
,
(
iteration_number
,
delta
,
energy
.
value
))
if
delta
==
0
:
convergence
=
self
.
convergence_level
+
2
...
...
nifty/minimization/line_searching/line_search.py
View file @
2b4b6c07
...
...
@@ -35,8 +35,6 @@ class LineSearch(with_metaclass(abc.ABCMeta, type('NewBase', (Loggable, object),
----------
line_energy : LineEnergy Object
LineEnergy object from which we can extract energy at a specific point.
f_k_minus_1 : Field
Value of the field at the k-1 iteration of the line search procedure.
preferred_initial_step_size : float
Initial guess for the step length.
...
...
@@ -45,32 +43,8 @@ class LineSearch(with_metaclass(abc.ABCMeta, type('NewBase', (Loggable, object),
__metaclass__
=
abc
.
ABCMeta
def
__init__
(
self
):
self
.
line_energy
=
None
self
.
f_k_minus_1
=
None
self
.
preferred_initial_step_size
=
None
def
_set_line_energy
(
self
,
energy
,
pk
,
f_k_minus_1
=
None
):
"""Set the coordinates for a new line search.
Parameters
----------
energy : Energy object
Energy object from which we can calculate the energy, gradient and
curvature at a specific point.
pk : Field
Unit vector pointing into the search direction.
f_k_minus_1 : float
Value of the fuction (energy) which will be minimized at the k-1
iteration of the line search procedure. (Default: None)
"""
self
.
line_energy
=
LineEnergy
(
line_position
=
0.
,
energy
=
energy
,
line_direction
=
pk
)
if
f_k_minus_1
is
not
None
:
f_k_minus_1
=
f_k_minus_1
.
copy
()
self
.
f_k_minus_1
=
f_k_minus_1
@
abc
.
abstractmethod
def
perform_line_search
(
self
,
energy
,
pk
,
f_k_minus_1
=
None
):
raise
NotImplementedError
nifty/minimization/line_searching/line_search_strong_wolfe.py
View file @
2b4b6c07
...
...
@@ -22,12 +22,13 @@ from builtins import range
import
numpy
as
np
from
.line_search
import
LineSearch
from
...energies
import
LineEnergy
class
LineSearchStrongWolfe
(
LineSearch
):
"""Class for finding a step size that satisfies the strong Wolfe conditions.
Algorithm contains two stages. It begins w
h
it a trial step length and
Algorithm contains two stages. It begins wit
h
a trial step length and
keeps increasing it until it finds an acceptable step length or an
interval. If it does not satisfy the Wolfe conditions, it performs the Zoom
algorithm (second stage). By interpolating it decreases the size of the
...
...
@@ -80,8 +81,8 @@ class LineSearchStrongWolfe(LineSearch):
"""Performs the first stage of the algorithm.
It starts with a trial step size and it keeps increasing it until it
satisf
y
the strong Wolf conditions. It also performs the descent and
returns the optimal step length and the new enrgy.
satisf
ies
the strong Wolf conditions. It also performs the descent and
returns the optimal step length and the new en
e
rgy.
Parameters
----------
...
...
@@ -89,29 +90,22 @@ class LineSearchStrongWolfe(LineSearch):
Energy object from which we will calculate the energy and the
gradient at a specific point.
pk : Field
Unit v
ector pointing into the search direction.
V
ector pointing into the search direction.
f_k_minus_1 : float
Value of the fuction (which is being minimized) at the k-1
iteration of the line search procedure. (Default: None)
Returns
-------
alpha_star : float
The optimal step length in the descent direction.
phi_star : float
Value of the energy after the performed descent.
energy_star : Energy object
The new Energy object on the new position.
"""
self
.
_set_line_energy
(
energy
,
pk
,
f_k_minus_1
=
f_k_minus_1
)
max_step_size
=
self
.
max_step_size
max_iterations
=
self
.
max_iterations
le_0
=
LineEnergy
(
0.
,
energy
,
pk
,
0.
)
# initialize the zero phis
old_phi_0
=
self
.
f_k_minus_1
le_0
=
self
.
line_energy
.
at
(
0
)
old_phi_0
=
f_k_minus_1
phi_0
=
le_0
.
value
phiprime_0
=
le_0
.
directional_derivative
if
phiprime_0
>=
0
:
...
...
@@ -120,6 +114,9 @@ class LineSearchStrongWolfe(LineSearch):
# set alphas
alpha0
=
0.
phi_alpha0
=
phi_0
phiprime_alpha0
=
phiprime_0
if
self
.
preferred_initial_step_size
is
not
None
:
alpha1
=
self
.
preferred_initial_step_size
elif
old_phi_0
is
not
None
and
phiprime_0
!=
0
:
...
...
@@ -129,73 +126,48 @@ class LineSearchStrongWolfe(LineSearch):
else
:
alpha1
=
1.0
# give the alpha0 phis the right init value
phi_alpha0
=
phi_0
phiprime_alpha0
=
phiprime_0
# start the minimization loop
for
i
in
range
(
max_iterations
):
le_alpha1
=
self
.
line_energy
.
at
(
alpha1
)
phi_alpha1
=
le_alpha1
.
value
for
i
in
range
(
self
.
max_iterations
):
if
alpha1
==
0
:
self
.
logger
.
warn
(
"Increment size became 0."
)
alpha_star
=
0.
phi_star
=
phi_0
le_star
=
le_0
break
return
le_0
.
energy
le_alpha1
=
le_0
.
at
(
alpha1
)
phi_alpha1
=
le_alpha1
.
value
if
(
phi_alpha1
>
phi_0
+
self
.
c1
*
alpha1
*
phiprime_0
)
or
\
((
phi_alpha1
>=
phi_alpha0
)
and
(
i
>
0
)):
(
alpha_star
,
phi_star
,
le_star
)
=
self
.
_zoom
(
alpha0
,
alpha1
,
phi_0
,
phiprime_0
,
phi_alpha0
,
phiprime_alpha0
,
phi_alpha1
)
break
le_star
=
self
.
_zoom
(
alpha0
,
alpha1
,
phi_0
,
phiprime_0
,
phi_alpha0
,
phiprime_alpha0
,
phi_alpha1
,
le_0
)
return
le_star
.
energy
phiprime_alpha1
=
le_alpha1
.
directional_derivative
if
abs
(
phiprime_alpha1
)
<=
-
self
.
c2
*
phiprime_0
:
alpha_star
=
alpha1
phi_star
=
phi_alpha1
le_star
=
le_alpha1
break
return
le_alpha1
.
energy
if
phiprime_alpha1
>=
0
:
(
alpha_star
,
phi_star
,
le_star
)
=
self
.
_zoom
(
alpha1
,
alpha0
,
phi_0
,
phiprime_0
,
phi_alpha1
,
phiprime_alpha1
,
phi_alpha0
)
break
le_star
=
self
.
_zoom
(
alpha1
,
alpha0
,
phi_0
,
phiprime_0
,
phi_alpha1
,
phiprime_alpha1
,
phi_alpha0
,
le_0
)
return
le_star
.
energy
# update alphas
alpha0
,
alpha1
=
alpha1
,
min
(
2
*
alpha1
,
max_step_size
)
if
alpha1
==
max_step_size
:
alpha0
,
alpha1
=
alpha1
,
min
(
2
*
alpha1
,
self
.
max_step_size
)
if
alpha1
==
self
.
max_step_size
:
print
(
"reached max step size, bailing out"
)
alpha_star
=
alpha1
phi_star
=
phi_alpha1
le_star
=
le_alpha1
break
return
le_alpha1
.
energy
phi_alpha0
=
phi_alpha1
phiprime_alpha0
=
phiprime_alpha1
else
:
# max_iterations was reached
alpha_star
=
alpha1
phi_star
=
phi_alpha1
le_star
=
le_alpha1
self
.
logger
.
error
(
"The line search algorithm did not converge."
)
# extract the full energy from the line_energy
energy_star
=
le_star
.
energy
direction_length
=
pk
.
norm
()
step_length
=
alpha_star
*
direction_length
return
step_length
,
phi_star
,
energy_star
return
le_alpha1
.
energy
def
_zoom
(
self
,
alpha_lo
,
alpha_hi
,
phi_0
,
phiprime_0
,
phi_lo
,
phiprime_lo
,
phi_hi
):
phi_lo
,
phiprime_lo
,
phi_hi
,
le_0
):
"""Performs the second stage of the line search algorithm.
If the first stage was not successful then the Zoom algorithm tries to
...
...
@@ -226,32 +198,23 @@ class LineSearchStrongWolfe(LineSearch):
Returns
-------
alpha_star : float
The optimal step length in the descent direction.
phi_star : float
Value of the energy after the performed descent.
energy_star : Energy object
The new Energy object on the new position.
"""
max_iterations
=
self
.
max_zoom_iterations
# define the cubic and quadratic interpolant checks
cubic_delta
=
0.2
# cubic
quad_delta
=
0.1
# quadratic
phiprime_alphaj
=
0.
alpha_recent
=
None
phi_recent
=
None
assert
phi_lo
<=
phi_0
+
self
.
c1
*
alpha_lo
*
phiprime_0
assert
phiprime_lo
*
(
alpha_hi
-
alpha_lo
)
<
0.
for
i
in
range
(
max
_iterations
):
for
i
in
range
(
self
.
max_zoom
_iterations
):
# assert phi_lo <= phi_0 + self.c1*alpha_lo*phiprime_0
# assert phiprime_lo*(alpha_hi-alpha_lo)<0.
delta_alpha
=
alpha_hi
-
alpha_lo
if
delta_alpha
<
0
:
a
,
b
=
alpha_hi
,
alpha_lo
else
:
a
,
b
=
alpha_lo
,
alpha_hi
a
,
b
=
min
(
alpha_lo
,
alpha_hi
),
max
(
alpha_lo
,
alpha_hi
)
# Try cubic interpolation
if
i
>
0
:
...
...
@@ -271,12 +234,12 @@ class LineSearchStrongWolfe(LineSearch):
alpha_j
=
alpha_lo
+
0.5
*
delta_alpha