Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
6251760d
Commit
6251760d
authored
May 29, 2020
by
Philipp Arras
Browse files
Simplify KL
parent
049e7fa9
Pipeline
#75762
passed with stages
in 23 minutes and 29 seconds
Changes
1
Pipelines
1
Show whitespace changes
Inline
Side-by-side
src/minimization/metric_gaussian_kl.py
View file @
6251760d
...
...
@@ -19,6 +19,7 @@ import numpy as np
from
..
import
random
from
..linearization
import
Linearization
from
..logger
import
logger
from
..multi_field
import
MultiField
from
..operators.endomorphic_operator
import
EndomorphicOperator
from
..operators.energy_operators
import
StandardHamiltonian
...
...
@@ -115,7 +116,7 @@ class MetricGaussianKL(Energy):
def
__init__
(
self
,
mean
,
hamiltonian
,
n_samples
,
constants
=
[],
point_estimates
=
[],
mirror_samples
=
False
,
napprox
=
0
,
comm
=
None
,
_local_samples
=
None
,
nanisinf
=
False
,
_ham4eval
=
None
):
nanisinf
=
False
):
super
(
MetricGaussianKL
,
self
).
__init__
(
mean
)
if
not
isinstance
(
hamiltonian
,
StandardHamiltonian
):
...
...
@@ -124,8 +125,6 @@ class MetricGaussianKL(Energy):
raise
ValueError
if
not
isinstance
(
n_samples
,
int
):
raise
TypeError
self
.
_constants
=
tuple
(
constants
)
self
.
_point_estimates
=
tuple
(
point_estimates
)
self
.
_mitigate_nans
=
nanisinf
if
not
isinstance
(
mirror_samples
,
bool
):
raise
TypeError
...
...
@@ -134,15 +133,11 @@ class MetricGaussianKL(Energy):
'Point estimates for whole domain. Use EnergyAdapter instead.'
)
self
.
_hamiltonian
=
hamiltonian
self
.
_ham4eval
=
_ham4eval
if
self
.
_ham4eval
is
None
:
if
len
(
constants
)
>
0
:
dom
=
{
kk
:
vv
for
kk
,
vv
in
mean
.
domain
.
items
()
if
kk
in
constants
}
dom
=
makeDomain
(
dom
)
cstpos
=
mean
.
extract
(
dom
)
_
,
self
.
_ham4eval
=
hamiltonian
.
simplify_for_constant_input
(
cstpos
)
else
:
self
.
_ham4eval
=
hamiltonian
_
,
self
.
_hamiltonian
=
hamiltonian
.
simplify_for_constant_input
(
cstpos
)
self
.
_n_samples
=
int
(
n_samples
)
if
comm
is
not
None
:
...
...
@@ -160,14 +155,13 @@ class MetricGaussianKL(Energy):
self
.
_n_eff_samples
*=
2
if
_local_samples
is
None
:
sample_hamiltonian
=
hamiltonian
if
len
(
point_estimates
)
>
0
:
dom
=
{
kk
:
vv
for
kk
,
vv
in
mean
.
domain
.
items
()
if
kk
in
point_estimates
}
dom
=
makeDomain
(
dom
)
cstpos
=
mean
.
extract
(
dom
)
_
,
sample_
hamiltonian
=
hamiltonian
.
simplify_for_constant_input
(
cstpos
)
met
=
sample_
hamiltonian
(
Linearization
.
make_var
(
mean
,
True
)).
metric
_
,
hamiltonian
=
hamiltonian
.
simplify_for_constant_input
(
cstpos
)
met
=
hamiltonian
(
Linearization
.
make_var
(
mean
,
True
)).
metric
if
napprox
>=
1
:
met
.
_approximation
=
makeOp
(
approximation2endo
(
met
,
napprox
))
_local_samples
=
[]
...
...
@@ -183,27 +177,26 @@ class MetricGaussianKL(Energy):
self
.
_lin
=
Linearization
.
make_var
(
mean
)
v
,
g
=
[],
[]
for
s
in
self
.
_local_samples
:
tmp
=
self
.
_ham
4eval
(
self
.
_lin
+
s
)
tmp
=
self
.
_ham
iltonian
(
self
.
_lin
+
s
)
tv
=
tmp
.
val
.
val
tg
=
tmp
.
gradient
if
self
.
_mirror_samples
:
tmp
=
self
.
_ham
4eval
(
self
.
_lin
-
s
)
tmp
=
self
.
_ham
iltonian
(
self
.
_lin
-
s
)
tv
=
tv
+
tmp
.
val
.
val
tg
=
tg
+
tmp
.
gradient
v
.
append
(
tv
)
g
.
append
(
tg
)
self
.
_val
=
self
.
_sumup
(
v
)[()]
/
self
.
_n_eff_samples
if
np
.
isnan
(
self
.
_val
)
and
self
.
_mitigate_nans
:
if
self
.
_mitigate_nans
and
np
.
isnan
(
self
.
_val
)
:
self
.
_val
=
np
.
inf
self
.
_grad
=
self
.
_sumup
(
g
)
/
self
.
_n_eff_samples
self
.
_metric
=
None
def
at
(
self
,
position
):
return
MetricGaussianKL
(
position
,
self
.
_hamiltonian
,
self
.
_n_samples
,
self
.
_constants
,
self
.
_point_estimates
,
self
.
_mirror_samples
,
comm
=
self
.
_comm
,
_local_samples
=
self
.
_local_samples
,
nanisinf
=
self
.
_mitigate_nans
,
_ham4eval
=
self
.
_ham4eval
)
position
,
self
.
_hamiltonian
,
self
.
_n_samples
,
mirror_samples
=
self
.
_mirror_samples
,
comm
=
self
.
_comm
,
_local_samples
=
self
.
_local_samples
,
nanisinf
=
self
.
_mitigate_nans
)
@
property
def
value
(
self
):
...
...
@@ -217,9 +210,9 @@ class MetricGaussianKL(Energy):
lin
=
self
.
_lin
.
with_want_metric
()
res
=
[]
for
s
in
self
.
_local_samples
:
tmp
=
self
.
_ham
4eval
(
lin
+
s
).
metric
(
x
)
tmp
=
self
.
_ham
iltonian
(
lin
+
s
).
metric
(
x
)
if
self
.
_mirror_samples
:
tmp
=
tmp
+
self
.
_ham
4eval
(
lin
-
s
).
metric
(
x
)
tmp
=
tmp
+
self
.
_ham
iltonian
(
lin
-
s
).
metric
(
x
)
res
.
append
(
tmp
)
return
self
.
_sumup
(
res
)
/
self
.
_n_eff_samples
...
...
@@ -268,6 +261,10 @@ class MetricGaussianKL(Energy):
def
_metric_sample
(
self
,
from_inverse
=
False
):
if
from_inverse
:
raise
NotImplementedError
()
s
=
(
'This draws from the Hamiltonian used for evaluation and does '
' not take point_estimates into accout. Make sure that this '
'is your intended use.'
)
logger
.
warning
(
s
)
lin
=
self
.
_lin
.
with_want_metric
()
samp
=
[]
sseq
=
random
.
spawn_sseq
(
self
.
_n_samples
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment