Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
5a4f59a7
Commit
5a4f59a7
authored
May 30, 2021
by
Philipp Arras
Browse files
Integrate GeoKL into visualized demo
parent
57d88b83
Changes
2
Hide whitespace changes
Inline
Side-by-side
.gitlab-ci.yml
View file @
5a4f59a7
...
@@ -143,7 +143,7 @@ run_curve_fitting:
...
@@ -143,7 +143,7 @@ run_curve_fitting:
paths
:
paths
:
-
'
*.png'
-
'
*.png'
run_visual_
mg
vi
:
run_visual_vi
:
stage
:
demo_runs
stage
:
demo_runs
script
:
script
:
-
python3 demos/
mgvi
_visualized.py
-
python3 demos/
variational_inference
_visualized.py
demos/
mgvi
_visualized.py
→
demos/
variational_inference
_visualized.py
View file @
5a4f59a7
...
@@ -11,21 +11,21 @@
...
@@ -11,21 +11,21 @@
# You should have received a copy of the GNU General Public License
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#
# Copyright(C) 2013-202
0
Max-Planck-Society
# Copyright(C) 2013-202
1
Max-Planck-Society
# Authors: Reimar Leike, Philipp Arras
# Authors: Reimar Leike, Philipp Arras
, Philipp Frank
#
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
###############################################################################
###############################################################################
#
Metric Gaussian
Variational Inference (
MG
VI)
# Variational Inference (VI)
#
#
# This script demonstrates how MGVI work
s
for an inference problem
with only
# This script demonstrates how MGVI
and GeoVI
work for an inference problem
# two real quantities of interest. This enables us to plot the
posterior
#
with only
two real quantities of interest. This enables us to plot the
# probability density as two-dimensional plot. The
posterior samples gener
ate
d
#
posterior
probability density as two-dimensional plot. The
approxim
ate
#
by MGVI
are contrasted with the maximum-a-posterior (MAP) solution
together
#
posterior samples
are contrasted with the maximum-a-posterior (MAP) solution
# with samples drawn with the Laplace method. This method uses the
local
#
together
with samples drawn with the Laplace method. This method uses the
# curvature at the MAP solution as inverse covariance of a Gaussian
probability
#
local
curvature at the MAP solution as inverse covariance of a Gaussian
# density.
#
probability
density.
###############################################################################
###############################################################################
import
numpy
as
np
import
numpy
as
np
...
@@ -36,8 +36,6 @@ import nifty7 as ift
...
@@ -36,8 +36,6 @@ import nifty7 as ift
def
main
():
def
main
():
use_geo
=
False
name
=
'GEO'
if
use_geo
else
'MGVI'
dom
=
ift
.
UnstructuredDomain
(
1
)
dom
=
ift
.
UnstructuredDomain
(
1
)
scale
=
10
scale
=
10
...
@@ -89,42 +87,52 @@ def main():
...
@@ -89,42 +87,52 @@ def main():
minimizer
=
ift
.
NewtonCG
(
minimizer
=
ift
.
NewtonCG
(
ift
.
GradientNormController
(
iteration_limit
=
2
,
name
=
'Mini'
))
ift
.
GradientNormController
(
iteration_limit
=
2
,
name
=
'Mini'
))
pos
=
ift
.
from_random
(
ham
.
domain
,
'normal'
)
pos
=
pos1
=
ift
.
from_random
(
ham
.
domain
,
'normal'
)
plt
.
figure
(
figsize
=
[
12
,
8
])
fig
,
axs
=
plt
.
subplots
(
2
,
1
,
figsize
=
[
12
,
8
])
for
ii
in
range
(
15
):
for
ii
in
range
(
15
):
if
ii
%
3
==
0
:
if
ii
%
3
==
0
:
if
use_geo
:
# Resample
mini_samp
=
ift
.
NewtonCG
(
mgkl
=
ift
.
MetricGaussianKL
(
pos
,
ham
,
100
,
False
)
ift
.
GradientNormController
(
iteration_limit
=
5
))
mini_samp
=
ift
.
NewtonCG
(
ift
.
GradientNormController
(
iteration_limit
=
5
))
mgkl
=
ift
.
GeoMetricKL
(
pos
,
ham
,
100
,
mini_samp
,
False
)
geokl
=
ift
.
GeoMetricKL
(
pos1
,
ham
,
100
,
mini_samp
,
False
)
else
:
mgkl
=
ift
.
MetricGaussianKL
(
pos
,
ham
,
100
,
False
)
for
axx
in
axs
:
axx
.
clear
()
plt
.
cla
()
im
=
axx
.
imshow
(
z
.
T
,
origin
=
'lower'
,
norm
=
LogNorm
(
vmin
=
1e-3
,
vmax
=
np
.
max
(
z
)),
plt
.
imshow
(
z
.
T
,
origin
=
'lower'
,
norm
=
LogNorm
(
vmin
=
1e-3
,
vmax
=
np
.
max
(
z
)),
cmap
=
'gist_earth_r'
,
extent
=
x_limits_scaled
+
y_limits
)
cmap
=
'gist_earth_r'
,
extent
=
x_limits_scaled
+
y_limits
)
if
ii
==
0
:
if
ii
==
0
:
cbar
=
plt
.
colorbar
(
im
,
ax
=
axx
)
cbar
=
plt
.
colorbar
()
cbar
.
ax
.
set_ylabel
(
'pdf'
)
cbar
.
ax
.
set_ylabel
(
'pdf'
)
xs
,
ys
=
[],
[]
for
jj
,
nn
,
kl
,
pp
in
((
0
,
"MGVI"
,
mgkl
,
pos
),
(
1
,
"GeoVI"
,
geokl
,
pos1
)):
for
samp
in
mgkl
.
samples
:
xs
,
ys
=
[],
[]
samp
=
(
samp
+
pos
).
val
for
samp
in
kl
.
samples
:
xs
.
append
(
samp
[
'a'
])
samp
=
(
samp
+
pp
).
val
ys
.
append
(
samp
[
'b'
])
xs
.
append
(
samp
[
'a'
])
plt
.
scatter
(
np
.
array
(
map_xs
)
*
scale
,
np
.
array
(
map_ys
),
ys
.
append
(
samp
[
'b'
])
label
=
'Laplace samples'
)
axs
[
jj
].
scatter
(
np
.
array
(
xs
)
*
scale
,
np
.
array
(
ys
),
label
=
f
'
{
nn
}
samples'
)
plt
.
scatter
(
np
.
array
(
xs
)
*
scale
,
np
.
array
(
ys
),
label
=
name
+
' samples'
)
axs
[
jj
].
scatter
(
pp
.
val
[
'a'
]
*
scale
,
pp
.
val
[
'b'
],
label
=
f
'
{
nn
}
latent mean'
)
plt
.
scatter
(
pos
.
val
[
'a'
]
*
scale
,
pos
.
val
[
'b'
],
label
=
name
+
' latent mean'
)
axs
[
jj
].
set_title
(
nn
)
plt
.
scatter
(
MAP
.
position
.
val
[
'a'
]
*
scale
,
MAP
.
position
.
val
[
'b'
],
label
=
'Maximum a posterior solution'
)
for
axx
in
axs
:
plt
.
xlim
(
x_limits_scaled
)
axx
.
scatter
(
np
.
array
(
map_xs
)
*
scale
,
np
.
array
(
map_ys
),
plt
.
ylim
(
y_limits
)
label
=
'Laplace samples'
)
plt
.
legend
()
axx
.
scatter
(
MAP
.
position
.
val
[
'a'
]
*
scale
,
MAP
.
position
.
val
[
'b'
],
label
=
'Maximum a posterior solution'
)
axx
.
set_xlim
(
x_limits_scaled
)
axx
.
set_ylim
(
y_limits
)
axx
.
set_ylabel
(
'y'
)
axx
.
legend
(
loc
=
'lower right'
)
axs
[
0
].
xaxis
.
set_visible
(
False
)
axs
[
1
].
set_xlabel
(
'x'
)
plt
.
tight_layout
()
plt
.
draw
()
plt
.
draw
()
plt
.
pause
(
1.0
)
plt
.
pause
(
1.0
)
mgkl
,
_
=
minimizer
(
mgkl
)
mgkl
,
_
=
minimizer
(
mgkl
)
geokl
,
_
=
minimizer
(
geokl
)
pos
=
mgkl
.
position
pos
=
mgkl
.
position
pos1
=
geokl
.
position
ift
.
logger
.
info
(
'Finished'
)
ift
.
logger
.
info
(
'Finished'
)
# Uncomment the following line in order to leave the plots open
# Uncomment the following line in order to leave the plots open
# plt.show()
# plt.show()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment