Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
6251760d
Commit
6251760d
authored
May 29, 2020
by
Philipp Arras
Browse files
Simplify KL
parent
049e7fa9
Pipeline
#75762
passed with stages
in 23 minutes and 29 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/minimization/metric_gaussian_kl.py
View file @
6251760d
...
@@ -19,6 +19,7 @@ import numpy as np
...
@@ -19,6 +19,7 @@ import numpy as np
from
..
import
random
from
..
import
random
from
..linearization
import
Linearization
from
..linearization
import
Linearization
from
..logger
import
logger
from
..multi_field
import
MultiField
from
..multi_field
import
MultiField
from
..operators.endomorphic_operator
import
EndomorphicOperator
from
..operators.endomorphic_operator
import
EndomorphicOperator
from
..operators.energy_operators
import
StandardHamiltonian
from
..operators.energy_operators
import
StandardHamiltonian
...
@@ -115,7 +116,7 @@ class MetricGaussianKL(Energy):
...
@@ -115,7 +116,7 @@ class MetricGaussianKL(Energy):
def
__init__
(
self
,
mean
,
hamiltonian
,
n_samples
,
constants
=
[],
def
__init__
(
self
,
mean
,
hamiltonian
,
n_samples
,
constants
=
[],
point_estimates
=
[],
mirror_samples
=
False
,
point_estimates
=
[],
mirror_samples
=
False
,
napprox
=
0
,
comm
=
None
,
_local_samples
=
None
,
napprox
=
0
,
comm
=
None
,
_local_samples
=
None
,
nanisinf
=
False
,
_ham4eval
=
None
):
nanisinf
=
False
):
super
(
MetricGaussianKL
,
self
).
__init__
(
mean
)
super
(
MetricGaussianKL
,
self
).
__init__
(
mean
)
if
not
isinstance
(
hamiltonian
,
StandardHamiltonian
):
if
not
isinstance
(
hamiltonian
,
StandardHamiltonian
):
...
@@ -124,8 +125,6 @@ class MetricGaussianKL(Energy):
...
@@ -124,8 +125,6 @@ class MetricGaussianKL(Energy):
raise
ValueError
raise
ValueError
if
not
isinstance
(
n_samples
,
int
):
if
not
isinstance
(
n_samples
,
int
):
raise
TypeError
raise
TypeError
self
.
_constants
=
tuple
(
constants
)
self
.
_point_estimates
=
tuple
(
point_estimates
)
self
.
_mitigate_nans
=
nanisinf
self
.
_mitigate_nans
=
nanisinf
if
not
isinstance
(
mirror_samples
,
bool
):
if
not
isinstance
(
mirror_samples
,
bool
):
raise
TypeError
raise
TypeError
...
@@ -134,15 +133,11 @@ class MetricGaussianKL(Energy):
...
@@ -134,15 +133,11 @@ class MetricGaussianKL(Energy):
'Point estimates for whole domain. Use EnergyAdapter instead.'
)
'Point estimates for whole domain. Use EnergyAdapter instead.'
)
self
.
_hamiltonian
=
hamiltonian
self
.
_hamiltonian
=
hamiltonian
self
.
_ham4eval
=
_ham4eval
if
len
(
constants
)
>
0
:
if
self
.
_ham4eval
is
None
:
dom
=
{
kk
:
vv
for
kk
,
vv
in
mean
.
domain
.
items
()
if
kk
in
constants
}
if
len
(
constants
)
>
0
:
dom
=
makeDomain
(
dom
)
dom
=
{
kk
:
vv
for
kk
,
vv
in
mean
.
domain
.
items
()
if
kk
in
constants
}
cstpos
=
mean
.
extract
(
dom
)
dom
=
makeDomain
(
dom
)
_
,
self
.
_hamiltonian
=
hamiltonian
.
simplify_for_constant_input
(
cstpos
)
cstpos
=
mean
.
extract
(
dom
)
_
,
self
.
_ham4eval
=
hamiltonian
.
simplify_for_constant_input
(
cstpos
)
else
:
self
.
_ham4eval
=
hamiltonian
self
.
_n_samples
=
int
(
n_samples
)
self
.
_n_samples
=
int
(
n_samples
)
if
comm
is
not
None
:
if
comm
is
not
None
:
...
@@ -160,14 +155,13 @@ class MetricGaussianKL(Energy):
...
@@ -160,14 +155,13 @@ class MetricGaussianKL(Energy):
self
.
_n_eff_samples
*=
2
self
.
_n_eff_samples
*=
2
if
_local_samples
is
None
:
if
_local_samples
is
None
:
sample_hamiltonian
=
hamiltonian
if
len
(
point_estimates
)
>
0
:
if
len
(
point_estimates
)
>
0
:
dom
=
{
kk
:
vv
for
kk
,
vv
in
mean
.
domain
.
items
()
dom
=
{
kk
:
vv
for
kk
,
vv
in
mean
.
domain
.
items
()
if
kk
in
point_estimates
}
if
kk
in
point_estimates
}
dom
=
makeDomain
(
dom
)
dom
=
makeDomain
(
dom
)
cstpos
=
mean
.
extract
(
dom
)
cstpos
=
mean
.
extract
(
dom
)
_
,
sample_
hamiltonian
=
hamiltonian
.
simplify_for_constant_input
(
cstpos
)
_
,
hamiltonian
=
hamiltonian
.
simplify_for_constant_input
(
cstpos
)
met
=
sample_
hamiltonian
(
Linearization
.
make_var
(
mean
,
True
)).
metric
met
=
hamiltonian
(
Linearization
.
make_var
(
mean
,
True
)).
metric
if
napprox
>=
1
:
if
napprox
>=
1
:
met
.
_approximation
=
makeOp
(
approximation2endo
(
met
,
napprox
))
met
.
_approximation
=
makeOp
(
approximation2endo
(
met
,
napprox
))
_local_samples
=
[]
_local_samples
=
[]
...
@@ -183,27 +177,26 @@ class MetricGaussianKL(Energy):
...
@@ -183,27 +177,26 @@ class MetricGaussianKL(Energy):
self
.
_lin
=
Linearization
.
make_var
(
mean
)
self
.
_lin
=
Linearization
.
make_var
(
mean
)
v
,
g
=
[],
[]
v
,
g
=
[],
[]
for
s
in
self
.
_local_samples
:
for
s
in
self
.
_local_samples
:
tmp
=
self
.
_ham
4eval
(
self
.
_lin
+
s
)
tmp
=
self
.
_ham
iltonian
(
self
.
_lin
+
s
)
tv
=
tmp
.
val
.
val
tv
=
tmp
.
val
.
val
tg
=
tmp
.
gradient
tg
=
tmp
.
gradient
if
self
.
_mirror_samples
:
if
self
.
_mirror_samples
:
tmp
=
self
.
_ham
4eval
(
self
.
_lin
-
s
)
tmp
=
self
.
_ham
iltonian
(
self
.
_lin
-
s
)
tv
=
tv
+
tmp
.
val
.
val
tv
=
tv
+
tmp
.
val
.
val
tg
=
tg
+
tmp
.
gradient
tg
=
tg
+
tmp
.
gradient
v
.
append
(
tv
)
v
.
append
(
tv
)
g
.
append
(
tg
)
g
.
append
(
tg
)
self
.
_val
=
self
.
_sumup
(
v
)[()]
/
self
.
_n_eff_samples
self
.
_val
=
self
.
_sumup
(
v
)[()]
/
self
.
_n_eff_samples
if
np
.
isnan
(
self
.
_val
)
and
self
.
_mitigate_nans
:
if
self
.
_mitigate_nans
and
np
.
isnan
(
self
.
_val
)
:
self
.
_val
=
np
.
inf
self
.
_val
=
np
.
inf
self
.
_grad
=
self
.
_sumup
(
g
)
/
self
.
_n_eff_samples
self
.
_grad
=
self
.
_sumup
(
g
)
/
self
.
_n_eff_samples
self
.
_metric
=
None
self
.
_metric
=
None
def
at
(
self
,
position
):
def
at
(
self
,
position
):
return
MetricGaussianKL
(
return
MetricGaussianKL
(
position
,
self
.
_hamiltonian
,
self
.
_n_samples
,
self
.
_constants
,
position
,
self
.
_hamiltonian
,
self
.
_n_samples
,
self
.
_point_estimates
,
self
.
_mirror_samples
,
comm
=
self
.
_comm
,
mirror_samples
=
self
.
_mirror_samples
,
comm
=
self
.
_comm
,
_local_samples
=
self
.
_local_samples
,
nanisinf
=
self
.
_mitigate_nans
,
_local_samples
=
self
.
_local_samples
,
nanisinf
=
self
.
_mitigate_nans
)
_ham4eval
=
self
.
_ham4eval
)
@
property
@
property
def
value
(
self
):
def
value
(
self
):
...
@@ -217,9 +210,9 @@ class MetricGaussianKL(Energy):
...
@@ -217,9 +210,9 @@ class MetricGaussianKL(Energy):
lin
=
self
.
_lin
.
with_want_metric
()
lin
=
self
.
_lin
.
with_want_metric
()
res
=
[]
res
=
[]
for
s
in
self
.
_local_samples
:
for
s
in
self
.
_local_samples
:
tmp
=
self
.
_ham
4eval
(
lin
+
s
).
metric
(
x
)
tmp
=
self
.
_ham
iltonian
(
lin
+
s
).
metric
(
x
)
if
self
.
_mirror_samples
:
if
self
.
_mirror_samples
:
tmp
=
tmp
+
self
.
_ham
4eval
(
lin
-
s
).
metric
(
x
)
tmp
=
tmp
+
self
.
_ham
iltonian
(
lin
-
s
).
metric
(
x
)
res
.
append
(
tmp
)
res
.
append
(
tmp
)
return
self
.
_sumup
(
res
)
/
self
.
_n_eff_samples
return
self
.
_sumup
(
res
)
/
self
.
_n_eff_samples
...
@@ -268,6 +261,10 @@ class MetricGaussianKL(Energy):
...
@@ -268,6 +261,10 @@ class MetricGaussianKL(Energy):
def
_metric_sample
(
self
,
from_inverse
=
False
):
def
_metric_sample
(
self
,
from_inverse
=
False
):
if
from_inverse
:
if
from_inverse
:
raise
NotImplementedError
()
raise
NotImplementedError
()
s
=
(
'This draws from the Hamiltonian used for evaluation and does '
' not take point_estimates into accout. Make sure that this '
'is your intended use.'
)
logger
.
warning
(
s
)
lin
=
self
.
_lin
.
with_want_metric
()
lin
=
self
.
_lin
.
with_want_metric
()
samp
=
[]
samp
=
[]
sseq
=
random
.
spawn_sseq
(
self
.
_n_samples
)
sseq
=
random
.
spawn_sseq
(
self
.
_n_samples
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment