Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
b5f37f09
Commit
b5f37f09
authored
Oct 24, 2016
by
theos
Browse files
Bugfixes related to VL-BFGS and SteepestDescent
parent
19160829
Changes
6
Hide whitespace changes
Inline
Side-by-side
demos/wiener_filter_hamiltonian.py
0 → 100644
View file @
b5f37f09
from
nifty
import
*
import
plotly.offline
as
pl
import
plotly.graph_objs
as
go
from
mpi4py
import
MPI
comm
=
MPI
.
COMM_WORLD
rank
=
comm
.
rank
if
__name__
==
"__main__"
:
distribution_strategy
=
'fftw'
s_space
=
RGSpace
([
512
,
512
],
dtype
=
np
.
float
)
fft
=
FFTOperator
(
s_space
)
h_space
=
fft
.
target
[
0
]
p_space
=
PowerSpace
(
h_space
,
distribution_strategy
=
distribution_strategy
)
pow_spec
=
(
lambda
k
:
(
42
/
(
k
+
1
)
**
3
))
S
=
create_power_operator
(
h_space
,
power_spectrum
=
pow_spec
,
distribution_strategy
=
distribution_strategy
)
sp
=
Field
(
p_space
,
val
=
lambda
z
:
pow_spec
(
z
)
**
(
1.
/
2
),
distribution_strategy
=
distribution_strategy
)
sh
=
sp
.
power_synthesize
(
real_signal
=
True
)
ss
=
fft
.
inverse_times
(
sh
)
R
=
SmoothingOperator
(
s_space
,
sigma
=
0.01
)
# R = DiagonalOperator(s_space, diagonal=1.)
# R._diagonal.val[200:400, 200:400] = 0
signal_to_noise
=
1
N
=
DiagonalOperator
(
s_space
,
diagonal
=
ss
.
var
()
/
signal_to_noise
,
bare
=
True
)
n
=
Field
.
from_random
(
domain
=
s_space
,
random_type
=
'normal'
,
std
=
ss
.
std
()
/
np
.
sqrt
(
signal_to_noise
),
mean
=
0
)
#n.val.data.imag[:] = 0
d
=
R
(
ss
)
+
n
j
=
R
.
adjoint_times
(
N
.
inverse_times
(
d
))
D
=
PropagatorOperator
(
S
=
S
,
N
=
N
,
R
=
R
)
def
energy
(
x
):
DIx
=
D
.
inverse_times
(
x
)
H
=
0.5
*
DIx
.
dot
(
x
)
-
j
.
dot
(
x
)
return
H
.
real
def
gradient
(
x
):
DIx
=
D
.
inverse_times
(
x
)
g
=
DIx
-
j
return_g
=
g
.
copy_empty
(
dtype
=
np
.
float
)
return_g
.
val
=
g
.
val
.
real
return
return_g
def
distance_measure
(
x
,
fgrad
,
iteration
):
print
(
iteration
,
((
x
-
ss
).
norm
()
/
ss
.
norm
()).
real
)
minimizer
=
SteepestDescent
(
convergence_tolerance
=
0
,
iteration_limit
=
50
,
callback
=
distance_measure
)
minimizer
=
VL_BFGS
(
convergence_tolerance
=
0
,
iteration_limit
=
50
,
callback
=
distance_measure
,
max_history_length
=
5
)
m0
=
Field
(
s_space
,
val
=
1
)
(
m
,
convergence
)
=
minimizer
(
m0
,
energy
,
gradient
)
grad
=
gradient
(
m
)
d_data
=
d
.
val
.
get_full_data
().
real
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
d_data
)],
filename
=
'data.html'
)
ss_data
=
ss
.
val
.
get_full_data
().
real
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
ss_data
)],
filename
=
'ss.html'
)
sh_data
=
sh
.
val
.
get_full_data
().
real
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
sh_data
)],
filename
=
'sh.html'
)
j_data
=
j
.
val
.
get_full_data
().
real
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
j_data
)],
filename
=
'j.html'
)
jabs_data
=
np
.
abs
(
j
.
val
.
get_full_data
())
jphase_data
=
np
.
angle
(
j
.
val
.
get_full_data
())
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
jabs_data
)],
filename
=
'j_abs.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
jphase_data
)],
filename
=
'j_phase.html'
)
m_data
=
m
.
val
.
get_full_data
().
real
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
m_data
)],
filename
=
'map.html'
)
grad_data
=
grad
.
val
.
get_full_data
().
real
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
grad_data
)],
filename
=
'grad.html'
)
nifty/minimization/bare_dot.py
deleted
100644 → 0
View file @
19160829
# -*- coding: utf-8 -*-
import
numpy
as
np
def
bare_dot
(
a
,
b
):
try
:
return
a
.
dot
(
b
,
bare
=
True
)
except
(
AttributeError
,
TypeError
):
pass
try
:
return
a
.
vdot
(
b
)
except
(
AttributeError
):
pass
return
np
.
vdot
(
a
,
b
)
nifty/minimization/line_searching/line_search.py
View file @
b5f37f09
...
...
@@ -2,8 +2,6 @@ import abc
from
keepers
import
Loggable
from
..bare_dot
import
bare_dot
class
LineSearch
(
object
,
Loggable
):
"""
...
...
@@ -94,7 +92,7 @@ class LineSearch(object, Loggable):
else
:
gradient
=
self
.
fprime
(
self
.
xk
+
self
.
pk
*
alpha
,
*
self
.
f_args
)
return
bare_dot
(
gradient
,
self
.
pk
)
return
gradient
.
dot
(
self
.
pk
)
@
abc
.
abstractmethod
def
perform_line_search
(
self
,
xk
,
pk
,
f_k
=
None
,
fprime_k
=
None
,
...
...
nifty/minimization/line_searching/line_search_strong_wolfe.py
View file @
b5f37f09
...
...
@@ -136,10 +136,6 @@ class LineSearchStrongWolfe(LineSearch):
cubic_delta
=
0.2
# cubic
quad_delta
=
0.1
# quadratic
cubic_delta
=
0.0
# cubic
quad_delta
=
0.0
# quadratic
# initialize the most recent versions (j-1) of phi and alpha
alpha_recent
=
0
phi_recent
=
phi_0
...
...
@@ -170,6 +166,7 @@ class LineSearchStrongWolfe(LineSearch):
# Check if the current value of alpha_j is already sufficient
phi_alphaj
=
self
.
_phi
(
alpha_j
)
# If the first Wolfe condition is not met replace alpha_hi
# by alpha_j
if
(
phi_alphaj
>
phi_0
+
c1
*
alpha_j
*
phiprime_0
)
or
\
...
...
nifty/minimization/steepest_descent.py
View file @
b5f37f09
...
...
@@ -5,4 +5,9 @@ from .quasi_newton_minimizer import QuasiNewtonMinimizer
class
SteepestDescent
(
QuasiNewtonMinimizer
):
def
_get_descend_direction
(
self
,
x
,
gradient
):
return
gradient
/
(
-
gradient
.
dot
(
gradient
))
descend_direction
=
gradient
norm
=
descend_direction
.
norm
()
if
norm
!=
1
:
return
descend_direction
/
-
norm
else
:
return
descend_direction
*
-
1
nifty/minimization/vl_bfgs.py
View file @
b5f37f09
...
...
@@ -5,8 +5,6 @@ import numpy as np
from
.quasi_newton_minimizer
import
QuasiNewtonMinimizer
from
.line_searching
import
LineSearchStrongWolfe
from
.bare_dot
import
bare_dot
class
VL_BFGS
(
QuasiNewtonMinimizer
):
def
__init__
(
self
,
line_searcher
=
LineSearchStrongWolfe
(),
callback
=
None
,
...
...
@@ -42,7 +40,7 @@ class VL_BFGS(QuasiNewtonMinimizer):
for
i
in
xrange
(
1
,
len
(
delta
)):
descend_direction
+=
delta
[
i
]
*
b
[
i
]
norm
=
np
.
sqrt
(
bare_dot
(
descend_direction
,
descend_direction
)
)
norm
=
descend_direction
.
norm
(
)
if
norm
!=
1
:
descend_direction
/=
norm
return
descend_direction
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment