Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
91c95868
Commit
91c95868
authored
Jun 02, 2021
by
Philipp Frank
Browse files
restructure fcvi
parent
28920e82
Changes
2
Hide whitespace changes
Inline
Side-by-side
demos/meanfield_inference.py
View file @
91c95868
...
...
@@ -73,11 +73,11 @@ if __name__ == "__main__":
H
=
ift
.
StandardHamiltonian
(
likelihood
)
position_fc
=
ift
.
from_random
(
H
.
domain
)
*
0.1
position_mf
=
ift
.
from_random
(
H
.
domain
)
*
0.
position_mf
=
ift
.
from_random
(
H
.
domain
)
*
0.
1
fc
=
ift
.
FullCovarianceVI
(
position_fc
,
H
,
3
,
True
,
initial_sig
=
0.01
)
mf
=
ift
.
MeanFieldVI
(
position_mf
,
H
,
3
,
True
,
initial_sig
=
0.
00
01
)
minimizer_fc
=
ift
.
ADVIOptimizer
(
1
0
)
mf
=
ift
.
MeanFieldVI
(
position_mf
,
H
,
3
,
True
,
initial_sig
=
0.01
)
minimizer_fc
=
ift
.
ADVIOptimizer
(
20
,
eta
=
0.
1
)
minimizer_mf
=
ift
.
ADVIOptimizer
(
10
)
plt
.
pause
(
0.001
)
...
...
@@ -89,7 +89,7 @@ if __name__ == "__main__":
plt
.
figure
(
"result"
)
plt
.
cla
()
plt
.
plot
(
sky
(
fc
.
positio
n
).
val
,
sky
(
fc
.
mea
n
).
val
,
"b-"
,
label
=
"Full covariance"
,
)
...
...
@@ -98,14 +98,11 @@ if __name__ == "__main__":
)
for
i
in
range
(
5
):
plt
.
plot
(
sky
(
mf
.
draw_sample
()).
val
,
"b-"
,
alpha
=
0.3
sky
(
fc
.
draw_sample
()).
val
,
"b-"
,
alpha
=
0.3
)
plt
.
plot
(
sky
(
mf
.
draw_sample
()).
val
,
"r-"
,
alpha
=
0.3
)
#for samp in KL_mf.samples:
# plt.plot(
# sky(meanfield_model.generator(KL_mf.position + samp)).val,
# "r-",
# alpha=0.3,
# )
plt
.
plot
(
data
.
val
,
"kx"
)
plt
.
plot
(
sky
(
mock_position
).
val
,
"k-"
,
label
=
"Ground truth"
)
plt
.
legend
()
...
...
src/library/variational_models.py
View file @
91c95868
...
...
@@ -27,8 +27,8 @@ from ..operators.energy_operators import EnergyOperator
from
..operators.linear_operator
import
LinearOperator
from
..operators.multifield2vector
import
Multifield2Vector
from
..operators.sandwich_operator
import
SandwichOperator
from
..operators.simple_linear_operators
import
FieldAdapter
,
PartialExtractor
from
..sugar
import
domain_union
,
full
,
makeField
,
from_random
,
is_fieldlike
from
..operators.simple_linear_operators
import
FieldAdapter
from
..sugar
import
full
,
makeField
,
from_random
,
is_fieldlike
from
..minimization.energy_adapter
import
StochasticEnergyAdapter
...
...
@@ -80,46 +80,52 @@ class MeanFieldVI:
class
FullCovarianceVI
:
def
__init__
(
self
,
position
,
hamiltonian
,
n_samples
,
mirror_samples
,
initial_sig
=
1
,
comm
=
None
,
nanisinf
=
False
,
names
=
[
'mean'
,
'cov'
]
):
initial_sig
=
1
,
comm
=
None
,
nanisinf
=
False
):
"""Collect the operators required for Gaussian full-covariance variational
inference.
"""
Flat
=
Multifield2Vector
(
position
.
domain
)
one_space
=
UnstructuredDomain
(
1
)
flat_domain
=
Flat
.
target
[
0
]
N_tri
=
flat_domain
.
shape
[
0
]
*
(
flat_domain
.
shape
[
0
]
+
1
)
//
2
triangular_space
=
DomainTuple
.
make
(
UnstructuredDomain
(
N_tri
))
tri
=
FieldAdapter
(
triangular_space
,
names
[
1
]
)
tri
=
FieldAdapter
(
triangular_space
,
'cov'
)
mat_space
=
DomainTuple
.
make
((
flat_domain
,
flat_domain
))
lat_mat_space
=
DomainTuple
.
make
((
one_space
,
flat_domain
))
lat
=
FieldAdapter
(
lat_mat_space
,
'latent'
)
lat
=
FieldAdapter
(
Flat
.
target
,
'latent'
)
LT
=
LowerTriangularProjector
(
triangular_space
,
mat_space
)
mean
=
FieldAdapter
(
flat_domain
,
names
[
0
]
)
mean
=
FieldAdapter
(
flat_domain
,
'mean'
)
cov
=
LT
@
tri
co
=
FieldAdapter
(
cov
.
target
,
'co'
)
matmul_setup_dom
=
domain_union
((
co
.
domain
,
lat
.
domain
))
co_part
=
PartialExtractor
(
matmul_setup_dom
,
co
.
domain
)
lat_part
=
PartialExtractor
(
matmul_setup_dom
,
lat
.
domain
)
matmul_setup
=
lat_part
.
adjoint
@
lat
.
adjoint
@
lat
+
co_part
.
adjoint
@
co
.
adjoint
@
cov
MatMult
=
MultiLinearEinsum
(
matmul_setup
.
target
,
'ij,ki->jk'
,
key_order
=
(
'co'
,
'latent'
))
Resp
=
Respacer
(
MatMult
.
target
,
mean
.
target
)
generator
=
Flat
.
adjoint
@
(
mean
+
Resp
@
MatMult
@
matmul_setup
)
matmul_setup
=
lat
.
adjoint
@
lat
+
cov
.
ducktape_left
(
'co'
)
MatMult
=
MultiLinearEinsum
(
matmul_setup
.
target
,
'ij,j->i'
,
key_order
=
(
'co'
,
'latent'
))
self
.
_generator
=
Flat
.
adjoint
@
(
mean
+
MatMult
@
matmul_setup
)
Diag
=
DiagonalSelector
(
cov
.
target
,
Flat
.
target
)
diag_cov
=
Diag
(
cov
).
absolute
()
entropy
=
GaussianEntropy
(
diag_cov
.
target
)
@
diag_cov
diag_tri
=
np
.
diag
(
np
.
full
(
flat_domain
.
shape
[
0
],
initial_sig
))[
np
.
tril_indices
(
flat_domain
.
shape
[
0
])]
pos
=
MultiField
.
from_dict
({
names
[
0
]:
Flat
(
position
),
names
[
1
]:
makeField
(
generator
.
domain
[
names
[
1
]],
diag_tri
)})
op
=
hamiltonian
(
generator
)
+
entropy
self
.
_names
=
names
self
.
_KL
=
StochasticEnergyAdapter
.
make
(
pos
,
op
,
[
'latent'
,],
n_samples
,
mirror_samples
,
nanisinf
=
nanisinf
,
comm
=
comm
)
self
.
_Flat
=
Flat
self
.
_entropy
=
GaussianEntropy
(
diag_cov
.
target
)
@
diag_cov
diag_tri
=
np
.
diag
(
np
.
full
(
flat_domain
.
shape
[
0
],
initial_sig
))
diag_tri
=
diag_tri
[
np
.
tril_indices
(
flat_domain
.
shape
[
0
])]
pos
=
MultiField
.
from_dict
(
{
'mean'
:
Flat
(
position
),
'cov'
:
makeField
(
triangular_space
,
diag_tri
)})
op
=
hamiltonian
(
self
.
_generator
)
+
self
.
_entropy
self
.
_KL
=
StochasticEnergyAdapter
.
make
(
pos
,
op
,
[
'latent'
,],
n_samples
,
mirror_samples
,
nanisinf
=
nanisinf
,
comm
=
comm
)
self
.
_mean
=
Flat
.
adjoint
@
mean
self
.
_samdom
=
lat
.
domain
@
property
def
mean
(
self
):
return
self
.
_mean
.
force
(
self
.
_KL
.
position
)
@
property
def
position
(
self
):
return
self
.
_Flat
.
adjoint
(
self
.
_KL
.
position
[
self
.
_names
[
0
]])
def
entropy
(
self
):
return
self
.
_entropy
.
force
(
self
.
_KL
.
position
)
def
draw_sample
(
self
):
_
,
op
=
self
.
_generator
.
simplify_for_constant_input
(
from_random
(
self
.
_samdom
))
return
op
(
self
.
_KL
.
position
)
def
minimize
(
self
,
minimizer
):
self
.
_KL
,
_
=
minimizer
(
self
.
_KL
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment