Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
On Thursday, 7th July from 1 to 3 pm there will be a maintenance with a short downtime of GitLab.
Open sidebar
ift
NIFTy
Commits
648587ef
Commit
648587ef
authored
Aug 19, 2017
by
Martin Reinecke
Browse files
first try
parent
e1843d4d
Pipeline
#16876
canceled with stage
in 9 minutes and 56 seconds
Changes
5
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty/energies/__init__.py
View file @
648587ef
...
...
@@ -17,5 +17,6 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from
energy
import
Energy
from
quadratic_energy
import
QuadraticEnergy
from
line_energy
import
LineEnergy
from
memoization
import
memo
nifty/energies/energy.py
View file @
648587ef
...
...
@@ -17,6 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from
nifty.nifty_meta
import
NiftyMeta
from
nifty.energies.memoization
import
memo
from
keepers
import
Loggable
...
...
@@ -40,7 +41,7 @@ class Energy(Loggable, object):
value : np.float
The value of the energy functional at given `position`.
gradient : Field
The gradient at given `position`
in parameter direction
.
The gradient at given `position`.
curvature : LinearOperator, callable
A positive semi-definite operator or function describing the curvature
of the potential at the given `position`.
...
...
@@ -109,12 +110,32 @@ class Energy(Loggable, object):
@
property
def
gradient
(
self
):
"""
The gradient at given `position`
in parameter direction
.
The gradient at given `position`.
"""
raise
NotImplementedError
@
property
@
memo
def
gradient_norm
(
self
):
"""
The length of the gradient at given `position`.
"""
return
self
.
gradient
.
norm
()
@
property
@
memo
def
gradient_infnorm
(
self
):
"""
The infinity norm of the gradient at given `position`.
"""
return
abs
(
self
.
gradient
).
max
()
@
property
def
curvature
(
self
):
"""
...
...
nifty/energies/quadratic_energy.py
0 → 100644
View file @
648587ef
from
nifty.energies.energy
import
Energy
from
nifty.energies.memoization
import
memo
class
QuadraticEnergy
(
Energy
):
"""The Energy for a quadratic form.
"""
def
__init__
(
self
,
position
,
A
,
b
):
super
(
QuadraticEnergy
,
self
).
__init__
(
position
=
position
)
self
.
_A
=
A
self
.
_b
=
b
def
at
(
self
,
position
):
return
self
.
__class__
(
position
=
position
,
A
=
self
.
_A
,
b
=
self
.
_b
)
@
property
@
memo
def
value
(
self
):
return
0.5
*
self
.
position
.
vdot
(
self
.
_Ax
)
-
self
.
_b
.
vdot
(
self
.
position
)
@
property
@
memo
def
gradient
(
self
):
return
self
.
_Ax
-
self
.
_b
@
property
@
memo
def
curvature
(
self
):
return
self
.
_A
@
property
@
memo
def
_Ax
(
self
):
return
self
.
curvature
(
self
.
position
)
nifty/minimization/conjugate_gradient.py
View file @
648587ef
...
...
@@ -90,60 +90,54 @@ class ConjugateGradient(Loggable, object):
reset_count
=
int
(
reset_count
)
self
.
reset_count
=
reset_count
if
preconditioner
is
None
:
preconditioner
=
lambda
z
:
z
self
.
preconditioner
=
preconditioner
self
.
callback
=
callback
def
__call__
(
self
,
A
,
b
,
x0
):
def
__call__
(
self
,
E
):
""" Runs the conjugate gradient minimization.
For `Ax = b` the variable `x` is infered.
Parameters
----------
A : Operator
Operator `A` applicable to a Field.
b : Field
Result of the operation `A(x)`.
x0 : Field
Starting guess for the minimization.
E : Energy object at the starting point of the iteration.
E's curvature operator must be independent of position, otherwise
linear conjugate gradient minimization will fail.
Returns
-------
x : Field
Latest `x` of the minimization.
E : QuadraticEnergy at last point of the iteration
convergence : integer
Latest convergence level indicating whether the minimization
has converged or not.
"""
r
=
b
-
A
(
x0
)
d
=
self
.
preconditioner
(
r
)
r
=
-
E
.
gradient
if
self
.
preconditioner
is
not
None
:
d
=
self
.
preconditioner
(
r
)
else
:
d
=
r
.
copy
()
previous_gamma
=
(
r
.
vdot
(
d
)).
real
if
previous_gamma
==
0
:
self
.
logger
.
info
(
"The starting guess is already perfect solution "
"for the inverse problem."
)
return
x0
,
self
.
convergence_level
+
1
norm_b
=
np
.
sqrt
((
b
.
vdot
(
b
)).
real
)
x
=
x0
.
copy
()
return
E
,
self
.
convergence_level
+
1
convergence
=
0
iteration_number
=
1
self
.
logger
.
info
(
"Starting conjugate gradient."
)
while
True
:
if
self
.
callback
is
not
None
:
self
.
callback
(
x
,
iteration_number
)
self
.
callback
(
E
,
iteration_number
)
q
=
A
(
d
)
alpha
=
previous_gamma
/
d
.
vdot
(
q
).
real
q
=
E
.
curvature
(
d
)
alpha
=
previous_gamma
/
(
d
.
vdot
(
q
).
real
)
if
not
np
.
isfinite
(
alpha
):
self
.
logger
.
error
(
"Alpha became infinite! Stopping."
)
return
x0
,
0
return
E
,
0
x
+=
d
*
alpha
E
=
E
.
at
(
E
.
position
+
d
*
alpha
)
reset
=
False
if
alpha
<
0
:
...
...
@@ -153,20 +147,23 @@ class ConjugateGradient(Loggable, object):
reset
+=
(
iteration_number
%
self
.
reset_count
==
0
)
if
reset
:
self
.
logger
.
info
(
"Resetting conjugate directions."
)
r
=
b
-
A
(
x
)
r
=
-
E
.
gradient
else
:
r
-=
q
*
alpha
s
=
self
.
preconditioner
(
r
)
if
self
.
preconditioner
is
not
None
:
s
=
self
.
preconditioner
(
r
)
else
:
s
=
r
.
copy
()
gamma
=
r
.
vdot
(
s
).
real
if
gamma
<
0
:
self
.
logger
.
warn
(
"Positive definitness of preconditioner "
self
.
logger
.
warn
(
"Positive definit
e
ness of preconditioner "
"violated!"
)
beta
=
max
(
0
,
gamma
/
previous_gamma
)
delta
=
np
.
sqrt
(
gamma
)
/
norm
_b
delta
=
r
.
norm
()
self
.
logger
.
debug
(
"Iteration : %08u alpha = %3.1E "
"beta = %3.1E delta = %3.1E"
%
...
...
@@ -196,4 +193,4 @@ class ConjugateGradient(Loggable, object):
iteration_number
+=
1
previous_gamma
=
gamma
return
x
,
convergence
return
E
,
convergence
test/test_minimization/test_conjugate_gradient.py
View file @
648587ef
import
unittest
import
numpy
as
np
from
numpy.testing
import
assert_equal
,
assert_al
most_equal
from
numpy.testing
import
assert_equal
,
assert_al
lclose
from
nifty
import
Field
,
DiagonalOperator
,
RGSpace
,
HPSpace
from
nifty
import
ConjugateGradient
from
nifty
import
ConjugateGradient
,
QuadraticEnergy
from
test.common
import
expand
...
...
@@ -38,10 +38,11 @@ class Test_ConjugateGradient(unittest.TestCase):
required_result
=
Field
(
space
,
val
=
1.
)
minimizer
=
ConjugateGradient
()
energy
=
QuadraticEnergy
(
A
=
covariance
,
b
=
required_result
,
position
=
starting_point
)
(
position
,
convergence
)
=
minimizer
(
A
=
covariance
,
x0
=
starting_point
,
b
=
required_result
)
(
energy
,
convergence
)
=
minimizer
(
energy
)
assert_al
most_equal
(
position
.
val
.
get_full_data
(),
assert_al
lclose
(
energy
.
position
.
val
.
get_full_data
(),
1.
/
covariance_diagonal
.
val
.
get_full_data
(),
decimal
=
3
)
rtol
=
1e-
3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment