Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
17b1c6e9
Commit
17b1c6e9
authored
Oct 25, 2016
by
theos
Browse files
Implemented Energy object. Adapted minimizers to it.
parent
41b9419d
Changes
8
Hide whitespace changes
Inline
Side-by-side
demos/wiener_filter_hamiltonian.py
View file @
17b1c6e9
from
nifty
import
*
import
plotly.offline
as
pl
import
plotly.graph_objs
as
go
...
...
@@ -8,17 +9,47 @@ comm = MPI.COMM_WORLD
rank
=
comm
.
rank
class
WienerFilterEnergy
(
Energy
):
def
__init__
(
self
,
position
,
D
,
j
):
# in principle not necessary, but useful in order to make the signature
# explicit
super
(
WienerFilterEnergy
,
self
).
__init__
(
position
)
self
.
D
=
D
self
.
j
=
j
def
at
(
self
,
position
):
return
self
.
__class__
(
position
,
D
=
self
.
D
,
j
=
self
.
j
)
@
property
def
value
(
self
):
D_inv_x
=
self
.
D_inverse_x
()
H
=
0.5
*
D_inv_x
.
dot
(
self
.
position
)
-
self
.
j
.
dot
(
self
.
position
)
return
H
.
real
@
property
def
gradient
(
self
):
D_inv_x
=
self
.
D_inverse_x
()
g
=
D_inv_x
-
self
.
j
return_g
=
g
.
copy_empty
(
dtype
=
np
.
float
)
return_g
.
val
=
g
.
val
.
real
return
return_g
def
D_inverse_x
(
self
):
return
D
.
inverse_times
(
self
.
position
)
if
__name__
==
"__main__"
:
distribution_strategy
=
'fftw'
# Set up spaces and fft transformation
s_space
=
RGSpace
([
512
,
512
],
dtype
=
np
.
float
)
fft
=
FFTOperator
(
s_space
)
h_space
=
fft
.
target
[
0
]
p_space
=
PowerSpace
(
h_space
,
distribution_strategy
=
distribution_strategy
)
# create the field instances and power operator
pow_spec
=
(
lambda
k
:
(
42
/
(
k
+
1
)
**
3
))
S
=
create_power_operator
(
h_space
,
power_spectrum
=
pow_spec
,
distribution_strategy
=
distribution_strategy
)
...
...
@@ -27,8 +58,8 @@ if __name__ == "__main__":
sh
=
sp
.
power_synthesize
(
real_signal
=
True
)
ss
=
fft
.
inverse_times
(
sh
)
# model the measurement process
R
=
SmoothingOperator
(
s_space
,
sigma
=
0.01
)
# R = DiagonalOperator(s_space, diagonal=1.)
# R._diagonal.val[200:400, 200:400] = 0
...
...
@@ -38,70 +69,67 @@ if __name__ == "__main__":
random_type
=
'normal'
,
std
=
ss
.
std
()
/
np
.
sqrt
(
signal_to_noise
),
mean
=
0
)
#n.val.data.imag[:] = 0
# create mock data
d
=
R
(
ss
)
+
n
# set up reconstruction objects
j
=
R
.
adjoint_times
(
N
.
inverse_times
(
d
))
D
=
PropagatorOperator
(
S
=
S
,
N
=
N
,
R
=
R
)
def
energy
(
x
):
DIx
=
D
.
inverse_times
(
x
)
H
=
0.5
*
DIx
.
dot
(
x
)
-
j
.
dot
(
x
)
return
H
.
real
def
gradient
(
x
):
DIx
=
D
.
inverse_times
(
x
)
g
=
DIx
-
j
return_g
=
g
.
copy_empty
(
dtype
=
np
.
float
)
return_g
.
val
=
g
.
val
.
real
return
return_g
def
distance_measure
(
x
,
fgrad
,
iteration
):
print
(
iteration
,
((
x
-
ss
).
norm
()
/
ss
.
norm
()).
real
)
def
distance_measure
(
energy
,
iteration
):
pass
#print (iteration, ((x-ss).norm()/ss.norm()).real)
minimizer
=
SteepestDescent
(
convergence_tolerance
=
0
,
iteration_limit
=
50
,
callback
=
distance_measure
)
minimizer
=
VL_BFGS
(
convergence_tolerance
=
0
,
iteration_limit
=
50
,
callback
=
distance_measure
,
max_history_length
=
5
)
# minimizer = VL_BFGS(convergence_tolerance=0,
# iteration_limit=50,
# callback=distance_measure,
# max_history_length=5)
m0
=
Field
(
s_space
,
val
=
1
)
(
m
,
convergence
)
=
minimizer
(
m0
,
energy
,
gradient
)
grad
=
gradient
(
m
)
d_data
=
d
.
val
.
get_full_data
().
real
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
d_data
)],
filename
=
'data.html'
)
ss_data
=
ss
.
val
.
get_full_data
().
real
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
ss_data
)],
filename
=
'ss.html'
)
sh_data
=
sh
.
val
.
get_full_data
().
real
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
sh_data
)],
filename
=
'sh.html'
)
j_data
=
j
.
val
.
get_full_data
().
real
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
j_data
)],
filename
=
'j.html'
)
jabs_data
=
np
.
abs
(
j
.
val
.
get_full_data
())
jphase_data
=
np
.
angle
(
j
.
val
.
get_full_data
())
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
jabs_data
)],
filename
=
'j_abs.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
jphase_data
)],
filename
=
'j_phase.html'
)
m_data
=
m
.
val
.
get_full_data
().
real
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
m_data
)],
filename
=
'map.html'
)
grad_data
=
grad
.
val
.
get_full_data
().
real
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
grad_data
)],
filename
=
'grad.html'
)
energy
=
WienerFilterEnergy
(
position
=
m0
,
D
=
D
,
j
=
j
)
(
energy
,
convergence
)
=
minimizer
(
energy
)
#
#
#
# grad = gradient(m)
#
# d_data = d.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=d_data)], filename='data.html')
#
#
# ss_data = ss.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')
#
# sh_data = sh.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')
#
# j_data = j.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=j_data)], filename='j.html')
#
# jabs_data = np.abs(j.val.get_full_data())
# jphase_data = np.angle(j.val.get_full_data())
# if rank == 0:
# pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html')
# pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html')
#
# m_data = m.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=m_data)], filename='map.html')
#
# grad_data = grad.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=grad_data)], filename='grad.html')
nifty/__init__.py
View file @
17b1c6e9
...
...
@@ -40,6 +40,8 @@ from config import dependency_injector,\
from
d2o
import
distributed_data_object
,
d2o_librarian
from
energies
import
*
from
field
import
Field
from
random
import
Random
...
...
nifty/energies/__init__.py
0 → 100644
View file @
17b1c6e9
# -*- coding: utf-8 -*-
from
energy
import
Energy
from
line_energy
import
LineEnergy
nifty/energies/energy.py
0 → 100644
View file @
17b1c6e9
# -*- coding: utf-8 -*-
class
Energy
(
object
):
def
__init__
(
self
,
position
):
self
.
_cache
=
{}
try
:
position
=
position
.
copy
()
except
AttributeError
:
pass
self
.
position
=
position
def
at
(
self
,
position
):
return
self
.
__class__
(
position
)
@
property
def
value
(
self
):
raise
NotImplementedError
@
property
def
gradient
(
self
):
raise
NotImplementedError
@
property
def
curvature
(
self
):
raise
NotImplementedError
def
memo
(
f
):
name
=
id
(
f
)
def
wrapped_f
(
self
):
try
:
return
self
.
_cache
[
name
]
except
KeyError
:
self
.
_cache
[
name
]
=
f
(
self
)
return
self
.
_cache
[
name
]
return
wrapped_f
nifty/energies/line_energy.py
0 → 100644
View file @
17b1c6e9
# -*- coding: utf-8 -*-
from
.energy
import
Energy
class
LineEnergy
(
Energy
):
def
__init__
(
self
,
position
,
energy
,
line_direction
):
self
.
energy
=
energy
self
.
line_direction
=
line_direction
super
(
LineEnergy
,
self
).
__init__
(
position
=
position
)
def
at
(
self
,
position
):
if
position
==
0
:
return
self
else
:
full_position
=
self
.
position
+
self
.
line_direction
*
position
return
self
.
__class__
(
full_position
,
self
.
energy
,
self
.
line_direction
)
@
property
def
value
(
self
):
return
self
.
energy
.
value
@
property
def
gradient
(
self
):
return
self
.
energy
.
gradient
.
dot
(
self
.
line_direction
)
@
property
def
curvature
(
self
):
return
self
.
energy
.
curvature
nifty/minimization/line_searching/line_search.py
View file @
17b1c6e9
...
...
@@ -2,6 +2,8 @@ import abc
from
keepers
import
Loggable
from
nifty
import
LineEnergy
class
LineSearch
(
object
,
Loggable
):
"""
...
...
@@ -26,23 +28,11 @@ class LineSearch(object, Loggable):
derivation.
"""
self
.
xk
=
None
self
.
pk
=
None
self
.
f_k
=
None
self
.
line_energy
=
None
self
.
f_k_minus_1
=
None
self
.
fprime_k
=
None
def
set_functions
(
self
,
f
,
fprime
,
f_args
=
()):
assert
(
callable
(
f
))
assert
(
callable
(
fprime
))
self
.
f
=
f
self
.
fprime
=
fprime
self
.
f_args
=
f_args
def
_set_coordinates
(
self
,
xk
,
pk
,
f_k
=
None
,
fprime_k
=
None
,
f_k_minus_1
=
None
):
def
_set_line_energy
(
self
,
energy
,
pk
,
f_k_minus_1
=
None
):
"""
Set the coordinates for a new line search.
...
...
@@ -61,39 +51,13 @@ class LineSearch(object, Loggable):
Function value fprime(xk).
"""
self
.
xk
=
xk
.
copy
()
self
.
pk
=
pk
.
copy
()
if
f_k
is
None
:
self
.
f_k
=
self
.
f
(
xk
)
else
:
self
.
f_k
=
f_k
if
fprime_k
is
None
:
self
.
fprime_k
=
self
.
fprime
(
xk
)
else
:
self
.
fprime_k
=
fprime_k
self
.
line_energy
=
LineEnergy
(
position
=
0.
,
energy
=
energy
,
line_direction
=
pk
)
if
f_k_minus_1
is
not
None
:
f_k_minus_1
=
f_k_minus_1
.
copy
()
self
.
f_k_minus_1
=
f_k_minus_1
def
_phi
(
self
,
alpha
):
if
alpha
==
0
:
value
=
self
.
f_k
else
:
value
=
self
.
f
(
self
.
xk
+
self
.
pk
*
alpha
,
*
self
.
f_args
)
return
value
def
_phiprime
(
self
,
alpha
):
if
alpha
==
0
:
gradient
=
self
.
fprime_k
else
:
gradient
=
self
.
fprime
(
self
.
xk
+
self
.
pk
*
alpha
,
*
self
.
f_args
)
return
gradient
.
dot
(
self
.
pk
)
@
abc
.
abstractmethod
def
perform_line_search
(
self
,
xk
,
pk
,
f_k
=
None
,
fprime_k
=
None
,
f_k_minus_1
=
None
):
...
...
nifty/minimization/line_searching/line_search_strong_wolfe.py
View file @
17b1c6e9
...
...
@@ -45,13 +45,8 @@ class LineSearchStrongWolfe(LineSearch):
self
.
max_zoom_iterations
=
int
(
max_zoom_iterations
)
self
.
_last_alpha_star
=
1.
def
perform_line_search
(
self
,
xk
,
pk
,
f_k
=
None
,
fprime_k
=
None
,
f_k_minus_1
=
None
):
self
.
_set_coordinates
(
xk
=
xk
,
pk
=
pk
,
f_k
=
f_k
,
fprime_k
=
fprime_k
,
f_k_minus_1
=
f_k_minus_1
)
def
perform_line_search
(
self
,
energy
,
pk
,
f_k_minus_1
=
None
):
self
.
_set_line_energy
(
energy
,
pk
,
f_k_minus_1
=
f_k_minus_1
)
c1
=
self
.
c1
c2
=
self
.
c2
max_step_size
=
self
.
max_step_size
...
...
@@ -59,8 +54,8 @@ class LineSearchStrongWolfe(LineSearch):
# initialize the zero phis
old_phi_0
=
self
.
f_k_minus_1
phi_0
=
self
.
_phi
(
0.
)
phiprime_0
=
self
.
_phiprime
(
0.
)
phi_0
=
self
.
line_energy
.
at
(
0
).
value
phiprime_0
=
self
.
line_energy
.
at
(
0
).
gradient
if
phiprime_0
==
0
:
self
.
logger
.
warn
(
"Flat gradient in search direction."
)
...
...
@@ -81,7 +76,8 @@ class LineSearchStrongWolfe(LineSearch):
# start the minimization loop
for
i
in
xrange
(
max_iterations
):
phi_alpha1
=
self
.
_phi
(
alpha1
)
energy_alpha1
=
self
.
line_energy
.
at
(
alpha1
)
phi_alpha1
=
energy_alpha1
.
value
if
alpha1
==
0
:
self
.
logger
.
warn
(
"Increment size became 0."
)
alpha_star
=
0.
...
...
@@ -98,7 +94,7 @@ class LineSearchStrongWolfe(LineSearch):
c1
,
c2
)
break
phiprime_alpha1
=
self
.
_phiprime
(
alpha1
)
phiprime_alpha1
=
energy_alpha1
.
gradient
if
abs
(
phiprime_alpha1
)
<=
-
c2
*
phiprime_0
:
alpha_star
=
alpha1
phi_star
=
phi_alpha1
...
...
@@ -165,7 +161,8 @@ class LineSearchStrongWolfe(LineSearch):
alpha_j
=
alpha_lo
+
0.5
*
delta_alpha
# Check if the current value of alpha_j is already sufficient
phi_alphaj
=
self
.
_phi
(
alpha_j
)
energy_alphaj
=
self
.
line_energy
.
at
(
alpha_j
)
phi_alphaj
=
energy_alphaj
.
value
# If the first Wolfe condition is not met replace alpha_hi
# by alpha_j
...
...
@@ -174,7 +171,7 @@ class LineSearchStrongWolfe(LineSearch):
alpha_recent
,
phi_recent
=
alpha_hi
,
phi_hi
alpha_hi
,
phi_hi
=
alpha_j
,
phi_alphaj
else
:
phiprime_alphaj
=
self
.
_phiprime
(
alpha_j
)
phiprime_alphaj
=
energy_alphaj
.
gradient
# If the second Wolfe condition is met, return the result
if
abs
(
phiprime_alphaj
)
<=
-
c2
*
phiprime_0
:
alpha_star
=
alpha_j
...
...
nifty/minimization/quasi_newton_minimizer.py
View file @
17b1c6e9
...
...
@@ -26,7 +26,7 @@ class QuasiNewtonMinimizer(object, Loggable):
self
.
line_searcher
=
line_searcher
self
.
callback
=
callback
def
__call__
(
self
,
x0
,
f
,
fprime
,
f_args
=
()
):
def
__call__
(
self
,
energy
):
"""
Runs the steepest descent minimization.
...
...
@@ -56,49 +56,45 @@ class QuasiNewtonMinimizer(object, Loggable):
"""
x
=
x0
.
copy
()
self
.
line_searcher
.
set_functions
(
f
=
f
,
fprime
=
fprime
,
f_args
=
f_args
)
convergence
=
0
f_k_minus_1
=
None
f_k
=
f
(
x
)
step_length
=
0
iteration_number
=
1
while
True
:
if
self
.
callback
is
not
None
:
try
:
self
.
callback
(
x
,
f_k
,
iteration_number
)
self
.
callback
(
energy
,
iteration_number
)
except
StopIteration
:
self
.
logger
.
info
(
"Minimization was stopped by callback "
"function."
)
break
# compute the the gradient for the current
x
gradient
=
fprime
(
x
)
# compute the the gradient for the current
location
gradient
=
energy
.
gradient
gradient_norm
=
gradient
.
dot
(
gradient
)
# check if
x
is at a flat point
# check if
position
is at a flat point
if
gradient_norm
==
0
:
self
.
logger
.
info
(
"Reached perfectly flat point. Stopping."
)
convergence
=
self
.
convergence_level
+
2
break
descend_direction
=
self
.
_get_descend_direction
(
x
,
gradient
)
current_position
=
energy
.
position
descend_direction
=
self
.
_get_descend_direction
(
current_position
,
gradient
)
# compute the step length, which minimizes
f_k
along the
# search direction
= the gradient
step_length
,
new_f_k
=
self
.
line_searcher
.
perform_line_search
(
xk
=
x
,
# compute the step length, which minimizes
energy.value
along the
# search direction
step_length
,
step_length
=
self
.
line_searcher
.
perform_line_search
(
energy
=
energy
,
pk
=
descend_direction
,
f_k
=
f_k
,
fprime_k
=
gradient
,
f_k_minus_1
=
f_k_minus_1
)
f_k_minus_1
=
f_k
f_k
=
new_f_k
new_position
=
current_position
+
step_length
*
descend_direction
new_energy
=
energy
.
at
(
new_position
)
# update x
x
+=
descend_direction
*
step_length
f_k_minus_1
=
energy
.
value
energy
=
new_energy
# check convergence
delta
=
abs
(
gradient
).
max
()
*
(
step_length
/
gradient_norm
)
...
...
@@ -127,7 +123,7 @@ class QuasiNewtonMinimizer(object, Loggable):
iteration_number
+=
1
return
x
,
convergence
return
energy
,
convergence
@
abc
.
abstractmethod
def
_get_descend_direction
(
self
,
gradient
,
gradient_norm
):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment