Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
6be924b0
Commit
6be924b0
authored
Jun 03, 2021
by
Philipp Arras
Browse files
Docs
parent
9a6f726b
Changes
1
Hide whitespace changes
Inline
Side-by-side
src/library/variational_models.py
View file @
6be924b0
...
...
@@ -34,12 +34,29 @@ from ..utilities import myassert
class
MeanFieldVI
:
def
__init__
(
self
,
initial_position
,
hamiltonian
,
n_samples
,
mirror_samples
,
"""Collect the operators required for Gaussian meanfield variational
inference.
Parameters
----------
position :
FIXME
hamiltonian :
FIXME
n_samples :
FIXME
mirror_samples :
FIXME
initial_sig :
FIXME
comm :
FIXME
nanisinf :
FIXME
"""
def
__init__
(
self
,
position
,
hamiltonian
,
n_samples
,
mirror_samples
,
initial_sig
=
1
,
comm
=
None
,
nanisinf
=
False
):
"""Collect the operators required for Gaussian mean-field variational
inference.
"""
Flat
=
Multifield2Vector
(
initial_position
.
domain
)
Flat
=
Multifield2Vector
(
position
.
domain
)
self
.
_std
=
FieldAdapter
(
Flat
.
target
,
'std'
).
absolute
()
latent
=
FieldAdapter
(
Flat
.
target
,
'latent'
)
self
.
_mean
=
FieldAdapter
(
Flat
.
target
,
'mean'
)
...
...
@@ -47,7 +64,7 @@ class MeanFieldVI:
self
.
_entropy
=
GaussianEntropy
(
self
.
_std
.
target
)
@
self
.
_std
self
.
_mean
=
Flat
.
adjoint
@
self
.
_mean
self
.
_std
=
Flat
.
adjoint
@
self
.
_std
pos
=
{
'mean'
:
Flat
(
initial_
position
)}
pos
=
{
'mean'
:
Flat
(
position
)}
if
is_fieldlike
(
initial_sig
):
pos
[
'std'
]
=
Flat
(
initial_sig
)
else
:
...
...
@@ -78,12 +95,30 @@ class MeanFieldVI:
def
minimize
(
self
,
minimizer
):
self
.
_KL
,
_
=
minimizer
(
self
.
_KL
)
class
FullCovarianceVI
:
"""Collect the operators required for Gaussian full-covariance variational
inference.
Parameters
----------
position :
FIXME
hamiltonian :
FIXME
n_samples :
FIXME
mirror_samples :
FIXME
initial_sig :
FIXME
comm :
FIXME
nanisinf :
FIXME
"""
def
__init__
(
self
,
position
,
hamiltonian
,
n_samples
,
mirror_samples
,
initial_sig
=
1
,
comm
=
None
,
nanisinf
=
False
):
"""Collect the operators required for Gaussian full-covariance variational
inference.
"""
Flat
=
Multifield2Vector
(
position
.
domain
)
flat_domain
=
Flat
.
target
[
0
]
mat_space
=
DomainTuple
.
make
((
flat_domain
,
flat_domain
))
...
...
@@ -128,8 +163,8 @@ class FullCovarianceVI:
class
GaussianEntropy
(
EnergyOperator
):
"""
Calculate the e
ntropy of a Gaussian distribution given the diagonal of a
triangular
decomposition of the covariance.
"""
E
ntropy of a Gaussian distribution given the diagonal of a
triangular
decomposition of the covariance.
Parameters
----------
...
...
@@ -152,7 +187,7 @@ class GaussianEntropy(EnergyOperator):
class
LowerTriangularInserter
(
LinearOperator
):
"""Insert
s
the
DOF
s of a lower triangular matrix into a matrix.
"""Insert the
entrie
s of a lower triangular matrix into a matrix.
Parameters
----------
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment