Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
3c5a2a01
Commit
3c5a2a01
authored
Jul 24, 2019
by
Philipp Arras
Browse files
Rewrite NewtonCG
parent
abe02b37
Pipeline
#52374
passed with stages
in 8 minutes and 38 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty5/minimization/descent_minimizers.py
View file @
3c5a2a01
...
...
@@ -18,8 +18,11 @@
import
numpy
as
np
from
..logger
import
logger
from
.conjugate_gradient
import
ConjugateGradient
from
.iteration_controllers
import
GradientNormController
from
.line_search
import
LineSearch
from
.minimizer
import
Minimizer
from
.quadratic_energy
import
QuadraticEnergy
class
DescentMinimizer
(
Minimizer
):
...
...
@@ -154,49 +157,22 @@ class NewtonCG(DescentMinimizer):
Algorithm derived from SciPy sources.
"""
def
__init__
(
self
,
controller
,
napprox
=
0
,
line_searcher
=
None
):
def
__init__
(
self
,
controller
,
line_searcher
=
None
):
if
line_searcher
is
None
:
line_searcher
=
LineSearch
(
preferred_initial_step_size
=
1.
)
super
(
NewtonCG
,
self
).
__init__
(
controller
=
controller
,
line_searcher
=
line_searcher
)
self
.
_napprox
=
int
(
napprox
)
def
get_descent_direction
(
self
,
energy
):
# if self._napprox > 1:
# from ..probing import approximation2endo
# sqdiag = approximation2endo(energy.metric, self._napprox).sqrt()
float64eps
=
np
.
finfo
(
np
.
float64
).
eps
r
=
energy
.
gradient
maggrad
=
abs
(
r
).
sum
()
g
=
energy
.
gradient
maggrad
=
abs
(
g
).
sum
()
termcond
=
np
.
min
([
0.5
,
np
.
sqrt
(
maggrad
)])
*
maggrad
pos
=
energy
.
position
*
0
d
=
r
previous_gamma
=
r
.
vdot
(
d
)
ii
=
0
while
True
:
if
not
ii
%
10
and
ii
>
0
:
print
(
ii
)
if
abs
(
r
).
sum
()
<=
termcond
:
return
pos
q
=
energy
.
apply_metric
(
d
)
curv
=
d
.
vdot
(
q
)
if
0
<=
curv
<=
3
*
float64eps
:
return
pos
if
curv
<
0
:
return
pos
if
ii
>
0
else
previous_gamma
/
curv
*
r
ii
+=
1
alpha
=
previous_gamma
/
curv
pos
=
pos
-
alpha
*
d
r
=
r
-
alpha
*
q
s
=
r
gamma
=
r
.
vdot
(
s
)
d
=
d
*
(
gamma
/
previous_gamma
)
+
r
previous_gamma
=
gamma
# curvature keeps increasing, bail out
raise
ValueError
(
"Warning: CG iterations didn't converge. "
"The Hessian is not positive definite."
)
ic
=
GradientNormController
(
tol_abs_gradnorm
=
termcond
,
p
=
1
)
e
=
QuadraticEnergy
(
0
*
energy
.
position
,
energy
.
metric
,
g
)
e
,
conv
=
ConjugateGradient
(
ic
,
nreset
=
np
.
inf
)(
e
)
if
conv
==
ic
.
ERROR
:
raise
RuntimeError
return
-
e
.
position
class
L_BFGS
(
DescentMinimizer
):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment