Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
ada00fdb
Commit
ada00fdb
authored
Jun 11, 2021
by
Martin Reinecke
Browse files
Merge branch 'more_samplers' into 'NIFTy_7'
Parametric MGVI See merge request
ift/nifty!604
parents
c76f44f3
2cbf787a
Changes
19
Hide whitespace changes
Inline
Side-by-side
.gitlab-ci.yml
View file @
ada00fdb
...
...
@@ -148,6 +148,11 @@ run_visual_vi:
script
:
-
python3 demos/variational_inference_visualized.py
run_meanfield
:
stage
:
demo_runs
script
:
-
python3 demos/parametric_variational_inference.py
run_nonlinearity_guide
:
stage
:
demo_runs
script
:
...
...
demos/parametric_variational_inference.py
0 → 100644
View file @
ada00fdb
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
###############################################################################
# Meanfield and fullcovariance variational inference
#
# The signal is a 1-D lognormal distributed field.
# The data follows a Poisson likelihood.
# The posterior distribution is approximated with a diagonal, as well as a
# full covariance Gaussian distribution. This is achieved by minimizing
# a stochastic estimate of the KL-Divergence
#
# Note that the fullcovariance approximation scales quadratically with the
# number of parameters.
###############################################################################
import
numpy
as
np
import
nifty7
as
ift
from
matplotlib
import
pyplot
as
plt
ift
.
random
.
push_sseq_from_seed
(
27
)
if
__name__
==
"__main__"
:
# Space and model setup
position_space
=
ift
.
RGSpace
([
100
])
harmonic_space
=
position_space
.
get_default_codomain
()
HT
=
ift
.
HarmonicTransformOperator
(
harmonic_space
,
position_space
)
p_space
=
ift
.
PowerSpace
(
harmonic_space
)
pd
=
ift
.
PowerDistributor
(
harmonic_space
,
p_space
)
a
=
ift
.
PS_field
(
p_space
,
lambda
k
:
1.0
/
(
1.0
+
k
**
2
))
A
=
pd
(
a
)
sky
=
10
*
ift
.
exp
(
HT
(
ift
.
makeOp
(
A
))).
ducktape
(
"xi"
)
R
=
ift
.
GeometryRemover
(
position_space
)
mask
=
np
.
zeros
(
position_space
.
shape
)
mask
[
mask
.
shape
[
0
]
//
3
:
2
*
mask
.
shape
[
0
]
//
3
]
=
1
mask
=
ift
.
Field
.
from_raw
(
position_space
,
mask
)
R
=
ift
.
MaskOperator
(
mask
)
d_space
=
R
.
target
[
0
]
lamb
=
R
(
sky
)
# Generate simulated signal and data and build log-likelihood
mock_position
=
ift
.
from_random
(
sky
.
domain
,
"normal"
)
data
=
ift
.
random
.
current_rng
().
poisson
(
lamb
(
mock_position
).
val
)
data
=
ift
.
makeField
(
d_space
,
data
)
loglikelihood
=
ift
.
PoissonianEnergy
(
data
)
@
lamb
H
=
ift
.
StandardHamiltonian
(
loglikelihood
)
# Settings for minimization
IC
=
ift
.
StochasticAbsDeltaEnergyController
(
5
,
iteration_limit
=
200
,
name
=
'advi'
)
minimizer_fc
=
ift
.
ADVIOptimizer
(
IC
,
eta
=
0.1
)
minimizer_mf
=
ift
.
ADVIOptimizer
(
IC
)
# Initial positions
position_fc
=
ift
.
from_random
(
H
.
domain
)
*
0.1
position_mf
=
ift
.
from_random
(
H
.
domain
)
*
0.1
# Setup of the variational models
fc
=
ift
.
FullCovarianceVI
(
position_fc
,
H
,
3
,
True
,
initial_sig
=
0.01
)
mf
=
ift
.
MeanFieldVI
(
position_mf
,
H
,
3
,
True
,
initial_sig
=
0.01
)
niter
=
10
for
ii
in
range
(
niter
):
# Plotting
plt
.
plot
(
sky
(
fc
.
mean
).
val
,
"b-"
,
label
=
"Full covariance"
)
plt
.
plot
(
sky
(
mf
.
mean
).
val
,
"r-"
,
label
=
"Mean field"
)
for
_
in
range
(
5
):
plt
.
plot
(
sky
(
fc
.
draw_sample
()).
val
,
"b-"
,
alpha
=
0.3
)
plt
.
plot
(
sky
(
mf
.
draw_sample
()).
val
,
"r-"
,
alpha
=
0.3
)
plt
.
plot
(
R
.
adjoint
(
data
).
val
,
"kx"
)
plt
.
plot
(
sky
(
mock_position
).
val
,
"k-"
,
label
=
"Ground truth"
)
plt
.
legend
()
plt
.
ylim
(
0.1
,
data
.
val
.
max
()
+
10
)
fname
=
f
"meanfield_
{
ii
:
03
d
}
.png"
plt
.
savefig
(
fname
)
print
(
f
"Saved results as '
{
fname
}
' (
{
ii
}
/
{
niter
-
1
}
)."
)
plt
.
close
()
# /Plotting
# Run minimization
fc
.
minimize
(
minimizer_fc
)
mf
.
minimize
(
minimizer_mf
)
demos/variational_inference_visualized.py
View file @
ada00fdb
...
...
@@ -19,13 +19,9 @@
###############################################################################
# Variational Inference (VI)
#
# This script demonstrates how MGVI and GeoVI work for an inference problem
# with only two real quantities of interest. This enables us to plot the
# posterior probability density as two-dimensional plot. The approximate
# posterior samples are contrasted with the maximum-a-posterior (MAP) solution
# together with samples drawn with the Laplace method. This method uses the
# local curvature at the MAP solution as inverse covariance of a Gaussian
# probability density.
# This script demonstrates how MGVI, GeoVI, MeanfieldVI and FullCovarianceVI
# work for an inference problem with only two real quantities of interest. This
# enables us to plot the posterior probability density as two-dimensional plot.
###############################################################################
import
numpy
as
np
...
...
@@ -74,65 +70,93 @@ def main():
plt
.
pause
(
2.0
)
plt
.
close
()
pos
=
ift
.
from_random
(
ham
.
domain
,
'normal'
)
MAP
=
ift
.
EnergyAdapter
(
pos
,
ham
,
want_metric
=
True
)
minimizer
=
ift
.
NewtonCG
(
ift
.
GradientNormController
(
iteration_limit
=
20
,
name
=
'Mini'
))
MAP
,
_
=
minimizer
(
MAP
)
map_xs
,
map_ys
=
[],
[]
for
ii
in
range
(
10
):
samp
=
(
MAP
.
metric
.
draw_sample
(
from_inverse
=
True
)
+
MAP
.
position
).
val
map_xs
.
append
(
samp
[
'a'
])
map_ys
.
append
(
samp
[
'b'
])
mapx
=
xx
[
z
==
np
.
max
(
z
)]
mapy
=
yy
[
z
==
np
.
max
(
z
)]
meanx
=
(
xx
*
z
).
sum
()
/
z
.
sum
()
meany
=
(
yy
*
z
).
sum
()
/
z
.
sum
()
n_samples
=
100
minimizer
=
ift
.
NewtonCG
(
ift
.
GradientNormController
(
iteration_limit
=
2
,
name
=
'Mini'
))
pos
=
pos1
=
ift
.
from_random
(
ham
.
domain
,
'normal'
)
fig
,
axs
=
plt
.
subplots
(
2
,
1
,
figsize
=
[
12
,
8
])
for
ii
in
range
(
15
):
if
ii
%
3
==
0
:
# Resample
mgkl
=
ift
.
MetricGaussianKL
(
pos
,
ham
,
100
,
False
)
mini_samp
=
ift
.
NewtonCG
(
ift
.
GradientNormController
(
iteration_limit
=
5
))
geokl
=
ift
.
GeoMetricKL
(
pos1
,
ham
,
100
,
mini_samp
,
False
)
for
axx
in
axs
:
ift
.
GradientNormController
(
iteration_limit
=
3
,
name
=
'Mini'
))
IC
=
ift
.
StochasticAbsDeltaEnergyController
(
0.5
,
iteration_limit
=
20
,
name
=
'advi'
)
stochastic_minimizer_mf
=
ift
.
ADVIOptimizer
(
IC
,
eta
=
0.3
)
stochastic_minimizer_fc
=
ift
.
ADVIOptimizer
(
IC
,
eta
=
0.3
)
posmg
=
posgeo
=
posmf
=
posfc
=
ift
.
from_random
(
ham
.
domain
,
'normal'
)
fc
=
ift
.
FullCovarianceVI
(
posfc
,
ham
,
10
,
False
,
initial_sig
=
0.01
)
mf
=
ift
.
MeanFieldVI
(
posmf
,
ham
,
10
,
False
,
initial_sig
=
0.01
)
fig
,
axs
=
plt
.
subplots
(
2
,
2
,
figsize
=
[
12
,
8
])
axs
=
axs
.
flatten
()
def
update_plot
(
runs
):
for
axx
,
(
nn
,
kl
,
pp
,
sam
)
in
zip
(
axs
,
runs
):
axx
.
clear
()
im
=
axx
.
imshow
(
z
.
T
,
origin
=
'lower'
,
norm
=
LogNorm
(
vmin
=
1e-3
,
vmax
=
np
.
max
(
z
)),
cmap
=
'gist_earth_r'
,
extent
=
x_limits_scaled
+
y_limits
)
if
ii
==
0
:
cbar
=
plt
.
colorbar
(
im
,
ax
=
axx
)
cbar
.
ax
.
set_ylabel
(
'pdf'
)
for
jj
,
nn
,
kl
,
pp
in
((
0
,
"MGVI"
,
mgkl
,
pos
),
(
1
,
"GeoVI"
,
geokl
,
pos1
)):
axx
.
imshow
(
z
.
T
,
origin
=
'lower'
,
cmap
=
'gist_earth_r'
,
norm
=
LogNorm
(
vmin
=
1e-3
,
vmax
=
np
.
max
(
z
)),
extent
=
x_limits_scaled
+
y_limits
)
xs
,
ys
=
[],
[]
for
samp
in
kl
.
samples
:
samp
=
(
samp
+
pp
).
val
xs
.
append
(
samp
[
'a'
])
ys
.
append
(
samp
[
'b'
])
axs
[
jj
].
scatter
(
np
.
array
(
xs
)
*
scale
,
np
.
array
(
ys
),
label
=
f
'
{
nn
}
samples'
)
axs
[
jj
].
scatter
(
pp
.
val
[
'a'
]
*
scale
,
pp
.
val
[
'b'
],
label
=
f
'
{
nn
}
latent mean'
)
axs
[
jj
].
set_title
(
nn
)
for
axx
in
axs
:
axx
.
scatter
(
np
.
array
(
map_xs
)
*
scale
,
np
.
array
(
map_ys
),
label
=
'Laplace samples'
)
axx
.
scatter
(
MAP
.
position
.
val
[
'a'
]
*
scale
,
MAP
.
position
.
val
[
'b'
],
label
=
'Maximum a posterior solution'
)
if
sam
:
samples
=
(
samp
+
pp
for
samp
in
kl
.
samples
)
else
:
samples
=
(
kl
.
draw_sample
()
for
_
in
range
(
n_samples
))
mx
,
my
=
0.
,
0.
for
samp
in
samples
:
a
=
samp
.
val
[
'a'
]
xs
.
append
(
a
)
mx
+=
a
b
=
samp
.
val
[
'b'
]
ys
.
append
(
b
)
my
+=
b
mx
/=
n_samples
my
/=
n_samples
axx
.
scatter
(
np
.
array
(
xs
)
*
scale
,
np
.
array
(
ys
),
label
=
f
'
{
nn
}
samples'
)
axx
.
scatter
(
mx
*
scale
,
my
,
label
=
f
'
{
nn
}
mean'
)
axx
.
scatter
(
mapx
*
scale
,
mapy
,
label
=
'MAP'
)
axx
.
scatter
(
meanx
*
scale
,
meany
,
label
=
'Posterior mean'
)
axx
.
set_title
(
nn
)
axx
.
set_xlim
(
x_limits_scaled
)
axx
.
set_ylim
(
y_limits
)
axx
.
set_ylabel
(
'y'
)
axx
.
legend
(
loc
=
'lower right'
)
axs
[
0
].
xaxis
.
set_visible
(
False
)
axs
[
1
].
set_xlabel
(
'x'
)
axs
[
1
].
xaxis
.
set_visible
(
False
)
axs
[
1
].
yaxis
.
set_visible
(
False
)
axs
[
2
].
set_xlabel
(
'x'
)
axs
[
2
].
set_ylabel
(
'y'
)
axs
[
3
].
yaxis
.
set_visible
(
False
)
axs
[
3
].
set_xlabel
(
'x'
)
plt
.
tight_layout
()
plt
.
draw
()
plt
.
pause
(
1.0
)
plt
.
pause
(
2.0
)
for
ii
in
range
(
20
):
if
ii
%
2
==
0
:
# Resample GeoVI and MGVI
mgkl
=
ift
.
MetricGaussianKL
(
posmg
,
ham
,
n_samples
,
False
)
mini_samp
=
ift
.
NewtonCG
(
ift
.
AbsDeltaEnergyController
(
1E-8
,
iteration_limit
=
5
))
geokl
=
ift
.
GeoMetricKL
(
posgeo
,
ham
,
n_samples
,
mini_samp
,
False
)
runs
=
((
"MGVI"
,
mgkl
,
posmg
,
True
),
(
"GeoVI"
,
geokl
,
posgeo
,
True
),
(
"MeanfieldVI"
,
mf
,
posmf
,
False
),
(
"FullCovarianceVI"
,
fc
,
posfc
,
False
))
update_plot
(
runs
)
mgkl
,
_
=
minimizer
(
mgkl
)
geokl
,
_
=
minimizer
(
geokl
)
pos
=
mgkl
.
position
pos1
=
geokl
.
position
mf
.
minimize
(
stochastic_minimizer_mf
)
fc
.
minimize
(
stochastic_minimizer_fc
)
posmg
=
mgkl
.
position
posgeo
=
geokl
.
position
posmf
=
mf
.
mean
posfc
=
fc
.
mean
runs
=
((
"MGVI"
,
mgkl
,
posmg
,
True
),
(
"GeoVI"
,
geokl
,
posgeo
,
True
),
(
"MeanfieldVI"
,
mf
,
posmf
,
False
),
(
"FullCovarianceVI"
,
fc
,
posfc
,
False
))
update_plot
(
runs
)
ift
.
logger
.
info
(
'Finished'
)
# Uncomment the following line in order to leave the plots open
# plt.show()
...
...
src/__init__.py
View file @
ada00fdb
...
...
@@ -53,6 +53,7 @@ from .operators.energy_operators import (
Squared2NormOperator
,
StudentTEnergy
,
VariableCovarianceGaussianEnergy
)
from
.operators.convolution_operators
import
FuncConvolutionOperator
from
.operators.normal_operators
import
NormalTransform
,
LognormalTransform
from
.operators.multifield2vector
import
Multifield2Vector
from
.probing
import
probe_with_posterior_samples
,
probe_diagonal
,
\
StatCalculator
,
approximation2endo
...
...
@@ -60,17 +61,18 @@ from .probing import probe_with_posterior_samples, probe_diagonal, \
from
.minimization.line_search
import
LineSearch
from
.minimization.iteration_controllers
import
(
IterationController
,
GradientNormController
,
DeltaEnergyController
,
GradInfNormController
,
AbsDeltaEnergyController
)
GradInfNormController
,
AbsDeltaEnergyController
,
StochasticAbsDeltaEnergyController
)
from
.minimization.minimizer
import
Minimizer
from
.minimization.conjugate_gradient
import
ConjugateGradient
from
.minimization.nonlinear_cg
import
NonlinearCG
from
.minimization.descent_minimizers
import
(
DescentMinimizer
,
SteepestDescent
,
VL_BFGS
,
L_BFGS
,
RelaxedNewton
,
NewtonCG
)
from
.minimization.stochastic_minimizer
import
ADVIOptimizer
from
.minimization.scipy_minimizer
import
L_BFGS_B
from
.minimization.energy
import
Energy
from
.minimization.quadratic_energy
import
QuadraticEnergy
from
.minimization.energy_adapter
import
EnergyAdapter
from
.minimization.energy_adapter
import
EnergyAdapter
,
StochasticEnergyAdapter
from
.minimization.kl_energies
import
MetricGaussianKL
,
GeoMetricKL
from
.sugar
import
*
...
...
@@ -90,6 +92,7 @@ from .library.adjust_variances import (make_adjust_variances_hamiltonian,
from
.library.nft
import
Gridder
,
FinuFFT
from
.library.correlated_fields
import
CorrelatedFieldMaker
from
.library.correlated_fields_simple
import
SimpleCorrelatedField
from
.library.variational_models
import
MeanFieldVI
,
FullCovarianceVI
from
.
import
extra
...
...
src/field.py
View file @
ada00fdb
...
...
@@ -405,6 +405,11 @@ class Field(Operator):
return
Field
(
DomainTuple
.
make
(
return_domain
),
data
)
def
scale
(
self
,
factor
):
if
factor
==
1
:
return
self
return
factor
*
self
def
sum
(
self
,
spaces
=
None
):
"""Sums up over the sub-domains given by `spaces`.
...
...
src/library/variational_models.py
0 → 100644
View file @
ada00fdb
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import
numpy
as
np
from
..domain_tuple
import
DomainTuple
from
..domains.unstructured_domain
import
UnstructuredDomain
from
..field
import
Field
from
..linearization
import
Linearization
from
..minimization.energy_adapter
import
StochasticEnergyAdapter
from
..multi_field
import
MultiField
from
..operators.einsum
import
MultiLinearEinsum
from
..operators.energy_operators
import
EnergyOperator
from
..operators.linear_operator
import
LinearOperator
from
..operators.multifield2vector
import
Multifield2Vector
from
..operators.sandwich_operator
import
SandwichOperator
from
..operators.simple_linear_operators
import
FieldAdapter
from
..sugar
import
from_random
,
full
,
is_fieldlike
,
makeDomain
,
makeField
from
..utilities
import
myassert
class
MeanFieldVI
:
"""Collect the operators required for Gaussian meanfield variational
inference.
Gaussian meanfield variational inference approximates some target
distribution with a Gaussian distribution with a diagonal covariance
matrix. The parameters of the approximation, in this case the mean and
standard deviation, are obtained by minimizing a stochastic estimate of the
Kullback-Leibler divergence between the target and the approximation. In
order to obtain gradients w.r.t the parameters, the reparametrization trick
is employed, which separates the stochastic part of the approximation from
a deterministic function, the generator. Samples from the approximation are
drawn by processing samples from a standard Gaussian through this
generator.
Parameters
----------
position : Field
The initial estimate of the approximate mean parameter.
hamiltonian : Energy
Hamiltonian of the approximated probability distribution.
n_samples : int
Number of samples used to stochastically estimate the KL.
mirror_samples : bool
Whether the negative of the drawn samples are also used, as they are
equally legitimate samples. If true, the number of used samples
doubles. Mirroring samples stabilizes the KL estimate as extreme sample
variation is counterbalanced. Since it improves stability in many
cases, it is recommended to set `mirror_samples` to `True`.
initial_sig : positive Field or positive float
The initial estimate of the standard deviation.
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible across
this communicator. If `mirror_samples` is set, then a sample and its
mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on these
occasions but rather the minimizer is told that the position it has
tried is not sensible.
"""
def
__init__
(
self
,
position
,
hamiltonian
,
n_samples
,
mirror_samples
,
initial_sig
=
1
,
comm
=
None
,
nanisinf
=
False
):
Flat
=
Multifield2Vector
(
position
.
domain
)
self
.
_std
=
FieldAdapter
(
Flat
.
target
,
'std'
).
absolute
()
latent
=
FieldAdapter
(
Flat
.
target
,
'latent'
)
self
.
_mean
=
FieldAdapter
(
Flat
.
target
,
'mean'
)
self
.
_generator
=
Flat
.
adjoint
(
self
.
_mean
+
self
.
_std
*
latent
)
self
.
_entropy
=
GaussianEntropy
(
self
.
_std
.
target
)
@
self
.
_std
self
.
_mean
=
Flat
.
adjoint
@
self
.
_mean
self
.
_std
=
Flat
.
adjoint
@
self
.
_std
pos
=
{
'mean'
:
Flat
(
position
)}
if
is_fieldlike
(
initial_sig
):
pos
[
'std'
]
=
Flat
(
initial_sig
)
else
:
pos
[
'std'
]
=
full
(
Flat
.
target
,
initial_sig
)
pos
=
MultiField
.
from_dict
(
pos
)
op
=
hamiltonian
(
self
.
_generator
)
+
self
.
_entropy
self
.
_KL
=
StochasticEnergyAdapter
.
make
(
pos
,
op
,
[
'latent'
,],
n_samples
,
mirror_samples
,
nanisinf
=
nanisinf
,
comm
=
comm
)
self
.
_samdom
=
latent
.
domain
@
property
def
mean
(
self
):
return
self
.
_mean
.
force
(
self
.
_KL
.
position
)
@
property
def
std
(
self
):
return
self
.
_std
.
force
(
self
.
_KL
.
position
)
@
property
def
entropy
(
self
):
return
self
.
_entropy
.
force
(
self
.
_KL
.
position
)
@
property
def
KL
(
self
):
return
self
.
_KL
def
draw_sample
(
self
):
_
,
op
=
self
.
_generator
.
simplify_for_constant_input
(
from_random
(
self
.
_samdom
))
return
op
(
self
.
_KL
.
position
)
def
minimize
(
self
,
minimizer
):
self
.
_KL
,
_
=
minimizer
(
self
.
_KL
)
class
FullCovarianceVI
:
"""Collect the operators required for Gaussian full-covariance variational
Gaussian meanfield variational inference approximates some target
distribution with a Gaussian distribution with a diagonal covariance
matrix. The parameters of the approximation, in this case the mean and a
lower triangular matrix corresponding to a Cholesky decomposition of the
covariance, are obtained by minimizing a stochastic estimate of the
Kullback-Leibler divergence between the target and the approximation. In
order to obtain gradients w.r.t the parameters, the reparametrization trick
is employed, which separates the stochastic part of the approximation from
a deterministic function, the generator. Samples from the approximation are
drawn by processing samples from a standard Gaussian through this
generator.
Note that the size of the covariance scales quadratically with the number
of model parameters.
Parameters
----------
position : Field
The initial estimate of the approximate mean parameter.
hamiltonian : Energy
Hamiltonian of the approximated probability distribution.
n_samples : int
Number of samples used to stochastically estimate the KL.
mirror_samples : bool
Whether the negative of the drawn samples are also used, as they are
equally legitimate samples. If true, the number of used samples
doubles. Mirroring samples stabilizes the KL estimate as extreme sample
variation is counterbalanced. Since it improves stability in many
cases, it is recommended to set `mirror_samples` to `True`.
initial_sig : positive float
The initial estimate for the standard deviation. Initially no
correlation between the parameters is assumed.
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible across
this communicator. If `mirror_samples` is set, then a sample and its
mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on these
occasions but rather the minimizer is told that the position it has
tried is not sensible.
"""
def
__init__
(
self
,
position
,
hamiltonian
,
n_samples
,
mirror_samples
,
initial_sig
=
1
,
comm
=
None
,
nanisinf
=
False
):
Flat
=
Multifield2Vector
(
position
.
domain
)
flat_domain
=
Flat
.
target
[
0
]
mat_space
=
DomainTuple
.
make
((
flat_domain
,
flat_domain
))
lat
=
FieldAdapter
(
Flat
.
target
,
'latent'
)
LT
=
LowerTriangularInserter
(
mat_space
)
tri
=
FieldAdapter
(
LT
.
domain
,
'cov'
)
mean
=
FieldAdapter
(
flat_domain
,
'mean'
)
cov
=
LT
@
tri
matmul_setup
=
lat
.
adjoint
@
lat
+
cov
.
ducktape_left
(
'co'
)
MatMult
=
MultiLinearEinsum
(
matmul_setup
.
target
,
'ij,j->i'
,
key_order
=
(
'co'
,
'latent'
))
self
.
_generator
=
Flat
.
adjoint
@
(
mean
+
MatMult
@
matmul_setup
)
diag_cov
=
(
DiagonalSelector
(
cov
.
target
)
@
cov
).
absolute
()
self
.
_entropy
=
GaussianEntropy
(
diag_cov
.
target
)
@
diag_cov
diag_tri
=
np
.
diag
(
np
.
full
(
flat_domain
.
shape
[
0
],
initial_sig
))
pos
=
MultiField
.
from_dict
(
{
'mean'
:
Flat
(
position
),
'cov'
:
LT
.
adjoint
(
makeField
(
mat_space
,
diag_tri
))})
op
=
hamiltonian
(
self
.
_generator
)
+
self
.
_entropy
self
.
_KL
=
StochasticEnergyAdapter
.
make
(
pos
,
op
,
[
'latent'
,],
n_samples
,
mirror_samples
,
nanisinf
=
nanisinf
,
comm
=
comm
)
self
.
_mean
=
Flat
.
adjoint
@
mean
self
.
_samdom
=
lat
.
domain
@
property
def
mean
(
self
):
return
self
.
_mean
.
force
(
self
.
_KL
.
position
)
@
property
def
entropy
(
self
):
return
self
.
_entropy
.
force
(
self
.
_KL
.
position
)
@
property
def
KL
(
self
):
return
self
.
_KL
def
draw_sample
(
self
):
_
,
op
=
self
.
_generator
.
simplify_for_constant_input
(
from_random
(
self
.
_samdom
))
return
op
(
self
.
_KL
.
position
)
def
minimize
(
self
,
minimizer
):
self
.
_KL
,
_
=
minimizer
(
self
.
_KL
)
class
GaussianEntropy
(
EnergyOperator
):
"""Entropy of a Gaussian distribution given the diagonal of a triangular
decomposition of the covariance.
As metric a `SandwichOperator` of the Jacobian is used. This is not a
proper Fisher metric but may be useful for second order minimization.
Parameters
----------
domain: Domain, DomainTuple, list of Domain
The domain of the diagonal.
"""
def
__init__
(
self
,
domain
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
if
isinstance
(
x
,
Field
):
if
not
np
.
issubdtype
(
x
.
dtype
,
np
.
floating
):
raise
NotImplementedError
(
"only real fields are allowed"
)
if
isinstance
(
x
,
MultiField
):
for
key
in
x
.
keys
():
if
not
np
.
issubdtype
(
x
[
key
].
dtype
,
np
.
floating
):
raise
NotImplementedError
(
"only real fields are allowed"
)
res
=
(
x
*
x
).
scale
(
2
*
np
.
pi
*
np
.
e
).
log
().
sum
().
scale
(
-
0.5
)
if
not
isinstance
(
x
,
Linearization
):
return
res
if
not
x
.
want_metric
:
return
res
return
res
.
add_metric
(
SandwichOperator
.
make
(
res
.
jac
))
class
LowerTriangularInserter
(
LinearOperator
):
"""Insert the entries of a lower triangular matrix into a matrix.
Parameters
----------
target: Domain, DomainTuple, list of Domain
A two-dimensional domain with NxN entries.
"""
def
__init__
(
self
,
target
):
myassert
(
len
(
target
.
shape
)
==
2
)
myassert
(
target
.
shape
[
0
]
==
target
.
shape
[
1
])
self
.
_target
=
makeDomain
(
target
)
ndof
=
(
target
.
shape
[
0
]
*
(
target
.
shape
[
0
]
+
1
))
//
2
self
.
_domain
=
makeDomain
(
UnstructuredDomain
(
ndof
))
self
.
_indices
=
np
.
tril_indices
(
target
.
shape
[
0
])
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):