Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
28920e82
Commit
28920e82
authored
Jun 02, 2021
by
Philipp Frank
Browse files
restructure mfvi
parent
f30d1547
Changes
4
Hide whitespace changes
Inline
Side-by-side
demos/meanfield_inference.py
View file @
28920e82
...
...
@@ -75,8 +75,8 @@ if __name__ == "__main__":
position_fc
=
ift
.
from_random
(
H
.
domain
)
*
0.1
position_mf
=
ift
.
from_random
(
H
.
domain
)
*
0.
fc
=
ift
.
FullCovariance
(
position_fc
,
H
,
3
,
True
,
initial_sig
=
0.01
)
mf
=
ift
.
MeanField
(
position_mf
,
H
,
3
,
True
,
initial_sig
=
0.0001
)
fc
=
ift
.
FullCovariance
VI
(
position_fc
,
H
,
3
,
True
,
initial_sig
=
0.01
)
mf
=
ift
.
MeanField
VI
(
position_mf
,
H
,
3
,
True
,
initial_sig
=
0.0001
)
minimizer_fc
=
ift
.
ADVIOptimizer
(
10
)
minimizer_mf
=
ift
.
ADVIOptimizer
(
10
)
...
...
@@ -94,12 +94,12 @@ if __name__ == "__main__":
label
=
"Full covariance"
,
)
plt
.
plot
(
sky
(
mf
.
positio
n
).
val
,
"r-"
,
label
=
"Mean field"
sky
(
mf
.
mea
n
).
val
,
"r-"
,
label
=
"Mean field"
)
#
for
samp in KL_fc.samples
:
#
plt.plot(
#
sky(
fullcov_model.generator(KL_fc.position +
samp)).val, "b-", alpha=0.3
#
)
for
i
in
range
(
5
)
:
plt
.
plot
(
sky
(
mf
.
draw_
samp
le
(
)).
val
,
"b-"
,
alpha
=
0.3
)
#for samp in KL_mf.samples:
# plt.plot(
# sky(meanfield_model.generator(KL_mf.position + samp)).val,
...
...
src/__init__.py
View file @
28920e82
...
...
@@ -92,7 +92,7 @@ from .library.adjust_variances import (make_adjust_variances_hamiltonian,
from
.library.nft
import
Gridder
,
FinuFFT
from
.library.correlated_fields
import
CorrelatedFieldMaker
from
.library.correlated_fields_simple
import
SimpleCorrelatedField
from
.library.variational_models
import
MeanField
,
FullCovariance
from
.library.variational_models
import
MeanField
VI
,
FullCovariance
VI
from
.
import
extra
...
...
src/library/variational_models.py
View file @
28920e82
...
...
@@ -28,42 +28,57 @@ from ..operators.linear_operator import LinearOperator
from
..operators.multifield2vector
import
Multifield2Vector
from
..operators.sandwich_operator
import
SandwichOperator
from
..operators.simple_linear_operators
import
FieldAdapter
,
PartialExtractor
from
..sugar
import
domain_union
,
full
,
makeField
,
is_fieldlike
from
..minimization.
stochastic_minimizer
import
PartialSampledEnergy
from
..sugar
import
domain_union
,
full
,
makeField
,
from_random
,
is_fieldlike
from
..minimization.
energy_adapter
import
StochasticEnergyAdapter
class
MeanField
:
def
__init__
(
self
,
position
,
hamiltonian
,
n_samples
,
mirror_samples
,
initial_sig
=
1
,
comm
=
None
,
nanisinf
=
False
,
names
=
[
'mean'
,
'var'
]
):
class
MeanField
VI
:
def
__init__
(
self
,
initial_
position
,
hamiltonian
,
n_samples
,
mirror_samples
,
initial_sig
=
1
,
comm
=
None
,
nanisinf
=
False
):
"""Collect the operators required for Gaussian mean-field variational
inference.
"""
Flat
=
Multifield2Vector
(
position
.
domain
)
std
=
FieldAdapter
(
Flat
.
target
,
names
[
1
]
).
absolute
()
Flat
=
Multifield2Vector
(
initial_
position
.
domain
)
self
.
_
std
=
FieldAdapter
(
Flat
.
target
,
'std'
).
absolute
()
latent
=
FieldAdapter
(
Flat
.
target
,
'latent'
)
mean
=
FieldAdapter
(
Flat
.
target
,
names
[
0
])
generator
=
Flat
.
adjoint
(
mean
+
std
*
latent
)
entropy
=
GaussianEntropy
(
std
.
target
)
@
std
pos
=
{
names
[
0
]:
Flat
(
position
)}
self
.
_mean
=
FieldAdapter
(
Flat
.
target
,
'mean'
)
self
.
_generator
=
Flat
.
adjoint
(
self
.
_mean
+
self
.
_std
*
latent
)
self
.
_entropy
=
GaussianEntropy
(
self
.
_std
.
target
)
@
self
.
_std
self
.
_mean
=
Flat
.
adjoint
@
self
.
_mean
self
.
_std
=
Flat
.
adjoint
@
self
.
_std
pos
=
{
'mean'
:
Flat
(
initial_position
)}
if
is_fieldlike
(
initial_sig
):
pos
[
names
[
1
]
]
=
Flat
(
initial_sig
)
pos
[
'std'
]
=
Flat
(
initial_sig
)
else
:
pos
[
names
[
1
]
]
=
full
(
Flat
.
target
,
initial_sig
)
pos
[
'std'
]
=
full
(
Flat
.
target
,
initial_sig
)
pos
=
MultiField
.
from_dict
(
pos
)
op
=
hamiltonian
(
generator
)
+
entropy
self
.
_
names
=
nam
es
self
.
_KL
=
PartialSampledEnergy
.
make
(
pos
,
op
,
[
'latent'
,],
n_samples
,
mirror_samples
,
nanisinf
=
nanisinf
,
comm
=
comm
)
self
.
_
Flat
=
F
lat
op
=
hamiltonian
(
self
.
_
generator
)
+
self
.
_
entropy
self
.
_
KL
=
StochasticEnergyAdapter
.
make
(
pos
,
op
,
[
'latent'
,],
n_sampl
es
,
mirror_samples
,
nanisinf
=
nanisinf
,
comm
=
comm
)
self
.
_
samdom
=
lat
ent
.
domain
@
property
def
position
(
self
):
return
self
.
_Flat
.
adjoint
(
self
.
_KL
.
position
[
self
.
_names
[
0
]])
def
mean
(
self
):
return
self
.
_mean
.
force
(
self
.
_KL
.
position
)
@
property
def
std
(
self
):
return
self
.
_std
.
force
(
self
.
_KL
.
position
)
@
property
def
entropy
(
self
):
return
self
.
_entropy
.
force
(
self
.
_KL
.
position
)
def
draw_sample
(
self
):
_
,
op
=
self
.
_generator
.
simplify_for_constant_input
(
from_random
(
self
.
_samdom
))
return
op
(
self
.
_KL
.
position
)
def
minimize
(
self
,
minimizer
):
self
.
_KL
,
_
=
minimizer
(
self
.
_KL
)
class
FullCovariance
:
class
FullCovariance
VI
:
def
__init__
(
self
,
position
,
hamiltonian
,
n_samples
,
mirror_samples
,
initial_sig
=
1
,
comm
=
None
,
nanisinf
=
False
,
names
=
[
'mean'
,
'cov'
]):
"""Collect the operators required for Gaussian full-covariance variational
...
...
@@ -99,7 +114,7 @@ class FullCovariance:
pos
=
MultiField
.
from_dict
({
names
[
0
]:
Flat
(
position
),
names
[
1
]:
makeField
(
generator
.
domain
[
names
[
1
]],
diag_tri
)})
op
=
hamiltonian
(
generator
)
+
entropy
self
.
_names
=
names
self
.
_KL
=
PartialSampledEnergy
.
make
(
pos
,
op
,
[
'latent'
,],
n_samples
,
mirror_samples
,
nanisinf
=
nanisinf
,
comm
=
comm
)
self
.
_KL
=
StochasticEnergyAdapter
.
make
(
pos
,
op
,
[
'latent'
,],
n_samples
,
mirror_samples
,
nanisinf
=
nanisinf
,
comm
=
comm
)
self
.
_Flat
=
Flat
@
property
...
...
src/minimization/energy_adapter.py
View file @
28920e82
...
...
@@ -23,7 +23,6 @@ from ..minimization.energy import Energy
from
..utilities
import
myassert
,
allreduce_sum
from
..multi_domain
import
MultiDomain
from
..sugar
import
from_random
from
.kl_energies
import
_SelfAdjointOperatorWrapper
,
_get_lo_hi
class
EnergyAdapter
(
Energy
):
"""Helper class which provides the traditional Nifty Energy interface to
...
...
@@ -91,7 +90,7 @@ class EnergyAdapter(Energy):
class
StochasticEnergyAdapter
(
Energy
):
def
__init__
(
self
,
position
,
op
,
keys
,
local_ops
,
n_samples
,
comm
,
nanisinf
,
def
__init__
(
self
,
position
,
big
op
,
keys
,
local_ops
,
n_samples
,
comm
,
nanisinf
,
_callingfrommake
=
False
):
if
not
_callingfrommake
:
raise
NotImplementedError
...
...
@@ -112,7 +111,7 @@ class StochasticEnergyAdapter(Energy):
self
.
_val
=
np
.
inf
self
.
_grad
=
allreduce_sum
(
g
,
self
.
_comm
)
/
self
.
_n_samples
self
.
_op
=
op
self
.
_op
=
big
op
self
.
_keys
=
keys
@
property
...
...
@@ -136,6 +135,7 @@ class StochasticEnergyAdapter(Energy):
@
property
def
metric
(
self
):
from
.kl_energies
import
_SelfAdjointOperatorWrapper
return
_SelfAdjointOperatorWrapper
(
self
.
position
.
domain
,
self
.
apply_metric
)
...
...
@@ -158,6 +158,7 @@ class StochasticEnergyAdapter(Energy):
samdom
=
MultiDomain
.
make
(
samdom
)
local_ops
=
[]
sseq
=
random
.
spawn_sseq
(
n_samples
)
from
.kl_energies
import
_get_lo_hi
for
i
in
range
(
*
_get_lo_hi
(
comm
,
n_samples
)):
with
random
.
Context
(
sseq
[
i
]):
rnd
=
from_random
(
samdom
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment