Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
d2c0398e
Commit
d2c0398e
authored
Feb 21, 2022
by
Philipp Arras
Browse files
Optimize_kl: add plotting kwargs
parent
64ce0634
Pipeline
#123298
passed with stages
in 18 minutes and 26 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
demos/getting_started_3.py
View file @
d2c0398e
...
...
@@ -130,7 +130,8 @@ def main():
n_samples
=
lambda
iiter
:
10
if
iiter
<
5
else
20
samples
=
ift
.
optimize_kl
(
likelihood_energy
,
n_iterations
,
n_samples
,
minimizer
,
ic_sampling
,
minimizer_sampling
,
plottable_operators
=
{
"signal"
:
signal
,
"power spectrum"
:
pspec
},
plottable_operators
=
{
"signal"
:
(
signal
,
dict
(
vmin
=
0
,
vmax
=
1
)),
"power spectrum"
:
pspec
},
ground_truth_position
=
mock_position
,
output_directory
=
"getting_started_3_results"
,
overwrite
=
True
,
comm
=
comm
,
inspect_callback
=
callback
)
...
...
src/minimization/optimize_kl.py
View file @
d2c0398e
...
...
@@ -12,6 +12,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
# Copyright(C) 2022 Max-Planck-Society, Philipp Arras
# Author: Philipp Arras
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...
...
@@ -21,6 +22,8 @@ from os import makedirs
from
os.path
import
isdir
,
isfile
,
join
from
warnings
import
warn
from
matplotlib.colors
import
LogNorm
from
..domain_tuple
import
DomainTuple
from
..multi_domain
import
MultiDomain
from
..multi_field
import
MultiField
...
...
@@ -123,7 +126,10 @@ def optimize_kl(likelihood_energy,
Default is to draw samples for the complete domain.
plottable_operators : dict
Dictionary of operators that are plotted during the minimization. The
key contains a string that serves as identifier.
key contains a string that serves as identifier. The value of the
dictionary can either be an operator or a tuple of an operator and a
dictionary that contains kwargs for the plotting that are passed into
the NIFTy plotting routine.
output_directory : str or None
Directory in which all output files are saved. If None, no output is
stored. Default: "nifty_optimize_kl_output".
...
...
@@ -277,6 +283,10 @@ def optimize_kl(likelihood_energy,
for
k1
,
op
in
plottable_operators
.
items
():
if
mf_dom
:
if
isinstance
(
op
,
tuple
)
and
len
(
op
)
==
2
:
if
not
isinstance
(
op
[
1
],
dict
):
raise
TypeError
op
=
op
[
0
]
for
k2
,
vv
in
op
.
domain
.
items
():
if
k2
in
dom
.
keys
()
and
dom
[
k2
]
!=
vv
:
raise
ValueError
(
f
"The domain of plottable operator '
{
k1
}
' "
...
...
@@ -428,14 +438,19 @@ def _plot_operators(output_directory, index, plottable_operators, sample_list,
raise
TypeError
for
name
,
op
in
plottable_operators
.
items
():
plotting_kwargs
=
{}
if
isinstance
(
op
,
tuple
)
and
len
(
op
)
==
2
:
op
,
plotting_kwargs
=
op
if
not
isinstance
(
plotting_kwargs
,
dict
):
raise
TypeError
if
not
_is_subdomain
(
op
.
domain
,
sample_list
.
domain
):
continue
gt
=
_op_force_or_none
(
op
,
ground_truth
)
fname
=
_file_name
(
output_directory
,
name
,
index
,
"samples_"
)
_plot_samples
(
fname
,
sample_list
.
iterator
(
op
),
gt
,
comm
)
_plot_samples
(
fname
,
sample_list
.
iterator
(
op
),
gt
,
comm
,
plotting_kwargs
)
if
sample_list
.
n_samples
>
1
:
fname
=
_file_name
(
output_directory
,
name
,
index
,
"stats_"
)
_plot_stats
(
fname
,
*
sample_list
.
sample_stat
(
op
),
gt
,
comm
)
_plot_stats
(
fname
,
*
sample_list
.
sample_stat
(
op
),
gt
,
comm
,
plotting_kwargs
)
op_direc
=
join
(
output_directory
,
name
)
if
sample_list
.
n_samples
>
1
:
...
...
@@ -463,7 +478,7 @@ def _plot_operators(output_directory, index, plottable_operators, sample_list,
pass
def
_plot_samples
(
file_name
,
samples
,
ground_truth
,
comm
):
def
_plot_samples
(
file_name
,
samples
,
ground_truth
,
comm
,
plotting_kwargs
):
samples
=
list
(
samples
)
if
_MPI_master
(
comm
):
...
...
@@ -481,12 +496,13 @@ def _plot_samples(file_name, samples, ground_truth, comm):
if
plottable2D
(
samples
[
0
][
kk
]):
if
ground_truth
is
not
None
:
p
.
add
(
ground_truth
[
kk
],
title
=
_append_key
(
"Ground truth"
,
kk
))
p
.
add
(
ground_truth
[
kk
],
title
=
_append_key
(
"Ground truth"
,
kk
),
**
plotting_kwargs
)
p
.
add
(
None
)
for
ii
,
ss
in
enumerate
(
single_samples
):
if
(
ground_truth
is
None
and
ii
==
16
)
or
(
ground_truth
is
not
None
and
ii
==
14
):
break
p
.
add
(
ss
,
title
=
_append_key
(
f
"Samples
{
ii
}
"
,
kk
))
p
.
add
(
ss
,
title
=
_append_key
(
f
"Samples
{
ii
}
"
,
kk
)
,
**
plotting_kwargs
)
else
:
n
=
len
(
samples
)
alpha
=
n
*
[
0.5
]
...
...
@@ -497,7 +513,8 @@ def _plot_samples(file_name, samples, ground_truth, comm):
alpha
=
[
1.
]
+
alpha
color
=
[
"green"
]
+
color
label
=
[
"Ground truth"
,
"Samples"
]
+
(
n
-
1
)
*
[
None
]
p
.
add
(
single_samples
,
color
=
color
,
alpha
=
alpha
,
label
=
label
,
title
=
_append_key
(
"Samples"
,
kk
))
p
.
add
(
single_samples
,
color
=
color
,
alpha
=
alpha
,
label
=
label
,
title
=
_append_key
(
"Samples"
,
kk
),
**
plotting_kwargs
)
p
.
output
(
name
=
file_name
)
...
...
@@ -507,12 +524,12 @@ def _append_key(s, key):
return
f
"
{
s
}
(
{
key
}
)"
def
_plot_stats
(
file_name
,
mean
,
var
,
ground_truth
,
comm
):
def
_plot_stats
(
file_name
,
mean
,
var
,
ground_truth
,
comm
,
plotting_kwargs
):
p
=
Plot
()
if
ground_truth
is
not
None
:
p
.
add
(
ground_truth
,
title
=
"Ground truth"
)
p
.
add
(
mean
,
title
=
"Mean"
)
p
.
add
(
var
.
sqrt
(),
vmin
=
0
,
title
=
"Standard deviation"
)
p
.
add
(
ground_truth
,
title
=
"Ground truth"
,
**
plotting_kwargs
)
p
.
add
(
mean
,
title
=
"Mean"
,
**
plotting_kwargs
)
p
.
add
(
var
.
sqrt
(),
title
=
"Standard deviation"
,
norm
=
LogNorm
()
)
if
_MPI_master
(
comm
):
p
.
output
(
name
=
file_name
,
ny
=
2
if
ground_truth
is
None
else
3
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment