Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
09b16aff
Commit
09b16aff
authored
Aug 29, 2018
by
Martin Reinecke
Browse files
Merge branch 'cleanup' into 'NIFTy_5'
Cleanup See merge request ift/nifty-dev!98
parents
8af5299c
45dcfca7
Changes
21
Hide whitespace changes
Inline
Side-by-side
demos/getting_started_3.py
View file @
09b16aff
...
...
@@ -72,8 +72,8 @@ if __name__ == '__main__':
# set up minimization and inversion schemes
ic_sampling
=
ift
.
GradientNormController
(
iteration_limit
=
100
)
ic_newton
=
ift
.
DeltaEnergy
Controller
(
name
=
'Newton'
,
tol
_rel_deltaE
=
1e-
8
,
iteration_limit
=
100
)
ic_newton
=
ift
.
GradInfNorm
Controller
(
name
=
'Newton'
,
tol
=
1e-
7
,
iteration_limit
=
100
0
)
minimizer
=
ift
.
NewtonCG
(
ic_newton
)
# build model Hamiltonian
...
...
@@ -91,7 +91,7 @@ if __name__ == '__main__':
# number of samples used to estimate the KL
N_samples
=
20
for
i
in
range
(
2
):
KL
=
ift
.
KL_Energy
(
position
,
H
,
N_samples
,
want_metric
=
True
)
KL
=
ift
.
KL_Energy
(
position
,
H
,
N_samples
)
KL
,
convergence
=
minimizer
(
KL
)
position
=
KL
.
position
...
...
nifty5/__init__.py
View file @
09b16aff
...
...
@@ -54,9 +54,9 @@ from .probing import probe_with_posterior_samples, probe_diagonal, \
StatCalculator
from
.minimization.line_search
import
LineSearch
from
.minimization.line_search_strong_wolfe
import
LineSearchStrongWolfe
from
.minimization.iteration_controllers
import
(
IterationController
,
GradientNormController
,
DeltaEnergyController
)
IterationController
,
GradientNormController
,
DeltaEnergyController
,
GradInfNormController
)
from
.minimization.minimizer
import
Minimizer
from
.minimization.conjugate_gradient
import
ConjugateGradient
from
.minimization.nonlinear_cg
import
NonlinearCG
...
...
@@ -66,12 +66,11 @@ from .minimization.descent_minimizers import (
from
.minimization.scipy_minimizer
import
(
ScipyMinimizer
,
L_BFGS_B
,
ScipyCG
)
from
.minimization.energy
import
Energy
from
.minimization.quadratic_energy
import
QuadraticEnergy
from
.minimization.line_energy
import
LineEnergy
from
.minimization.energy_adapter
import
EnergyAdapter
from
.minimization.kl_energy
import
KL_Energy
from
.sugar
import
*
from
.plotting
.plot
import
Plot
from
.plot
import
Plot
from
.library.amplitude_model
import
AmplitudeModel
from
.library.inverse_gamma_model
import
InverseGammaModel
...
...
nifty5/data_objects/distributed_do.py
View file @
09b16aff
...
...
@@ -32,7 +32,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"local_data"
,
"ibegin"
,
"ibegin_from_shape"
,
"np_allreduce_sum"
,
"np_allreduce_min"
,
"np_allreduce_max"
,
"distaxis"
,
"from_local_data"
,
"from_global_data"
,
"to_global_data"
,
"redistribute"
,
"default_distaxis"
,
"is_numpy"
,
"redistribute"
,
"default_distaxis"
,
"is_numpy"
,
"absmax"
,
"norm"
,
"lock"
,
"locked"
,
"uniform_full"
,
"transpose"
,
"to_global_data_rw"
,
"ensure_not_distributed"
,
"ensure_default_distributed"
]
...
...
@@ -553,3 +553,22 @@ def ensure_default_distributed(arr):
if
arr
.
_distaxis
!=
0
:
arr
=
redistribute
(
arr
,
dist
=
0
)
return
arr
def
absmax
(
arr
):
if
arr
.
_data
.
size
==
0
:
tmp
=
np
.
array
(
0
,
dtype
=
arr
.
_data
.
dtype
)
else
:
tmp
=
np
.
asarray
(
np
.
linalg
.
norm
(
arr
.
_data
,
ord
=
np
.
inf
))
res
=
np
.
empty_like
(
tmp
)
_comm
.
Allreduce
(
tmp
,
res
,
MPI
.
MAX
)
return
res
[()]
def
norm
(
arr
,
ord
=
2
):
if
ord
==
np
.
inf
:
return
absmax
(
arr
)
tmp
=
np
.
asarray
(
np
.
linalg
.
norm
(
np
.
atleast_1d
(
arr
.
_data
),
ord
=
ord
)
**
ord
)
res
=
np
.
empty_like
(
tmp
)
_comm
.
Allreduce
(
tmp
,
res
,
MPI
.
SUM
)
return
res
[()]
**
(
1.
/
ord
)
nifty5/data_objects/numpy_do.py
View file @
09b16aff
...
...
@@ -31,7 +31,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"local_data"
,
"ibegin"
,
"ibegin_from_shape"
,
"np_allreduce_sum"
,
"np_allreduce_min"
,
"np_allreduce_max"
,
"distaxis"
,
"from_local_data"
,
"from_global_data"
,
"to_global_data"
,
"redistribute"
,
"default_distaxis"
,
"is_numpy"
,
"redistribute"
,
"default_distaxis"
,
"is_numpy"
,
"absmax"
,
"norm"
,
"lock"
,
"locked"
,
"uniform_full"
,
"to_global_data_rw"
,
"ensure_not_distributed"
,
"ensure_default_distributed"
]
...
...
@@ -141,3 +141,11 @@ def ensure_not_distributed(arr, axes):
def
ensure_default_distributed
(
arr
):
return
arr
def
absmax
(
arr
):
return
np
.
linalg
.
norm
(
arr
,
ord
=
np
.
inf
)
def
norm
(
arr
,
ord
=
2
):
return
np
.
linalg
.
norm
(
np
.
atleast_1d
(
arr
),
ord
=
ord
)
nifty5/extra
/energy_and_model_tests
.py
→
nifty5/extra.py
View file @
09b16aff
...
...
@@ -20,14 +20,61 @@ from __future__ import absolute_import, division, print_function
import
numpy
as
np
from
..compat
import
*
from
..linearization
import
Linearization
from
..sugar
import
from_random
from
.compat
import
*
from
.field
import
Field
from
.linearization
import
Linearization
from
.sugar
import
from_random
__all__
=
[
"check_value_gradient_consistency"
,
__all__
=
[
"consistency_check"
,
"check_value_gradient_consistency"
,
"check_value_gradient_metric_consistency"
]
def
_assert_allclose
(
f1
,
f2
,
atol
,
rtol
):
if
isinstance
(
f1
,
Field
):
return
np
.
testing
.
assert_allclose
(
f1
.
local_data
,
f2
.
local_data
,
atol
=
atol
,
rtol
=
rtol
)
for
key
,
val
in
f1
.
items
():
_assert_allclose
(
val
,
f2
[
key
],
atol
=
atol
,
rtol
=
rtol
)
def
_adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
needed_cap
=
op
.
TIMES
|
op
.
ADJOINT_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
f1
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
f2
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
res1
=
f1
.
vdot
(
op
.
adjoint_times
(
f2
))
res2
=
op
.
times
(
f1
).
vdot
(
f2
)
np
.
testing
.
assert_allclose
(
res1
,
res2
,
atol
=
atol
,
rtol
=
rtol
)
def
_inverse_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
needed_cap
=
op
.
TIMES
|
op
.
INVERSE_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
foo
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
res
=
op
(
op
.
inverse_times
(
foo
))
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
foo
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
res
=
op
.
inverse_times
(
op
(
foo
))
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
def
_full_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
_adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
_inverse_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
def
consistency_check
(
op
,
domain_dtype
=
np
.
float64
,
target_dtype
=
np
.
float64
,
atol
=
0
,
rtol
=
1e-7
):
_full_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
_full_implementation
(
op
.
adjoint
,
target_dtype
,
domain_dtype
,
atol
,
rtol
)
_full_implementation
(
op
.
inverse
,
target_dtype
,
domain_dtype
,
atol
,
rtol
)
_full_implementation
(
op
.
adjoint
.
inverse
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
def
_get_acceptable_location
(
op
,
loc
,
lin
):
if
not
np
.
isfinite
(
lin
.
val
.
sum
()):
raise
ValueError
(
'Initial value must be finite'
)
...
...
nifty5/extra/__init__.py
deleted
100644 → 0
View file @
8af5299c
from
.operator_tests
import
consistency_check
from
.energy_and_model_tests
import
*
nifty5/extra/operator_tests.py
deleted
100644 → 0
View file @
8af5299c
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
__future__
import
absolute_import
,
division
,
print_function
import
numpy
as
np
from
..compat
import
*
from
..field
import
Field
from
..sugar
import
from_random
__all__
=
[
"consistency_check"
]
def
_assert_allclose
(
f1
,
f2
,
atol
,
rtol
):
if
isinstance
(
f1
,
Field
):
return
np
.
testing
.
assert_allclose
(
f1
.
local_data
,
f2
.
local_data
,
atol
=
atol
,
rtol
=
rtol
)
for
key
,
val
in
f1
.
items
():
_assert_allclose
(
val
,
f2
[
key
],
atol
=
atol
,
rtol
=
rtol
)
def
adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
needed_cap
=
op
.
TIMES
|
op
.
ADJOINT_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
f1
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
f2
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
res1
=
f1
.
vdot
(
op
.
adjoint_times
(
f2
))
res2
=
op
.
times
(
f1
).
vdot
(
f2
)
np
.
testing
.
assert_allclose
(
res1
,
res2
,
atol
=
atol
,
rtol
=
rtol
)
def
inverse_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
needed_cap
=
op
.
TIMES
|
op
.
INVERSE_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
foo
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
res
=
op
(
op
.
inverse_times
(
foo
))
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
foo
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
res
=
op
.
inverse_times
(
op
(
foo
))
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
def
full_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
inverse_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
def
consistency_check
(
op
,
domain_dtype
=
np
.
float64
,
target_dtype
=
np
.
float64
,
atol
=
0
,
rtol
=
1e-7
):
full_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
full_implementation
(
op
.
adjoint
,
target_dtype
,
domain_dtype
,
atol
,
rtol
)
full_implementation
(
op
.
inverse
,
target_dtype
,
domain_dtype
,
atol
,
rtol
)
full_implementation
(
op
.
adjoint
.
inverse
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
nifty5/field.py
View file @
09b16aff
...
...
@@ -360,25 +360,20 @@ class Field(object):
# For the moment, do this the explicit, non-optimized way
return
(
self
.
conjugate
()
*
x
).
sum
(
spaces
=
spaces
)
def
norm
(
self
):
def
norm
(
self
,
ord
=
2
):
""" Computes the L2-norm of the field values.
Returns
-------
float
The L2-norm of the field values.
"""
return
np
.
sqrt
(
abs
(
self
.
vdot
(
x
=
self
)))
def
squared_norm
(
self
):
""" Computes the square of the L2-norm of the field values.
Parameters
----------
ord : int, default=2
accepted values: 1, 2, ..., np.inf
Returns
-------
float
The
square of the
L2-norm of the field values.
The L2-norm of the field values.
"""
return
abs
(
self
.
vdot
(
x
=
self
)
)
return
dobj
.
norm
(
self
.
_val
,
ord
)
def
conjugate
(
self
):
""" Returns the complex conjugate of the field.
...
...
nifty5/linearization.py
View file @
09b16aff
...
...
@@ -152,6 +152,9 @@ class Linearization(object):
def
add_metric
(
self
,
metric
):
return
self
.
new
(
self
.
_val
,
self
.
_jac
,
metric
)
def
with_want_metric
(
self
):
return
Linearization
(
self
.
_val
,
self
.
_jac
,
self
.
_metric
,
True
)
@
staticmethod
def
make_var
(
field
,
want_metric
=
False
):
from
.operators.scaling_operator
import
ScalingOperator
...
...
@@ -163,3 +166,15 @@ class Linearization(object):
from
.operators.simple_linear_operators
import
NullOperator
return
Linearization
(
field
,
NullOperator
(
field
.
domain
,
field
.
domain
),
want_metric
=
want_metric
)
@
staticmethod
def
make_partial_var
(
field
,
constants
,
want_metric
=
False
):
from
.operators.scaling_operator
import
ScalingOperator
from
.operators.simple_linear_operators
import
NullOperator
if
len
(
constants
)
==
0
:
return
Linearization
.
make_var
(
field
,
want_metric
)
else
:
ops
=
[
ScalingOperator
(
0.
if
key
in
constants
else
1.
,
dom
)
for
key
,
dom
in
field
.
domain
.
items
()]
bdop
=
BlockDiagonalOperator
(
fielld
.
domain
,
tuple
(
ops
))
return
Linearization
(
field
,
bdop
,
want_metric
=
want_metric
)
nifty5/minimization/descent_minimizers.py
View file @
09b16aff
...
...
@@ -22,7 +22,7 @@ import numpy as np
from
..compat
import
*
from
..logger
import
logger
from
.line_search
_strong_wolfe
import
LineSearch
StrongWolfe
from
.line_search
import
LineSearch
from
.minimizer
import
Minimizer
...
...
@@ -40,10 +40,10 @@ class DescentMinimizer(Minimizer):
Object that decides when to terminate the minimization.
line_searcher : callable *optional*
Function which infers the step size in the descent direction
(default : LineSearch
StrongWolfe
()).
(default : LineSearch()).
"""
def
__init__
(
self
,
controller
,
line_searcher
=
LineSearch
StrongWolfe
()):
def
__init__
(
self
,
controller
,
line_searcher
=
LineSearch
()):
self
.
_controller
=
controller
self
.
line_searcher
=
line_searcher
...
...
@@ -144,8 +144,7 @@ class RelaxedNewton(DescentMinimizer):
def
__init__
(
self
,
controller
,
line_searcher
=
None
):
if
line_searcher
is
None
:
line_searcher
=
LineSearchStrongWolfe
(
preferred_initial_step_size
=
1.
)
line_searcher
=
LineSearch
(
preferred_initial_step_size
=
1.
)
super
(
RelaxedNewton
,
self
).
__init__
(
controller
=
controller
,
line_searcher
=
line_searcher
)
...
...
@@ -161,8 +160,7 @@ class NewtonCG(DescentMinimizer):
def
__init__
(
self
,
controller
,
line_searcher
=
None
):
if
line_searcher
is
None
:
line_searcher
=
LineSearchStrongWolfe
(
preferred_initial_step_size
=
1.
)
line_searcher
=
LineSearch
(
preferred_initial_step_size
=
1.
)
super
(
NewtonCG
,
self
).
__init__
(
controller
=
controller
,
line_searcher
=
line_searcher
)
...
...
@@ -201,7 +199,7 @@ class NewtonCG(DescentMinimizer):
class
L_BFGS
(
DescentMinimizer
):
def
__init__
(
self
,
controller
,
line_searcher
=
LineSearch
StrongWolfe
(),
def
__init__
(
self
,
controller
,
line_searcher
=
LineSearch
(),
max_history_length
=
5
):
super
(
L_BFGS
,
self
).
__init__
(
controller
=
controller
,
line_searcher
=
line_searcher
)
...
...
@@ -266,7 +264,7 @@ class VL_BFGS(DescentMinimizer):
Microsoft
"""
def
__init__
(
self
,
controller
,
line_searcher
=
LineSearch
StrongWolfe
(),
def
__init__
(
self
,
controller
,
line_searcher
=
LineSearch
(),
max_history_length
=
5
):
super
(
VL_BFGS
,
self
).
__init__
(
controller
=
controller
,
line_searcher
=
line_searcher
)
...
...
nifty5/minimization/energy_adapter.py
View file @
09b16aff
...
...
@@ -13,14 +13,8 @@ class EnergyAdapter(Energy):
self
.
_op
=
op
self
.
_constants
=
constants
self
.
_want_metric
=
want_metric
if
len
(
self
.
_constants
)
==
0
:
tmp
=
self
.
_op
(
Linearization
.
make_var
(
self
.
_position
,
want_metric
))
else
:
ops
=
[
ScalingOperator
(
0.
if
key
in
self
.
_constants
else
1.
,
dom
)
for
key
,
dom
in
self
.
_position
.
domain
.
items
()]
bdop
=
BlockDiagonalOperator
(
self
.
_position
.
domain
,
tuple
(
ops
))
tmp
=
self
.
_op
(
Linearization
(
self
.
_position
,
bdop
,
want_metric
=
want_metric
))
lin
=
Linearization
.
make_partial_var
(
position
,
constants
,
want_metric
)
tmp
=
self
.
_op
(
lin
)
self
.
_val
=
tmp
.
val
.
local_data
[()]
self
.
_grad
=
tmp
.
gradient
self
.
_metric
=
tmp
.
_metric
...
...
nifty5/minimization/iteration_controllers.py
View file @
09b16aff
...
...
@@ -21,6 +21,7 @@ from __future__ import absolute_import, division, print_function
from
..compat
import
*
from
..logger
import
logger
from
..utilities
import
NiftyMetaBase
import
numpy
as
np
class
IterationController
(
NiftyMetaBase
()):
...
...
@@ -145,6 +146,48 @@ class GradientNormController(IterationController):
return
self
.
CONTINUE
class
GradInfNormController
(
IterationController
):
def
__init__
(
self
,
tol
=
None
,
convergence_level
=
1
,
iteration_limit
=
None
,
name
=
None
):
self
.
_tol
=
tol
self
.
_convergence_level
=
convergence_level
self
.
_iteration_limit
=
iteration_limit
self
.
_name
=
name
def
start
(
self
,
energy
):
self
.
_itcount
=
-
1
self
.
_ccount
=
0
return
self
.
check
(
energy
)
def
check
(
self
,
energy
):
self
.
_itcount
+=
1
crit
=
energy
.
gradient
.
norm
(
np
.
inf
)
/
abs
(
energy
.
value
)
if
self
.
_tol
is
not
None
and
crit
<=
self
.
_tol
:
self
.
_ccount
+=
1
else
:
self
.
_ccount
=
max
(
0
,
self
.
_ccount
-
1
)
# report
if
self
.
_name
is
not
None
:
logger
.
info
(
"{}: Iteration #{} energy={:.6E} crit={:.2E} clvl={}"
.
format
(
self
.
_name
,
self
.
_itcount
,
energy
.
value
,
crit
,
self
.
_ccount
))
# Are we done?
if
self
.
_iteration_limit
is
not
None
:
if
self
.
_itcount
>=
self
.
_iteration_limit
:
logger
.
warning
(
"{} Iteration limit reached. Assuming convergence"
.
format
(
""
if
self
.
_name
is
None
else
self
.
_name
+
": "
))
return
self
.
CONVERGED
if
self
.
_ccount
>=
self
.
_convergence_level
:
return
self
.
CONVERGED
return
self
.
CONTINUE
class
DeltaEnergyController
(
IterationController
):
def
__init__
(
self
,
tol_rel_deltaE
,
convergence_level
=
1
,
iteration_limit
=
None
,
name
=
None
):
...
...
nifty5/minimization/kl_energy.py
View file @
09b16aff
...
...
@@ -9,33 +9,32 @@ from .. import utilities
class
KL_Energy
(
Energy
):
def
__init__
(
self
,
position
,
h
,
nsamp
,
constants
=
[],
_samples
=
None
,
want_metric
=
False
):
def
__init__
(
self
,
position
,
h
,
nsamp
,
constants
=
[],
_samples
=
None
):
super
(
KL_Energy
,
self
).
__init__
(
position
)
self
.
_h
=
h
self
.
_constants
=
constants
self
.
_want_metric
=
want_metric
if
_samples
is
None
:
met
=
h
(
Linearization
.
make_var
(
position
,
True
)).
metric
_samples
=
tuple
(
met
.
draw_sample
(
from_inverse
=
True
)
for
_
in
range
(
nsamp
))
self
.
_samples
=
_samples
if
len
(
constants
)
==
0
:
tmp
=
Linearization
.
make_var
(
position
,
want_metric
)
else
:
ops
=
[
ScalingOperator
(
0.
if
key
in
constants
else
1.
,
dom
)
for
key
,
dom
in
position
.
domain
.
items
()]
bdop
=
BlockDiagonalOperator
(
position
.
domain
,
tuple
(
ops
))
tmp
=
Linearization
(
position
,
bdop
,
want_metric
=
want_metric
)
mymap
=
map
(
lambda
v
:
self
.
_h
(
tmp
+
v
),
self
.
_samples
)
tmp
=
utilities
.
my_sum
(
mymap
)
*
(
1.
/
len
(
self
.
_samples
))
self
.
_val
=
tmp
.
val
.
local_data
[()]
self
.
_grad
=
tmp
.
gradient
self
.
_metric
=
tmp
.
metric
self
.
_lin
=
Linearization
.
make_partial_var
(
position
,
constants
)
v
,
g
=
None
,
None
for
s
in
self
.
_samples
:
tmp
=
self
.
_h
(
self
.
_lin
+
s
)
if
v
is
None
:
v
=
tmp
.
val
.
local_data
[()]
g
=
tmp
.
gradient
else
:
v
+=
tmp
.
val
.
local_data
[()]
g
=
g
+
tmp
.
gradient
self
.
_val
=
v
/
len
(
self
.
_samples
)
self
.
_grad
=
g
*
(
1.
/
len
(
self
.
_samples
))
self
.
_metric
=
None
def
at
(
self
,
position
):
return
KL_Energy
(
position
,
self
.
_h
,
0
,
self
.
_constants
,
self
.
_samples
,
self
.
_want_metric
)
return
KL_Energy
(
position
,
self
.
_h
,
0
,
self
.
_constants
,
self
.
_samples
)
@
property
def
value
(
self
):
...
...
@@ -45,11 +44,20 @@ class KL_Energy(Energy):
def
gradient
(
self
):
return
self
.
_grad
def
_get_metric
(
self
):
if
self
.
_metric
is
None
:
lin
=
self
.
_lin
.
with_want_metric
()
mymap
=
map
(
lambda
v
:
self
.
_h
(
lin
+
v
).
metric
,
self
.
_samples
)
self
.
_metric
=
utilities
.
my_sum
(
mymap
)
self
.
_metric
=
self
.
_metric
.
scale
(
1.
/
len
(
self
.
_samples
))
def
apply_metric
(
self
,
x
):
self
.
_get_metric
()
return
self
.
_metric
(
x
)
@
property
def
metric
(
self
):
self
.
_get_metric
()
return
self
.
_metric
@
property
...
...
nifty5/minimization/line_energy.py
deleted
100644 → 0
View file @
8af5299c
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
__future__
import
absolute_import
,
division
,
print_function
from
..compat
import
*
class
LineEnergy
(
object
):
""" Evaluates an underlying Energy along a certain line direction.
Given an Energy class and a line direction, its position is parametrized by
a scalar step size along the descent direction relative to a zero point.
Parameters
----------
line_position : float
Defines the full spatial position of this energy via
self.energy.position = zero_point + line_position*line_direction
energy : Energy
The Energy object which will be evaluated along the given direction.
line_direction : Field
Direction used for line evaluation. Does not have to be normalized.
offset : float *optional*
Indirectly defines the zero point of the line via the equation
energy.position = zero_point + offset*line_direction
(default : 0.).
Notes
-----
The LineEnergy is used in minimization schemes in order perform line
searches. It describes an underlying Energy which is restricted along one
direction, only requiring the step size parameter to determine a new
position.
"""
def
__init__
(
self
,
line_position
,
energy
,
line_direction
,
offset
=
0.
):
self
.
_line_position
=
float
(
line_position
)
self
.
_line_direction
=
line_direction
if
self
.
_line_position
==
float
(
offset
):
self
.
_energy
=
energy
else
:
pos
=
energy
.
position
\
+
(
self
.
_line_position
-
float
(
offset
))
*
self
.
_line_direction
self
.
_energy
=
energy
.
at
(
position
=
pos
)
def
at
(
self
,
line_position
):
""" Returns LineEnergy at new position, memorizing the zero point.
Parameters
----------
line_position : float
Parameter for the new position on the line direction.
Returns
-------
LineEnergy object at new position with same zero point as `self`.
"""
return
LineEnergy
(
line_position
,
self
.
_energy
,
self
.
_line_direction
,
offset
=
self
.
_line_position
)
@
property
def
energy
(
self
):
"""
Energy : The underlying Energy object
"""
return
self
.
_energy
@
property
def
value
(
self
):
"""