Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
007fb8dc
Commit
007fb8dc
authored
Jul 27, 2018
by
Martin Reinecke
Browse files
adjust KL
parent
34a385d9
Changes
3
Hide whitespace changes
Inline
Side-by-side
demos/getting_started_3b.py
View file @
007fb8dc
...
...
@@ -97,9 +97,8 @@ if __name__ == '__main__':
# build model Hamiltonian
H
=
ift
.
Hamiltonian
(
likelihood
,
ic_sampling
)
H
=
EnergyAdapter
(
MOCK_POSITION
,
H
)
INITIAL_POSITION
=
ift
.
from_random
(
'normal'
,
H
.
position
.
domain
)
INITIAL_POSITION
=
ift
.
from_random
(
'normal'
,
domain
)
position
=
INITIAL_POSITION
ift
.
plot
(
signal
(
MOCK_POSITION
),
title
=
'ground truth'
)
...
...
@@ -110,11 +109,12 @@ if __name__ == '__main__':
# number of samples used to estimate the KL
N_samples
=
20
for
i
in
range
(
2
):
H
=
H
.
at
(
position
)
samples
=
[
H
.
metric
.
draw_sample
(
from_inverse
=
True
)
metric
=
H
(
ift
.
Linearization
.
make_var
(
position
)
).
metric
samples
=
[
metric
.
draw_sample
(
from_inverse
=
True
)
for
_
in
range
(
N_samples
)]
KL
=
ift
.
SampledKullbachLeiblerDivergence
(
H
,
samples
)
KL
=
EnergyAdapter
(
position
,
KL
)
KL
=
KL
.
make_invertible
(
ic_cg
)
KL
,
convergence
=
minimizer
(
KL
)
position
=
KL
.
position
...
...
nifty5/energies/kl.py
View file @
007fb8dc
...
...
@@ -19,41 +19,21 @@
from
__future__
import
absolute_import
,
division
,
print_function
from
..compat
import
*
from
..
minimization.energy
import
Energy
from
..utilities
import
memo
,
my_sum
from
..
operator
import
Operator
from
..utilities
import
my_sum
class
SampledKullbachLeiblerDivergence
(
Energy
):
class
SampledKullbachLeiblerDivergence
(
Operator
):
def
__init__
(
self
,
h
,
res_samples
):
"""
# MR FIXME: does h have to be a Hamiltonian? Couldn't it be any energy?
h: Hamiltonian
N: Number of samples to be used
"""
super
(
SampledKullbachLeiblerDivergence
,
self
).
__init__
(
h
.
position
)
super
(
SampledKullbachLeiblerDivergence
,
self
).
__init__
()
self
.
_h
=
h
self
.
_res_samples
=
res_samples
self
.
_res_samples
=
tuple
(
res_samples
)
self
.
_energy_list
=
tuple
(
h
.
at
(
self
.
position
+
ss
)
for
ss
in
res_samples
)
def
at
(
self
,
position
):
return
self
.
__class__
(
self
.
_h
.
at
(
position
),
self
.
_res_samples
)
@
property
@
memo
def
value
(
self
):
return
(
my_sum
(
map
(
lambda
v
:
v
.
value
,
self
.
_energy_list
))
/
len
(
self
.
_energy_list
))
@
property
@
memo
def
gradient
(
self
):
return
(
my_sum
(
map
(
lambda
v
:
v
.
gradient
,
self
.
_energy_list
))
/
len
(
self
.
_energy_list
))
@
property
@
memo
def
metric
(
self
):
return
(
my_sum
(
map
(
lambda
v
:
v
.
metric
,
self
.
_energy_list
))
*
(
1.
/
len
(
self
.
_energy_list
)))
def
__call__
(
self
,
x
):
return
(
my_sum
(
map
(
lambda
v
:
self
.
_h
(
x
+
v
),
self
.
_res_samples
))
*
(
1.
/
len
(
self
.
_res_samples
)))
nifty5/linearization.py
View file @
007fb8dc
...
...
@@ -82,7 +82,8 @@ class Linearization(object):
if
isinstance
(
other
,
(
int
,
float
,
complex
)):
# if other == 0:
# return ...
return
Linearization
(
self
.
_val
*
other
,
self
.
_jac
*
other
)
met
=
None
if
self
.
_metric
is
None
else
self
.
_metric
*
other
return
Linearization
(
self
.
_val
*
other
,
self
.
_jac
*
other
,
met
)
if
isinstance
(
other
,
(
Field
,
MultiField
)):
d2
=
makeOp
(
other
)
return
Linearization
(
self
.
_val
*
other
,
d2
*
self
.
_jac
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment