Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
a2916703
Commit
a2916703
authored
Aug 20, 2018
by
Martin Reinecke
Browse files
Introduce KL_Energy (which might be parallelized in the future)
parent
747e2082
Changes
4
Hide whitespace changes
Inline
Side-by-side
demos/getting_started_3.py
View file @
a2916703
...
...
@@ -91,26 +91,21 @@ if __name__ == '__main__':
# number of samples used to estimate the KL
N_samples
=
20
for
i
in
range
(
2
):
metric
=
H
(
ift
.
Linearization
.
make_var
(
position
)).
metric
samples
=
[
metric
.
draw_sample
(
from_inverse
=
True
)
for
_
in
range
(
N_samples
)]
KL
=
ift
.
SampledKullbachLeiblerDivergence
(
H
,
samples
)
KL
=
ift
.
EnergyAdapter
(
position
,
KL
)
KL
=
ift
.
KL_Energy
(
position
,
H
,
N_samples
)
KL
,
convergence
=
minimizer
(
KL
)
position
=
KL
.
position
ift
.
plot
(
signal
(
position
),
title
=
"reconstruction"
)
ift
.
plot
([
A
(
position
),
A
(
MOCK_POSITION
)],
title
=
"power"
)
ift
.
plot
(
signal
(
KL
.
position
),
title
=
"reconstruction"
)
ift
.
plot
([
A
(
KL
.
position
),
A
(
MOCK_POSITION
)],
title
=
"power"
)
ift
.
plot_finish
(
nx
=
2
,
xsize
=
12
,
ysize
=
6
,
title
=
"loop"
,
name
=
"loop.png"
)
sc
=
ift
.
StatCalculator
()
for
sample
in
samples
:
sc
.
add
(
signal
(
sample
+
position
))
for
sample
in
KL
.
samples
:
sc
.
add
(
signal
(
sample
+
KL
.
position
))
ift
.
plot
(
sc
.
mean
,
title
=
"mean"
)
ift
.
plot
(
ift
.
sqrt
(
sc
.
var
),
title
=
"std deviation"
)
powers
=
[
A
(
s
+
position
)
for
s
in
samples
]
ift
.
plot
([
A
(
position
),
A
(
MOCK_POSITION
)]
+
powers
,
title
=
"power"
)
powers
=
[
A
(
s
+
KL
.
position
)
for
s
in
KL
.
samples
]
ift
.
plot
([
A
(
KL
.
position
),
A
(
MOCK_POSITION
)]
+
powers
,
title
=
"power"
)
ift
.
plot_finish
(
nx
=
3
,
xsize
=
16
,
ysize
=
5
,
title
=
"results"
,
name
=
"results.png"
)
nifty5/__init__.py
View file @
a2916703
...
...
@@ -66,6 +66,7 @@ from .minimization.energy import Energy
from
.minimization.quadratic_energy
import
QuadraticEnergy
from
.minimization.line_energy
import
LineEnergy
from
.minimization.energy_adapter
import
EnergyAdapter
from
.minimization.kl_energy
import
KL_Energy
from
.sugar
import
*
from
.plotting.plot
import
plot
,
plot_finish
...
...
nifty5/minimization/energy_adapter.py
View file @
a2916703
...
...
@@ -26,18 +26,6 @@ class EnergyAdapter(Energy):
def
at
(
self
,
position
):
return
EnergyAdapter
(
position
,
self
.
_op
,
self
.
_constants
)
def
_fill_all
(
self
):
if
len
(
self
.
_constants
)
==
0
:
tmp
=
self
.
_op
(
Linearization
.
make_var
(
self
.
_position
))
else
:
ops
=
[
ScalingOperator
(
0.
if
key
in
self
.
_constants
else
1.
,
dom
)
for
key
,
dom
in
self
.
_position
.
domain
.
items
()]
bdop
=
BlockDiagonalOperator
(
self
.
_position
.
domain
,
tuple
(
ops
))
tmp
=
self
.
_op
(
Linearization
(
self
.
_position
,
bdop
))
self
.
_val
=
tmp
.
val
.
local_data
[()]
self
.
_grad
=
tmp
.
gradient
self
.
_metric
=
tmp
.
_metric
@
property
def
value
(
self
):
return
self
.
_val
...
...
nifty5/minimization/kl_energy.py
0 → 100644
View file @
a2916703
from
__future__
import
absolute_import
,
division
,
print_function
from
..compat
import
*
from
.energy
import
Energy
from
..linearization
import
Linearization
from
..operators.scaling_operator
import
ScalingOperator
from
..operators.block_diagonal_operator
import
BlockDiagonalOperator
from
..
import
utilities
class
KL_Energy
(
Energy
):
def
__init__
(
self
,
position
,
h
,
nsamp
,
constants
=
[],
_samples
=
None
):
super
(
KL_Energy
,
self
).
__init__
(
position
)
self
.
_h
=
h
self
.
_constants
=
constants
if
_samples
is
None
:
met
=
h
(
Linearization
.
make_var
(
position
)).
metric
_samples
=
tuple
(
met
.
draw_sample
(
from_inverse
=
True
)
for
_
in
range
(
nsamp
))
self
.
_samples
=
_samples
if
len
(
constants
)
==
0
:
tmp
=
Linearization
.
make_var
(
position
)
else
:
ops
=
[
ScalingOperator
(
0.
if
key
in
constants
else
1.
,
dom
)
for
key
,
dom
in
position
.
domain
.
items
()]
bdop
=
BlockDiagonalOperator
(
position
.
domain
,
tuple
(
ops
))
tmp
=
Linearization
(
position
,
bdop
)
mymap
=
map
(
lambda
v
:
self
.
_h
(
tmp
+
v
),
self
.
_samples
)
tmp
=
utilities
.
my_sum
(
mymap
)
*
(
1.
/
len
(
self
.
_samples
))
self
.
_val
=
tmp
.
val
.
local_data
[()]
self
.
_grad
=
tmp
.
gradient
self
.
_metric
=
tmp
.
metric
def
at
(
self
,
position
):
return
KL_Energy
(
position
,
self
.
_h
,
0
,
self
.
_constants
,
self
.
_samples
)
@
property
def
value
(
self
):
return
self
.
_val
@
property
def
gradient
(
self
):
return
self
.
_grad
def
apply_metric
(
self
,
x
):
return
self
.
_metric
(
x
)
@
property
def
samples
(
self
):
return
self
.
_samples
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment