Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
1f598ab5
Commit
1f598ab5
authored
Oct 02, 2019
by
Reimar H Leike
Committed by
Martin Reinecke
Oct 02, 2019
Browse files
added metric operator
parent
ffc6059b
Pipeline
#61335
passed with stages
in 9 minutes and 21 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Sidebyside
nifty5/minimization/metric_gaussian_kl_mpi.py
View file @
1f598ab5
...
...
@@ 18,9 +18,12 @@
from
..
import
utilities
from
..linearization
import
Linearization
from
..operators.energy_operators
import
StandardHamiltonian
from
..operators.endomorphic_operator
import
EndomorphicOperator
from
.energy
import
Energy
from
mpi4py
import
MPI
import
numpy
as
np
from
..probing
import
approximation2endo
from
..sugar
import
makeOp
from
..field
import
Field
from
..multi_field
import
MultiField
...
...
@@ 56,10 +59,83 @@ def allreduce_sum_field(fld):
return
MultiField
(
fld
.
domain
,
res
)
class
KLMetric
(
EndomorphicOperator
):
def
__init__
(
self
,
KL
):
self
.
_KL
=
KL
self
.
_capability
=
self
.
TIMES

self
.
ADJOINT_TIMES
self
.
_domain
=
KL
.
position
.
domain
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
return
self
.
_KL
.
apply_metric
(
x
)
def
draw_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
self
.
_KL
.
metric_sample
(
from_inverse
,
dtype
)
class
MetricGaussianKL_MPI
(
Energy
):
"""Provides the sampled KullbackLeibler divergence between a distribution
and a Metric Gaussian.
A Metric Gaussian is used to approximate another probability distribution.
It is a Gaussian distribution that uses the Fisher information metric of
the other distribution at the location of its mean to approximate the
variance. In order to infer the mean, a stochastic estimate of the
KullbackLeibler divergence is minimized. This estimate is obtained by
sampling the Metric Gaussian at the current mean. During minimization
these samples are kept constant; only the mean is updated. Due to the
typically nonlinear structure of the true distribution these samples have
to be updated eventually by intantiating `MetricGaussianKL` again. For the
true probability distribution the standard parametrization is assumed.
The samples of this class are distributed among MPI tasks.
Parameters

mean : Field
Mean of the Gaussian probability distribution.
hamiltonian : StandardHamiltonian
Hamiltonian of the approximated probability distribution.
n_samples : integer
Number of samples used to stochastically estimate the KL.
constants : list
List of parameter keys that are kept constant during optimization.
Default is no constants.
point_estimates : list
List of parameter keys for which no samples are drawn, but that are
(possibly) optimized for, corresponding to point estimates of these.
Default is to draw samples for the complete domain.
mirror_samples : boolean
Whether the negative of the drawn samples are also used,
as they are equally legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. Default is False.
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
_samples : None
Only a parameter for internal uses. Typically not to be set by users.
seed_offset : int
A parameter with which one can controll from which seed the samples
are drawn. Per default, the seed is different for MPI tasks, but the
same every time this class is initialized.
Note

The two lists `constants` and `point_estimates` are independent from each
other. It is possible to sample along domains which are kept constant
during minimization and vice versa.
See also

`Metric Gaussian Variational Inference`, Jakob Knollmüller,
Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
"""
def
__init__
(
self
,
mean
,
hamiltonian
,
n_samples
,
constants
=
[],
point_estimates
=
[],
mirror_samples
=
False
,
_samples
=
None
,
seed_offset
=
0
):
napprox
=
0
,
_samples
=
None
,
seed_offset
=
0
):
super
(
MetricGaussianKL_MPI
,
self
).
__init__
(
mean
)
if
not
isinstance
(
hamiltonian
,
StandardHamiltonian
):
...
...
@@ 82,6 +158,8 @@ class MetricGaussianKL_MPI(Energy):
lo
,
hi
=
_shareRange
(
n_samples
,
ntask
,
rank
)
met
=
hamiltonian
(
Linearization
.
make_partial_var
(
mean
,
point_estimates
,
True
)).
metric
if
napprox
>
1
:
met
.
_approximation
=
makeOp
(
approximation2endo
(
met
,
napprox
))
_samples
=
[]
for
i
in
range
(
lo
,
hi
):
if
mirror_samples
:
...
...
@@ 142,8 +220,8 @@ class MetricGaussianKL_MPI(Energy):
else
:
mymap
=
map
(
lambda
v
:
self
.
_hamiltonian
(
lin
+
v
).
metric
,
self
.
_samples
)
self
.
_metric
=
utilities
.
my_sum
(
mymap
)
self
.
_metric
=
self
.
_metric
.
scale
(
1.
/
self
.
_n_samples
)
self
.
unscaled
_metric
=
utilities
.
my_sum
(
mymap
)
self
.
_metric
=
self
.
unscaled
_metric
.
scale
(
1.
/
self
.
_n_samples
)
def
apply_metric
(
self
,
x
):
self
.
_get_metric
()
...
...
@@ 151,12 +229,22 @@ class MetricGaussianKL_MPI(Energy):
@
property
def
metric
(
self
):
if
ntask
>
1
:
raise
ValueError
(
"not supported when MPI is active"
)
return
self
.
_metric
return
KLMetric
(
self
)
@
property
def
samples
(
self
):
res
=
_comm
.
allgather
(
self
.
_samples
)
res
=
[
item
for
sublist
in
res
for
item
in
sublist
]
return
res
def
unscaled_metric_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
if
from_inverse
:
raise
NotImplementedError
()
lin
=
self
.
_lin
.
with_want_metric
()
samp
=
ift
.
full
(
self
.
_hamiltonian
.
domain
,
0.
)
for
s
in
self
.
_samples
:
samp
=
samp
+
self
.
_hamiltonian
(
lin
+
v
).
metric
.
draw_sample
(
dtype
)
return
allreduce_sum_field
(
samp
)
def
metric_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
return
self
.
unscaled_metric_sample
(
from_inverse
,
dtype
)
/
self
.
_n_samples
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment