Commit a03a8fbd authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'plottest' into 'NIFTy_5'

Allow multifigure plots

See merge request ift/nifty-dev!57
parents 19a7c1ed 6d38f478
......@@ -88,6 +88,7 @@ if __name__ == '__main__':
reconstruction =
ift.plot(reconstruction, title='reconstruction', name='reconstruction.png')
ift.plot(GR.adjoint_times(data), title='data', name='data.png')
ift.plot(, title='truth', name='truth.png')
ift.plot(reconstruction, title='reconstruction')
ift.plot(GR.adjoint_times(data), title='data')
ift.plot(, title='truth')
ift.plot_finish(nx=3, xsize=16, ysize=5, title="results", name="bernoulli.png")
......@@ -46,7 +46,7 @@ if __name__ == '__main__':
# FIXME description of the tutorial
# Choose problem geometry and masking
mode = 0
mode = 1
if mode == 0:
# One dimensional regular grid
position_space = ift.RGSpace([1024])
......@@ -106,11 +106,14 @@ if __name__ == '__main__':
if rg and len(position_space.shape) == 1:
ift.plot([HT(MOCK_SIGNAL), GR.adjoint(data), HT(m)],
label=['Mock signal', 'Data', 'Reconstruction'],
alpha=[1, .3, 1],
alpha=[1, .3, 1])
ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)))
ift.plot_finish(nx=2, ny=1, xsize=10, ysize=4,
ift.plot(HT(MOCK_SIGNAL), title='Mock Signal', name='mock_signal.png')
ift.plot(mask_to_nan(mask, (GR*Mask).adjoint(data)),
title='Data', name='data.png')
ift.plot(HT(m), title='Reconstruction', name='reconstruction.png')
ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)), name='residuals.png')
ift.plot(HT(MOCK_SIGNAL), title='Mock Signal')
ift.plot(mask_to_nan(mask, (GR*Mask).adjoint(data)), title='Data')
ift.plot(HT(m), title='Reconstruction')
ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)))
ift.plot_finish(nx=2, ny=2, xsize=10, ysize=10,
......@@ -83,13 +83,14 @@ if __name__ == '__main__':
INITIAL_POSITION = ift.from_random('normal', H.position.domain)
ift.plot(, name='truth.png')
ift.plot(R.adjoint_times(data), name='data.png')
ift.plot([], name='power.png')
ift.plot(, title='ground truth')
ift.plot(R.adjoint_times(data), title='data')
ift.plot([], title='power')
ift.plot_finish(nx=3, xsize=16, ysize=5, title="setup", name="setup.png")
# number of samples used to estimate the KL
N_samples = 20
for i in range(5):
for i in range(2):
H =
samples = [H.metric.draw_sample(from_inverse=True)
for _ in range(N_samples)]
......@@ -99,17 +100,19 @@ if __name__ == '__main__':
KL, convergence = minimizer(KL)
position = KL.position
ift.plot(, name='reconstruction.png')
ift.plot(, title="reconstruction")
ift.plot_finish(nx=2, xsize=12, ysize=6, title="loop", name="loop.png")
sc = ift.StatCalculator()
for sample in samples:
ift.plot(sc.mean, name='avrg.png')
ift.plot(ift.sqrt(sc.var), name='std.png')
ift.plot(sc.mean, title="mean")
ift.plot(ift.sqrt(sc.var), title="std deviation")
powers = [ for s in samples]
ift.plot_finish(nx=3, xsize=16, ysize=5, title="results", name="results.png")
import nifty5 as ift
import numpy as np
def plot_test():
rg_space1 = ift.makeDomain(ift.RGSpace((100,)))
rg_space2 = ift.makeDomain(ift.RGSpace((80, 80)))
hp_space = ift.makeDomain(ift.HPSpace(64))
gl_space = ift.makeDomain(ift.GLSpace(128))
fft = ift.FFTOperator(rg_space2)
field_rg1_1 = ift.Field.from_global_data(rg_space1, np.random.randn(100))
field_rg1_2 = ift.Field.from_global_data(rg_space1, np.random.randn(100))
field_rg2 = ift.Field.from_global_data(
rg_space2, np.random.randn(80 ** 2).reshape((80, 80)))
field_hp = ift.Field.from_global_data(hp_space, np.random.randn(12*64**2))
field_gl = ift.Field.from_global_data(gl_space, np.random.randn(32640))
field_ps = ift.power_analyze(fft.times(field_rg2))
## Start various plotting tests
ift.plot(field_rg1_1, title='Single plot')
ift.plot(field_rg2, title='2d rg')
ift.plot([field_rg1_1, field_rg1_2], title='list 1d rg', label=['1', '2'])
ift.plot(field_rg1_2, title='1d rg, xmin, ymin', xmin=0.5, ymin=0.,
xlabel='xmin=0.5', ylabel='ymin=0')
ift.plot_finish(title='Three plots')
ift.plot(field_hp, title='HP planck-color', colormap='Planck-like')
ift.plot(field_rg1_2, title='1d rg')
ift.plot(field_gl, title='GL')
ift.plot(field_rg2, title='2d rg')
ift.plot_finish(nx=2, ny=3, title='Five plots')
if __name__ == '__main__':
......@@ -75,7 +75,7 @@ from .minimization.line_energy import LineEnergy
from .minimization.energy_sum import EnergySum
from .sugar import *
from .plotting.plot import plot
from .plotting.plot import plot, plot_finish
from .library.amplitude_model import make_amplitude_model
from .library.gaussian_energy import GaussianEnergy
......@@ -23,7 +23,12 @@ import os
import numpy as np
from ..compat import *
from .. import Field, GLSpace, HPSpace, PowerSpace, RGSpace, dobj
from ..field import Field
from import GLSpace
from import HPSpace
from import PowerSpace
from import RGSpace
from .. import dobj
# relevant properties:
# - x/y size
......@@ -81,16 +86,6 @@ def _makeplot(name):
elif extension == ".png":
# elif extension==".html":
# import mpld3
# mpld3.save_html(plt.gcf(),fileobj=name,no_extras=True)
# import plotly.offline as py
# import as tls
# plotly_fig = tls.mpl_to_plotly(plt.gcf())
# py.plot(plotly_fig,filename=name)
# py.plot_mpl(plt.gcf(),filename=name)
# import bokeh
# bokeh.mpl.to_bokeh(plt.gcf())
raise ValueError("file format not understood")
......@@ -169,7 +164,7 @@ def _register_cmaps():
plt.register_cmap(cmap=LinearSegmentedColormap("Plus Minus", pm_cmap))
def plot(f, **kwargs):
def _plot(f, ax, **kwargs):
import matplotlib.pyplot as plt
if isinstance(f, Field):
......@@ -209,12 +204,6 @@ def plot(f, **kwargs):
alpha = [alpha]
dom = dom[0]
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
xsize = kwargs.pop("xsize", 6)
ysize = kwargs.pop("ysize", 6)
fig.set_size_inches(xsize, ysize)
ax.set_title(kwargs.pop("title", ""))
ax.set_xlabel(kwargs.pop("xlabel", ""))
ax.set_ylabel(kwargs.pop("ylabel", ""))
......@@ -231,7 +220,6 @@ def plot(f, **kwargs):
if label != ([None]*len(f)):
elif len(dom.shape) == 2:
f = f[0]
......@@ -251,7 +239,6 @@ def plot(f, **kwargs):
# plt.colorbar(im,cax=cax)
elif isinstance(dom, PowerSpace):
......@@ -265,7 +252,6 @@ def plot(f, **kwargs):
if label != ([None]*len(f)):
elif isinstance(dom, HPSpace):
f = f[0]
......@@ -282,7 +268,6 @@ def plot(f, **kwargs):
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
cmap=cmap, origin="lower")
elif isinstance(dom, GLSpace):
f = f[0]
......@@ -300,7 +285,85 @@ def plot(f, **kwargs):
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
cmap=cmap, origin="lower")
raise ValueError("Field type not(yet) supported")
_plots = []
_kwargs = []
def plot(f, **kwargs):
"""Add a figure to the current list of plots.
After doing one or more calls `plot()`, one also needs to call
`plot_finish()` to output the result.
f: Field, or list of Field objects
If `f` is a single Field, it must live over a single `RGSpace`,
`PowerSpace`, `HPSpace`, `GLSPace`.
If it is a list, all list members must be Fields living over the same
one-dimensional `RGSpace` or `PowerSpace`.
title: string
title of the plot
xlabel: string
label for the x axis
ylabel: string
label for the y axis
[xyz]min, [xyz]max: float
limits for the values to plot
colormap: string
color map to use for the plot (if it is a 2D plot)
linewidth: float or list of floats
line width
label: string of list of strings
annotation string
alpha: float or list of floats
transparency value
def plot_finish(**kwargs):
"""Plot the accumulated list of figures.
title: string
title of the full plot
nx, ny: integer (default: square root of the numer of plots, rounded up)
number of subplots to use in x- and y-direction
xsize, ysize: float (default: 6)
dimensions of the full plot in inches
name: string (default: "")
if left empty, the plot will be shown on the screen,
otherwise it will be written to a file with the given name.
Supported extensions: .png and .pdf
global _plots, _kwargs
import matplotlib.pyplot as plt
nplot = len(_plots)
fig = plt.figure()
if "title" in kwargs:
nx = kwargs.pop("nx", int(np.ceil(np.sqrt(nplot))))
ny = kwargs.pop("ny", int(np.ceil(np.sqrt(nplot))))
if nx*ny < nplot:
raise ValueError(
'Figure dimensions not sufficient for number of plots')
xsize = kwargs.pop("xsize", 6)
ysize = kwargs.pop("ysize", 6)
fig.set_size_inches(xsize, ysize)
for i in range(nplot):
ax = fig.add_subplot(ny, nx, i+1)
_plot(_plots[i], ax, **_kwargs[i])
_makeplot(kwargs.pop("name", None))
_plots = []
_kwargs = []
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment