From bedbdb646bb5d1115489af293ea0ef6b42400b1e Mon Sep 17 00:00:00 2001 From: Martin Reinecke Date: Fri, 13 Jul 2018 19:45:34 +0200 Subject: [PATCH] adjust demos --- demos/bernoulli_demo.py | 7 ++++--- demos/getting_started_1.py | 8 +++++--- demos/getting_started_3.py | 21 ++++++++++++--------- nifty5/plotting/plot.py | 16 ++++------------ 4 files changed, 25 insertions(+), 27 deletions(-) diff --git a/demos/bernoulli_demo.py b/demos/bernoulli_demo.py index 43ba4b56..8335e9af 100644 --- a/demos/bernoulli_demo.py +++ b/demos/bernoulli_demo.py @@ -88,6 +88,7 @@ if __name__ == '__main__': reconstruction = sky.at(H.position).value - ift.plot(reconstruction, title='reconstruction', name='reconstruction.png') - ift.plot(GR.adjoint_times(data), title='data', name='data.png') - ift.plot(sky.at(mock_position).value, title='truth', name='truth.png') + ift.plot(reconstruction, title='reconstruction') + ift.plot(GR.adjoint_times(data), title='data') + ift.plot(sky.at(mock_position).value, title='truth') + ift.plot_finish(nx=3, xsize=16, ysize=5, title="results", name="bernoulli.png") diff --git a/demos/getting_started_1.py b/demos/getting_started_1.py index 336dfb52..7279cc00 100644 --- a/demos/getting_started_1.py +++ b/demos/getting_started_1.py @@ -46,7 +46,7 @@ if __name__ == '__main__': # FIXME description of the tutorial # Choose problem geometry and masking - mode = 0 + mode = 1 if mode == 0: # One dimensional regular grid position_space = ift.RGSpace([1024]) @@ -108,10 +108,12 @@ if __name__ == '__main__': label=['Mock signal', 'Data', 'Reconstruction'], alpha=[1, .3, 1]) ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL))) - ift.plot_finish(1, 2, xsize=10, ysize=4, title="getting_started_1") + ift.plot_finish(nx=2, ny=1, xsize=10, ysize=4, + title="getting_started_1") else: ift.plot(HT(MOCK_SIGNAL), title='Mock Signal') ift.plot(mask_to_nan(mask, (GR*Mask).adjoint(data)), title='Data') ift.plot(HT(m), title='Reconstruction') ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL))) - ift.plot_finish(2, 2, xsize=10, ysize=8, title="getting_started_1") + ift.plot_finish(nx=4, ny=1, xsize=20, ysize=4, + title="getting_started_1") diff --git a/demos/getting_started_3.py b/demos/getting_started_3.py index da25759a..47776c1c 100644 --- a/demos/getting_started_3.py +++ b/demos/getting_started_3.py @@ -83,13 +83,14 @@ if __name__ == '__main__': INITIAL_POSITION = ift.from_random('normal', H.position.domain) position = INITIAL_POSITION - ift.plot(signal.at(MOCK_POSITION).value, name='truth.png') - ift.plot(R.adjoint_times(data), name='data.png') - ift.plot([A.at(MOCK_POSITION).value], name='power.png') + ift.plot(signal.at(MOCK_POSITION).value, title='ground truth') + ift.plot(R.adjoint_times(data), title='data') + ift.plot([A.at(MOCK_POSITION).value], title='power') + ift.plot_finish(nx=3, xsize=16, ysize=5, title="setup", name="setup.png") # number of samples used to estimate the KL N_samples = 20 - for i in range(5): + for i in range(2): H = H.at(position) samples = [H.metric.draw_sample(from_inverse=True) for _ in range(N_samples)] @@ -99,17 +100,19 @@ if __name__ == '__main__': KL, convergence = minimizer(KL) position = KL.position - ift.plot(signal.at(position).value, name='reconstruction.png') + ift.plot(signal.at(position).value, title="reconstruction") ift.plot([A.at(position).value, A.at(MOCK_POSITION).value], - name='power.png') + title="power") + ift.plot_finish(nx=2, xsize=12, ysize=6, title="loop", name="loop.png") sc = ift.StatCalculator() for sample in samples: sc.add(signal.at(sample+position).value) - ift.plot(sc.mean, name='avrg.png') - ift.plot(ift.sqrt(sc.var), name='std.png') + ift.plot(sc.mean, title="mean") + ift.plot(ift.sqrt(sc.var), title="std deviation") powers = [A.at(s+position).value for s in samples] ift.plot([A.at(position).value, A.at(MOCK_POSITION).value]+powers, - name='power.png') + title="power") + ift.plot_finish(nx=3, xsize=16, ysize=5, title="results", name="results.png") diff --git a/nifty5/plotting/plot.py b/nifty5/plotting/plot.py index 5413d76b..3aeaa642 100644 --- a/nifty5/plotting/plot.py +++ b/nifty5/plotting/plot.py @@ -86,16 +86,6 @@ def _makeplot(name): elif extension == ".png": plt.savefig(name) plt.close() - # elif extension==".html": - # import mpld3 - # mpld3.save_html(plt.gcf(),fileobj=name,no_extras=True) - # import plotly.offline as py - # import plotly.tools as tls - # plotly_fig = tls.mpl_to_plotly(plt.gcf()) - # py.plot(plotly_fig,filename=name) - # py.plot_mpl(plt.gcf(),filename=name) - # import bokeh - # bokeh.mpl.to_bokeh(plt.gcf()) else: raise ValueError("file format not understood") @@ -306,18 +296,20 @@ def plot(f, **kwargs): _plots.append(f) _kwargs.append(kwargs) -def plot_finish(nx, ny, **kwargs): +def plot_finish(**kwargs): global _plots, _kwargs import matplotlib.pyplot as plt nplot = len(_plots) fig = plt.figure() if "title" in kwargs: plt.suptitle(kwargs.pop("title")) + nx = kwargs.pop("nx", 1) + ny = kwargs.pop("ny", 1) xsize = kwargs.pop("xsize", 6) ysize = kwargs.pop("ysize", 6) fig.set_size_inches(xsize, ysize) for i in range(nplot): - ax = fig.add_subplot(nx,ny,i+1) + ax = fig.add_subplot(ny,nx,i+1) _plot(_plots[i], ax, **_kwargs[i]) _makeplot(kwargs.pop("name", None)) _plots = [] -- GitLab