Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
59a009cb
Commit
59a009cb
authored
Jun 01, 2021
by
Philipp Arras
Browse files
Tweak docs
parent
0d3e909a
Changes
2
Hide whitespace changes
Inline
Side-by-side
src/library/variational_models.py
View file @
59a009cb
...
...
@@ -33,14 +33,15 @@ from ..sugar import domain_union, from_random, full, makeField
class
MeanfieldModel
():
'''
Collects the operators required for Gaussian mean-field variational
inference.
"""Collect the operators required for Gaussian mean-field variational
inference.
Parameters
----------
domain: MultiDomain
The domain of the model parameters.
'''
The domain of the model parameters.
"""
def
__init__
(
self
,
domain
):
self
.
domain
=
MultiDomain
.
make
(
domain
)
self
.
Flat
=
Multifield2Vector
(
self
.
domain
)
...
...
@@ -52,17 +53,18 @@ class MeanfieldModel():
self
.
entropy
=
GaussianEntropy
(
self
.
std
.
target
)
@
self
.
std
def
get_initial_pos
(
self
,
initial_mean
=
None
,
initial_sig
=
1
):
'''
Provides an initial position for a given mean parameter vector and an
initial standard deviation.
"""Provide an initial position for a given mean parameter vector and an
initial standard deviation.
Parameters
----------
initial_mean: MultiField
The initial mean of the variational approximation. If not None, a Gaussian sample with mean zero and standard deviation of 0.1 is used.
Default: None
The initial mean of the variational approximation. If not None, a
Gaussian sample with mean zero and standard deviation of 0.1 is
used. Default: None
initial_sig: positive float
The initial standard deviation shared by all parameters. Default: 1
'''
The initial standard deviation shared by all parameters. Default: 1
"""
initial_pos
=
from_random
(
self
.
generator
.
domain
).
to_dict
()
initial_pos
[
'latent'
]
=
full
(
self
.
generator
.
domain
[
'latent'
],
0.
)
...
...
@@ -76,14 +78,15 @@ class MeanfieldModel():
class
FullCovarianceModel
():
'''
Collects the operators required for Gaussian full-covariance variational
inference.
"""Collect the operators required for Gaussian full-covariance variational
inference.
Parameters
----------
domain: MultiDomain
The domain of the model parameters.
'''
The domain of the model parameters.
"""
def
__init__
(
self
,
domain
):
self
.
domain
=
MultiDomain
.
make
(
domain
)
self
.
Flat
=
Multifield2Vector
(
self
.
domain
)
...
...
@@ -108,23 +111,24 @@ class FullCovarianceModel():
Resp
=
Respacer
(
MatMult
.
target
,
mean
.
target
)
self
.
generator
=
self
.
Flat
.
adjoint
@
(
mean
+
Resp
@
MatMult
@
matmul_setup
)
Diag
=
DiagonalSelector
(
cov
.
target
,
self
.
Flat
.
target
)
diag_cov
=
Diag
(
cov
).
absolute
()
self
.
entropy
=
GaussianEntropy
(
diag_cov
.
target
)
@
diag_cov
def
get_initial_pos
(
self
,
initial_mean
=
None
,
initial_sig
=
1
):
'''
Provides an initial position for a given mean parameter vector and a
diagonal covariance with an initial standard deviation.
"""Provide an initial position for a given mean parameter vector and a
diagonal covariance with an initial standard deviation.
Parameters
----------
initial_mean: MultiField
The initial mean of the variational approximation. If not None, a Gaussian sample with mean zero and standard deviation of 0.1 is used.
Default: None
The initial mean of the variational approximation. If not None, a
Gaussian sample with mean zero and standard deviation of 0.1 is
used. Default: None
initial_sig: positive float
The initial standard deviation shared by all parameters. Default: 1
'''
The initial standard deviation shared by all parameters. Default: 1
"""
initial_pos
=
from_random
(
self
.
generator
.
domain
).
to_dict
()
initial_pos
[
'latent'
]
=
full
(
self
.
generator
.
domain
[
'latent'
],
0.
)
diag_tri
=
np
.
diag
(
np
.
full
(
self
.
flat_domain
.
shape
[
0
],
initial_sig
))[
np
.
tril_indices
(
self
.
flat_domain
.
shape
[
0
])]
...
...
@@ -136,14 +140,15 @@ class FullCovarianceModel():
class
GaussianEntropy
(
EnergyOperator
):
'''
Calculates the entropy of a Gaussian distribution given the diagonal of a
triangular decomposition of the covariance.
"""Calculate the entropy of a Gaussian distribution given the diagonal of a
triangular decomposition of the covariance.
Parameters
----------
domain: Domain
The domain of the diagonal.
'''
The domain of the diagonal.
"""
def
__init__
(
self
,
domain
):
self
.
_domain
=
domain
...
...
@@ -159,16 +164,17 @@ class GaussianEntropy(EnergyOperator):
class
LowerTriangularProjector
(
LinearOperator
):
'''
Projects the DOFs of a triangular matrix into the matrix form.
"""Project the DOFs of a triangular matrix into the matrix form.
Parameters
----------
domain: Domain
A one-dimensional domain containing N(N+1)/2 DOFs of a triangular matrix.
A one-dimensional domain containing N(N+1)/2 DOFs of a triangular
matrix.
target: Domain
A two-dimensional domain with NxN entries.
'''
A two-dimensional domain with NxN entries.
"""
def
__init__
(
self
,
domain
,
target
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_target
=
DomainTuple
.
make
(
target
)
...
...
@@ -187,16 +193,17 @@ class LowerTriangularProjector(LinearOperator):
class
DiagonalSelector
(
LinearOperator
):
'''
Extracts the diagonal of a two-dimensional field.
"""Extract the diagonal of a two-dimensional field.
Parameters
----------
domain: Domain
The two-dimensional domain of the input field
The two-dimensional domain of the input field
target: Domain
A one-dimensional domain in which the diagonal of the input field lives.
'''
The one-dimensional domain on which the diagonal of the input field is
defined.
"""
def
__init__
(
self
,
domain
,
target
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_target
=
DomainTuple
.
make
(
target
)
...
...
@@ -211,16 +218,16 @@ class DiagonalSelector(LinearOperator):
class
Respacer
(
LinearOperator
):
'''
Re-maps a field from one domain to another one with the same amounts of
DOFs. Wrapps the numpy.reshape method.
"""Re-map a field from one domain to another one with the same amounts of
DOFs. Wrapps the numpy.reshape method.
Parameters
----------
domain: Domain
The domain of the input field.
The domain of the input field.
target: Domain
The domain of the output field.
'''
The domain of the output field.
"""
def
__init__
(
self
,
domain
,
target
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
...
...
src/minimization/stochastic_minimizer.py
View file @
59a009cb
...
...
@@ -19,26 +19,25 @@ from .minimizer import Minimizer
class
ADVIOptimizer
(
Minimizer
):
'''
Provides an implementation of an adaptive step-size sequence optimizer,
following https://arxiv.org/abs/1603.00788.
"""Provide an implementation of an adaptive step-size sequence optimizer,
following https://arxiv.org/abs/1603.00788.
Parameters
----------
steps: int
The number of concecutive steps during one call of the optimizer.
eta: positive float
The scale of the step-size sequence. It might have to be adapted to the application to increase performance. Default: 1.
The scale of the step-size sequence. It might have to be adapted to the
application to increase performance. Default: 1.
alpha: float between 0 and 1
The fraction of how much the current gradient impacts the momentum.
The fraction of how much the current gradient impacts the momentum.
tau: positive float
This quantity prevents division by zero.
epsilon: positive float
A small value guarantees Robbins and Monro conditions.
'''
"""
def
__init__
(
self
,
steps
,
eta
=
1
,
alpha
=
0.1
,
tau
=
1
,
epsilon
=
1e-16
):
self
.
alpha
=
alpha
self
.
eta
=
eta
self
.
tau
=
tau
...
...
@@ -59,15 +58,6 @@ class ADVIOptimizer(Minimizer):
return
new_position
def
__call__
(
self
,
E
):
'''
Performs the optimization.
Parameters
----------
E: EnergyOperator
The target function.
'''
from
..minimization.parametric_gaussian_kl
import
ParametricGaussianKL
if
self
.
s
is
None
:
self
.
s
=
E
.
gradient
**
2
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment