Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
81663bbf
Commit
81663bbf
authored
Aug 29, 2018
by
Martin Reinecke
Browse files
Merge remote-tracking branch 'origin/NIFTy_5' into adjust_variances_but_right
parents
4fb3544c
09b16aff
Changes
21
Show whitespace changes
Inline
Side-by-side
demos/getting_started_3.py
View file @
81663bbf
...
@@ -72,8 +72,8 @@ if __name__ == '__main__':
...
@@ -72,8 +72,8 @@ if __name__ == '__main__':
# set up minimization and inversion schemes
# set up minimization and inversion schemes
ic_sampling
=
ift
.
GradientNormController
(
iteration_limit
=
100
)
ic_sampling
=
ift
.
GradientNormController
(
iteration_limit
=
100
)
ic_newton
=
ift
.
DeltaEnergy
Controller
(
ic_newton
=
ift
.
GradInfNorm
Controller
(
name
=
'Newton'
,
tol
_rel_deltaE
=
1e-
8
,
iteration_limit
=
100
)
name
=
'Newton'
,
tol
=
1e-
7
,
iteration_limit
=
100
0
)
minimizer
=
ift
.
NewtonCG
(
ic_newton
)
minimizer
=
ift
.
NewtonCG
(
ic_newton
)
# build model Hamiltonian
# build model Hamiltonian
...
@@ -91,7 +91,7 @@ if __name__ == '__main__':
...
@@ -91,7 +91,7 @@ if __name__ == '__main__':
# number of samples used to estimate the KL
# number of samples used to estimate the KL
N_samples
=
20
N_samples
=
20
for
i
in
range
(
2
):
for
i
in
range
(
2
):
KL
=
ift
.
KL_Energy
(
position
,
H
,
N_samples
,
want_metric
=
True
)
KL
=
ift
.
KL_Energy
(
position
,
H
,
N_samples
)
KL
,
convergence
=
minimizer
(
KL
)
KL
,
convergence
=
minimizer
(
KL
)
position
=
KL
.
position
position
=
KL
.
position
...
...
nifty5/__init__.py
View file @
81663bbf
...
@@ -54,9 +54,9 @@ from .probing import probe_with_posterior_samples, probe_diagonal, \
...
@@ -54,9 +54,9 @@ from .probing import probe_with_posterior_samples, probe_diagonal, \
StatCalculator
StatCalculator
from
.minimization.line_search
import
LineSearch
from
.minimization.line_search
import
LineSearch
from
.minimization.line_search_strong_wolfe
import
LineSearchStrongWolfe
from
.minimization.iteration_controllers
import
(
from
.minimization.iteration_controllers
import
(
IterationController
,
GradientNormController
,
DeltaEnergyController
)
IterationController
,
GradientNormController
,
DeltaEnergyController
,
GradInfNormController
)
from
.minimization.minimizer
import
Minimizer
from
.minimization.minimizer
import
Minimizer
from
.minimization.conjugate_gradient
import
ConjugateGradient
from
.minimization.conjugate_gradient
import
ConjugateGradient
from
.minimization.nonlinear_cg
import
NonlinearCG
from
.minimization.nonlinear_cg
import
NonlinearCG
...
@@ -66,12 +66,11 @@ from .minimization.descent_minimizers import (
...
@@ -66,12 +66,11 @@ from .minimization.descent_minimizers import (
from
.minimization.scipy_minimizer
import
(
ScipyMinimizer
,
L_BFGS_B
,
ScipyCG
)
from
.minimization.scipy_minimizer
import
(
ScipyMinimizer
,
L_BFGS_B
,
ScipyCG
)
from
.minimization.energy
import
Energy
from
.minimization.energy
import
Energy
from
.minimization.quadratic_energy
import
QuadraticEnergy
from
.minimization.quadratic_energy
import
QuadraticEnergy
from
.minimization.line_energy
import
LineEnergy
from
.minimization.energy_adapter
import
EnergyAdapter
from
.minimization.energy_adapter
import
EnergyAdapter
from
.minimization.kl_energy
import
KL_Energy
from
.minimization.kl_energy
import
KL_Energy
from
.sugar
import
*
from
.sugar
import
*
from
.plotting
.plot
import
Plot
from
.plot
import
Plot
from
.library.amplitude_model
import
AmplitudeModel
from
.library.amplitude_model
import
AmplitudeModel
from
.library.inverse_gamma_model
import
InverseGammaModel
from
.library.inverse_gamma_model
import
InverseGammaModel
...
...
nifty5/data_objects/distributed_do.py
View file @
81663bbf
...
@@ -32,7 +32,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
...
@@ -32,7 +32,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"local_data"
,
"ibegin"
,
"ibegin_from_shape"
,
"np_allreduce_sum"
,
"local_data"
,
"ibegin"
,
"ibegin_from_shape"
,
"np_allreduce_sum"
,
"np_allreduce_min"
,
"np_allreduce_max"
,
"np_allreduce_min"
,
"np_allreduce_max"
,
"distaxis"
,
"from_local_data"
,
"from_global_data"
,
"to_global_data"
,
"distaxis"
,
"from_local_data"
,
"from_global_data"
,
"to_global_data"
,
"redistribute"
,
"default_distaxis"
,
"is_numpy"
,
"redistribute"
,
"default_distaxis"
,
"is_numpy"
,
"absmax"
,
"norm"
,
"lock"
,
"locked"
,
"uniform_full"
,
"transpose"
,
"to_global_data_rw"
,
"lock"
,
"locked"
,
"uniform_full"
,
"transpose"
,
"to_global_data_rw"
,
"ensure_not_distributed"
,
"ensure_default_distributed"
]
"ensure_not_distributed"
,
"ensure_default_distributed"
]
...
@@ -553,3 +553,22 @@ def ensure_default_distributed(arr):
...
@@ -553,3 +553,22 @@ def ensure_default_distributed(arr):
if
arr
.
_distaxis
!=
0
:
if
arr
.
_distaxis
!=
0
:
arr
=
redistribute
(
arr
,
dist
=
0
)
arr
=
redistribute
(
arr
,
dist
=
0
)
return
arr
return
arr
def
absmax
(
arr
):
if
arr
.
_data
.
size
==
0
:
tmp
=
np
.
array
(
0
,
dtype
=
arr
.
_data
.
dtype
)
else
:
tmp
=
np
.
asarray
(
np
.
linalg
.
norm
(
arr
.
_data
,
ord
=
np
.
inf
))
res
=
np
.
empty_like
(
tmp
)
_comm
.
Allreduce
(
tmp
,
res
,
MPI
.
MAX
)
return
res
[()]
def
norm
(
arr
,
ord
=
2
):
if
ord
==
np
.
inf
:
return
absmax
(
arr
)
tmp
=
np
.
asarray
(
np
.
linalg
.
norm
(
np
.
atleast_1d
(
arr
.
_data
),
ord
=
ord
)
**
ord
)
res
=
np
.
empty_like
(
tmp
)
_comm
.
Allreduce
(
tmp
,
res
,
MPI
.
SUM
)
return
res
[()]
**
(
1.
/
ord
)
nifty5/data_objects/numpy_do.py
View file @
81663bbf
...
@@ -31,7 +31,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
...
@@ -31,7 +31,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"local_data"
,
"ibegin"
,
"ibegin_from_shape"
,
"np_allreduce_sum"
,
"local_data"
,
"ibegin"
,
"ibegin_from_shape"
,
"np_allreduce_sum"
,
"np_allreduce_min"
,
"np_allreduce_max"
,
"np_allreduce_min"
,
"np_allreduce_max"
,
"distaxis"
,
"from_local_data"
,
"from_global_data"
,
"to_global_data"
,
"distaxis"
,
"from_local_data"
,
"from_global_data"
,
"to_global_data"
,
"redistribute"
,
"default_distaxis"
,
"is_numpy"
,
"redistribute"
,
"default_distaxis"
,
"is_numpy"
,
"absmax"
,
"norm"
,
"lock"
,
"locked"
,
"uniform_full"
,
"to_global_data_rw"
,
"lock"
,
"locked"
,
"uniform_full"
,
"to_global_data_rw"
,
"ensure_not_distributed"
,
"ensure_default_distributed"
]
"ensure_not_distributed"
,
"ensure_default_distributed"
]
...
@@ -141,3 +141,11 @@ def ensure_not_distributed(arr, axes):
...
@@ -141,3 +141,11 @@ def ensure_not_distributed(arr, axes):
def
ensure_default_distributed
(
arr
):
def
ensure_default_distributed
(
arr
):
return
arr
return
arr
def
absmax
(
arr
):
return
np
.
linalg
.
norm
(
arr
,
ord
=
np
.
inf
)
def
norm
(
arr
,
ord
=
2
):
return
np
.
linalg
.
norm
(
np
.
atleast_1d
(
arr
),
ord
=
ord
)
nifty5/extra
/energy_and_model_tests
.py
→
nifty5/extra.py
View file @
81663bbf
...
@@ -20,14 +20,61 @@ from __future__ import absolute_import, division, print_function
...
@@ -20,14 +20,61 @@ from __future__ import absolute_import, division, print_function
import
numpy
as
np
import
numpy
as
np
from
..compat
import
*
from
.compat
import
*
from
..linearization
import
Linearization
from
.field
import
Field
from
..sugar
import
from_random
from
.linearization
import
Linearization
from
.sugar
import
from_random
__all__
=
[
"check_value_gradient_consistency"
,
__all__
=
[
"consistency_check"
,
"check_value_gradient_consistency"
,
"check_value_gradient_metric_consistency"
]
"check_value_gradient_metric_consistency"
]
def
_assert_allclose
(
f1
,
f2
,
atol
,
rtol
):
if
isinstance
(
f1
,
Field
):
return
np
.
testing
.
assert_allclose
(
f1
.
local_data
,
f2
.
local_data
,
atol
=
atol
,
rtol
=
rtol
)
for
key
,
val
in
f1
.
items
():
_assert_allclose
(
val
,
f2
[
key
],
atol
=
atol
,
rtol
=
rtol
)
def
_adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
needed_cap
=
op
.
TIMES
|
op
.
ADJOINT_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
f1
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
f2
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
res1
=
f1
.
vdot
(
op
.
adjoint_times
(
f2
))
res2
=
op
.
times
(
f1
).
vdot
(
f2
)
np
.
testing
.
assert_allclose
(
res1
,
res2
,
atol
=
atol
,
rtol
=
rtol
)
def
_inverse_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
needed_cap
=
op
.
TIMES
|
op
.
INVERSE_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
foo
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
res
=
op
(
op
.
inverse_times
(
foo
))
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
foo
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
res
=
op
.
inverse_times
(
op
(
foo
))
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
def
_full_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
_adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
_inverse_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
def
consistency_check
(
op
,
domain_dtype
=
np
.
float64
,
target_dtype
=
np
.
float64
,
atol
=
0
,
rtol
=
1e-7
):
_full_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
_full_implementation
(
op
.
adjoint
,
target_dtype
,
domain_dtype
,
atol
,
rtol
)
_full_implementation
(
op
.
inverse
,
target_dtype
,
domain_dtype
,
atol
,
rtol
)
_full_implementation
(
op
.
adjoint
.
inverse
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
def
_get_acceptable_location
(
op
,
loc
,
lin
):
def
_get_acceptable_location
(
op
,
loc
,
lin
):
if
not
np
.
isfinite
(
lin
.
val
.
sum
()):
if
not
np
.
isfinite
(
lin
.
val
.
sum
()):
raise
ValueError
(
'Initial value must be finite'
)
raise
ValueError
(
'Initial value must be finite'
)
...
...
nifty5/extra/__init__.py
deleted
100644 → 0
View file @
4fb3544c
from
.operator_tests
import
consistency_check
from
.energy_and_model_tests
import
*
nifty5/extra/operator_tests.py
deleted
100644 → 0
View file @
4fb3544c
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
__future__
import
absolute_import
,
division
,
print_function
import
numpy
as
np
from
..compat
import
*
from
..field
import
Field
from
..sugar
import
from_random
__all__
=
[
"consistency_check"
]
def
_assert_allclose
(
f1
,
f2
,
atol
,
rtol
):
if
isinstance
(
f1
,
Field
):
return
np
.
testing
.
assert_allclose
(
f1
.
local_data
,
f2
.
local_data
,
atol
=
atol
,
rtol
=
rtol
)
for
key
,
val
in
f1
.
items
():
_assert_allclose
(
val
,
f2
[
key
],
atol
=
atol
,
rtol
=
rtol
)
def
adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
needed_cap
=
op
.
TIMES
|
op
.
ADJOINT_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
f1
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
f2
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
res1
=
f1
.
vdot
(
op
.
adjoint_times
(
f2
))
res2
=
op
.
times
(
f1
).
vdot
(
f2
)
np
.
testing
.
assert_allclose
(
res1
,
res2
,
atol
=
atol
,
rtol
=
rtol
)
def
inverse_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
needed_cap
=
op
.
TIMES
|
op
.
INVERSE_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
foo
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
res
=
op
(
op
.
inverse_times
(
foo
))
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
foo
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
res
=
op
.
inverse_times
(
op
(
foo
))
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
def
full_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
inverse_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
def
consistency_check
(
op
,
domain_dtype
=
np
.
float64
,
target_dtype
=
np
.
float64
,
atol
=
0
,
rtol
=
1e-7
):
full_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
full_implementation
(
op
.
adjoint
,
target_dtype
,
domain_dtype
,
atol
,
rtol
)
full_implementation
(
op
.
inverse
,
target_dtype
,
domain_dtype
,
atol
,
rtol
)
full_implementation
(
op
.
adjoint
.
inverse
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
nifty5/field.py
View file @
81663bbf
...
@@ -360,25 +360,20 @@ class Field(object):
...
@@ -360,25 +360,20 @@ class Field(object):
# For the moment, do this the explicit, non-optimized way
# For the moment, do this the explicit, non-optimized way
return
(
self
.
conjugate
()
*
x
).
sum
(
spaces
=
spaces
)
return
(
self
.
conjugate
()
*
x
).
sum
(
spaces
=
spaces
)
def
norm
(
self
):
def
norm
(
self
,
ord
=
2
):
""" Computes the L2-norm of the field values.
""" Computes the L2-norm of the field values.
Returns
Parameters
-------
----------
float
ord : int, default=2
The L2-norm of the field values.
accepted values: 1, 2, ..., np.inf
"""
return
np
.
sqrt
(
abs
(
self
.
vdot
(
x
=
self
)))
def
squared_norm
(
self
):
""" Computes the square of the L2-norm of the field values.
Returns
Returns
-------
-------
float
float
The
square of the
L2-norm of the field values.
The L2-norm of the field values.
"""
"""
return
abs
(
self
.
vdot
(
x
=
self
)
)
return
dobj
.
norm
(
self
.
_val
,
ord
)
def
conjugate
(
self
):
def
conjugate
(
self
):
""" Returns the complex conjugate of the field.
""" Returns the complex conjugate of the field.
...
...
nifty5/linearization.py
View file @
81663bbf
...
@@ -163,6 +163,9 @@ class Linearization(object):
...
@@ -163,6 +163,9 @@ class Linearization(object):
def
add_metric
(
self
,
metric
):
def
add_metric
(
self
,
metric
):
return
self
.
new
(
self
.
_val
,
self
.
_jac
,
metric
)
return
self
.
new
(
self
.
_val
,
self
.
_jac
,
metric
)
def
with_want_metric
(
self
):
return
Linearization
(
self
.
_val
,
self
.
_jac
,
self
.
_metric
,
True
)
@
staticmethod
@
staticmethod
def
make_var
(
field
,
want_metric
=
False
):
def
make_var
(
field
,
want_metric
=
False
):
from
.operators.scaling_operator
import
ScalingOperator
from
.operators.scaling_operator
import
ScalingOperator
...
@@ -174,3 +177,15 @@ class Linearization(object):
...
@@ -174,3 +177,15 @@ class Linearization(object):
from
.operators.simple_linear_operators
import
NullOperator
from
.operators.simple_linear_operators
import
NullOperator
return
Linearization
(
field
,
NullOperator
(
field
.
domain
,
field
.
domain
),
return
Linearization
(
field
,
NullOperator
(
field
.
domain
,
field
.
domain
),
want_metric
=
want_metric
)
want_metric
=
want_metric
)
@
staticmethod
def
make_partial_var
(
field
,
constants
,
want_metric
=
False
):
from
.operators.scaling_operator
import
ScalingOperator
from
.operators.simple_linear_operators
import
NullOperator
if
len
(
constants
)
==
0
:
return
Linearization
.
make_var
(
field
,
want_metric
)
else
:
ops
=
[
ScalingOperator
(
0.
if
key
in
constants
else
1.
,
dom
)
for
key
,
dom
in
field
.
domain
.
items
()]
bdop
=
BlockDiagonalOperator
(
fielld
.
domain
,
tuple
(
ops
))
return
Linearization
(
field
,
bdop
,
want_metric
=
want_metric
)
nifty5/minimization/descent_minimizers.py
View file @
81663bbf
...
@@ -22,7 +22,7 @@ import numpy as np
...
@@ -22,7 +22,7 @@ import numpy as np
from
..compat
import
*
from
..compat
import
*
from
..logger
import
logger
from
..logger
import
logger
from
.line_search
_strong_wolfe
import
LineSearch
StrongWolfe
from
.line_search
import
LineSearch
from
.minimizer
import
Minimizer
from
.minimizer
import
Minimizer
...
@@ -40,10 +40,10 @@ class DescentMinimizer(Minimizer):
...
@@ -40,10 +40,10 @@ class DescentMinimizer(Minimizer):
Object that decides when to terminate the minimization.
Object that decides when to terminate the minimization.
line_searcher : callable *optional*
line_searcher : callable *optional*
Function which infers the step size in the descent direction
Function which infers the step size in the descent direction
(default : LineSearch
StrongWolfe
()).
(default : LineSearch()).
"""
"""
def
__init__
(
self
,
controller
,
line_searcher
=
LineSearch
StrongWolfe
()):
def
__init__
(
self
,
controller
,
line_searcher
=
LineSearch
()):
self
.
_controller
=
controller
self
.
_controller
=
controller
self
.
line_searcher
=
line_searcher
self
.
line_searcher
=
line_searcher
...
@@ -144,8 +144,7 @@ class RelaxedNewton(DescentMinimizer):
...
@@ -144,8 +144,7 @@ class RelaxedNewton(DescentMinimizer):
def
__init__
(
self
,
controller
,
line_searcher
=
None
):
def
__init__
(
self
,
controller
,
line_searcher
=
None
):
if
line_searcher
is
None
:
if
line_searcher
is
None
:
line_searcher
=
LineSearchStrongWolfe
(
line_searcher
=
LineSearch
(
preferred_initial_step_size
=
1.
)
preferred_initial_step_size
=
1.
)
super
(
RelaxedNewton
,
self
).
__init__
(
controller
=
controller
,
super
(
RelaxedNewton
,
self
).
__init__
(
controller
=
controller
,
line_searcher
=
line_searcher
)
line_searcher
=
line_searcher
)
...
@@ -161,8 +160,7 @@ class NewtonCG(DescentMinimizer):
...
@@ -161,8 +160,7 @@ class NewtonCG(DescentMinimizer):
def
__init__
(
self
,
controller
,
line_searcher
=
None
):
def
__init__
(
self
,
controller
,
line_searcher
=
None
):
if
line_searcher
is
None
:
if
line_searcher
is
None
:
line_searcher
=
LineSearchStrongWolfe
(
line_searcher
=
LineSearch
(
preferred_initial_step_size
=
1.
)
preferred_initial_step_size
=
1.
)
super
(
NewtonCG
,
self
).
__init__
(
controller
=
controller
,
super
(
NewtonCG
,
self
).
__init__
(
controller
=
controller
,
line_searcher
=
line_searcher
)
line_searcher
=
line_searcher
)
...
@@ -201,7 +199,7 @@ class NewtonCG(DescentMinimizer):
...
@@ -201,7 +199,7 @@ class NewtonCG(DescentMinimizer):
class
L_BFGS
(
DescentMinimizer
):
class
L_BFGS
(
DescentMinimizer
):
def
__init__
(
self
,
controller
,
line_searcher
=
LineSearch
StrongWolfe
(),
def
__init__
(
self
,
controller
,
line_searcher
=
LineSearch
(),
max_history_length
=
5
):
max_history_length
=
5
):
super
(
L_BFGS
,
self
).
__init__
(
controller
=
controller
,
super
(
L_BFGS
,
self
).
__init__
(
controller
=
controller
,
line_searcher
=
line_searcher
)
line_searcher
=
line_searcher
)
...
@@ -266,7 +264,7 @@ class VL_BFGS(DescentMinimizer):
...
@@ -266,7 +264,7 @@ class VL_BFGS(DescentMinimizer):
Microsoft
Microsoft
"""
"""
def
__init__
(
self
,
controller
,
line_searcher
=
LineSearch
StrongWolfe
(),
def
__init__
(
self
,
controller
,
line_searcher
=
LineSearch
(),
max_history_length
=
5
):
max_history_length
=
5
):
super
(
VL_BFGS
,
self
).
__init__
(
controller
=
controller
,
super
(
VL_BFGS
,
self
).
__init__
(
controller
=
controller
,
line_searcher
=
line_searcher
)
line_searcher
=
line_searcher
)
...
...
nifty5/minimization/energy_adapter.py
View file @
81663bbf
...
@@ -13,14 +13,8 @@ class EnergyAdapter(Energy):
...
@@ -13,14 +13,8 @@ class EnergyAdapter(Energy):
self
.
_op
=
op
self
.
_op
=
op
self
.
_constants
=
constants
self
.
_constants
=
constants
self
.
_want_metric
=
want_metric
self
.
_want_metric
=
want_metric
if
len
(
self
.
_constants
)
==
0
:
lin
=
Linearization
.
make_partial_var
(
position
,
constants
,
want_metric
)
tmp
=
self
.
_op
(
Linearization
.
make_var
(
self
.
_position
,
want_metric
))
tmp
=
self
.
_op
(
lin
)
else
:
ops
=
[
ScalingOperator
(
0.
if
key
in
self
.
_constants
else
1.
,
dom
)
for
key
,
dom
in
self
.
_position
.
domain
.
items
()]
bdop
=
BlockDiagonalOperator
(
self
.
_position
.
domain
,
tuple
(
ops
))
tmp
=
self
.
_op
(
Linearization
(
self
.
_position
,
bdop
,
want_metric
=
want_metric
))
self
.
_val
=
tmp
.
val
.
local_data
[()]
self
.
_val
=
tmp
.
val
.
local_data
[()]
self
.
_grad
=
tmp
.
gradient
self
.
_grad
=
tmp
.
gradient
self
.
_metric
=
tmp
.
_metric
self
.
_metric
=
tmp
.
_metric
...
...
nifty5/minimization/iteration_controllers.py
View file @
81663bbf
...
@@ -21,6 +21,7 @@ from __future__ import absolute_import, division, print_function
...
@@ -21,6 +21,7 @@ from __future__ import absolute_import, division, print_function
from
..compat
import
*
from
..compat
import
*
from
..logger
import
logger
from
..logger
import
logger
from
..utilities
import
NiftyMetaBase
from
..utilities
import
NiftyMetaBase
import
numpy
as
np
class
IterationController
(
NiftyMetaBase
()):
class
IterationController
(
NiftyMetaBase
()):
...
@@ -145,6 +146,48 @@ class GradientNormController(IterationController):
...
@@ -145,6 +146,48 @@ class GradientNormController(IterationController):
return
self
.
CONTINUE
return
self
.
CONTINUE
class
GradInfNormController
(
IterationController
):
def
__init__
(
self
,
tol
=
None
,
convergence_level
=
1
,
iteration_limit
=
None
,
name
=
None
):
self
.
_tol
=
tol
self
.
_convergence_level
=
convergence_level
self
.
_iteration_limit
=
iteration_limit
self
.
_name
=
name
def
start
(
self
,
energy
):
self
.
_itcount
=
-
1
self
.
_ccount
=
0
return
self
.
check
(
energy
)
def
check
(
self
,
energy
):
self
.
_itcount
+=
1
crit
=
energy
.
gradient
.
norm
(
np
.
inf
)
/
abs
(
energy
.
value
)
if
self
.
_tol
is
not
None
and
crit
<=
self
.
_tol
:
self
.
_ccount
+=
1
else
:
self
.
_ccount
=
max
(
0
,
self
.
_ccount
-
1
)
# report
if
self
.
_name
is
not
None
:
logger
.
info
(
"{}: Iteration #{} energy={:.6E} crit={:.2E} clvl={}"
.
format
(
self
.
_name
,
self
.
_itcount
,
energy
.
value
,
crit
,
self
.
_ccount
))
# Are we done?
if
self
.
_iteration_limit
is
not
None
:
if
self
.
_itcount
>=
self
.
_iteration_limit
:
logger
.
warning
(
"{} Iteration limit reached. Assuming convergence"
.
format
(
""
if
self
.
_name
is
None
else
self
.
_name
+
": "
))
return
self
.
CONVERGED
if
self
.
_ccount
>=
self
.
_convergence_level
:
return
self
.
CONVERGED
return
self
.
CONTINUE
class
DeltaEnergyController
(
IterationController
):
class
DeltaEnergyController
(
IterationController
):
def
__init__
(
self
,
tol_rel_deltaE
,
convergence_level
=
1
,
def
__init__
(
self
,
tol_rel_deltaE
,
convergence_level
=
1
,
iteration_limit
=
None
,
name
=
None
):
iteration_limit
=
None
,
name
=
None
):
...
...
nifty5/minimization/kl_energy.py
View file @
81663bbf
...
@@ -9,33 +9,32 @@ from .. import utilities
...
@@ -9,33 +9,32 @@ from .. import utilities
class
KL_Energy
(
Energy
):
class
KL_Energy
(
Energy
):
def
__init__
(
self
,
position
,
h
,
nsamp
,
constants
=
[],
_samples
=
None
,
def
__init__
(
self
,
position
,
h
,
nsamp
,
constants
=
[],
_samples
=
None
):
want_metric
=
False
):
super
(
KL_Energy
,
self
).
__init__
(
position
)
super
(
KL_Energy
,
self
).
__init__
(
position
)
self
.
_h
=
h
self
.
_h
=
h
self
.
_constants
=
constants
self
.
_constants
=
constants
self
.
_want_metric
=
want_metric
if
_samples
is
None
:
if
_samples
is
None
:
met
=
h
(
Linearization
.
make_var
(
position
,
True
)).
metric
met
=
h
(
Linearization
.
make_var
(
position
,
True
)).
metric
_samples
=
tuple
(
met
.
draw_sample
(
from_inverse
=
True
)
_samples
=
tuple
(
met
.
draw_sample
(
from_inverse
=
True
)
for
_
in
range
(
nsamp
))
for
_
in
range
(
nsamp
))
self
.
_samples
=
_samples
self
.
_samples
=
_samples
if
len
(
constants
)
==
0
:
tmp
=
Linearization
.
make_var
(
position
,
want_metric
)
self
.
_lin
=
Linearization
.
make_partial_var
(
position
,
constants
)
v
,
g
=
None
,
None
for
s
in
self
.
_samples
:
tmp
=
self
.
_h
(
self
.
_lin
+
s
)
if
v
is
None
:
v
=
tmp
.
val
.
local_data
[()]
g
=
tmp
.
gradient
else
:
else
:
ops
=
[
ScalingOperator
(
0.
if
key
in
constants
else
1.
,
dom
)
v
+=
tmp
.
val
.
local_data
[()]
for
key
,
dom
in
position
.
domain
.
items
()]
g
=
g
+
tmp
.
gradient
bdop
=
BlockDiagonalOperator
(
position
.
domain
,
tuple
(
ops
))
self
.
_val
=
v
/
len
(
self
.
_samples
)
tmp
=
Linearization
(
position
,
bdop
,
want_metric
=
want_metric
)
self
.
_grad
=
g
*
(
1.
/
len
(
self
.
_samples
))
mymap
=
map
(
lambda
v
:
self
.
_h
(
tmp
+
v
),
self
.
_samples
)
self
.
_metric
=
None
tmp
=
utilities
.
my_sum
(
mymap
)
*
(
1.
/
len
(
self
.
_samples
))
self
.
_val
=
tmp
.
val
.
local_data
[()]
self
.
_grad
=
tmp
.
gradient
self
.
_metric
=
tmp
.
metric
def
at
(
self
,
position
):
def
at
(
self
,
position
):
return
KL_Energy
(
position
,
self
.
_h
,
0
,
self
.
_constants
,
self
.
_samples
,