From 4af129bb0b961aee8bf0a20f07578af531032386 Mon Sep 17 00:00:00 2001 From: Martin Reinecke Date: Fri, 13 Jul 2018 14:53:39 +0200 Subject: [PATCH] more tweaks --- demos/getting_started_1.py | 20 ++++++++++---------- nifty5/__init__.py | 2 +- nifty5/plotting/plot.py | 34 ++++++++++++++++++---------------- 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/demos/getting_started_1.py b/demos/getting_started_1.py index 0a0220c0..336dfb52 100644 --- a/demos/getting_started_1.py +++ b/demos/getting_started_1.py @@ -46,7 +46,7 @@ if __name__ == '__main__': # FIXME description of the tutorial # Choose problem geometry and masking - mode = 1 + mode = 0 if mode == 0: # One dimensional regular grid position_space = ift.RGSpace([1024]) @@ -104,14 +104,14 @@ if __name__ == '__main__': # PLOTTING rg = isinstance(position_space, ift.RGSpace) if rg and len(position_space.shape) == 1: - ift.add_plot([HT(MOCK_SIGNAL), GR.adjoint(data), HT(m)], + ift.plot([HT(MOCK_SIGNAL), GR.adjoint(data), HT(m)], label=['Mock signal', 'Data', 'Reconstruction'], - alpha=[1, .3, 1], - name='getting_started_1.png') + alpha=[1, .3, 1]) + ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL))) + ift.plot_finish(1, 2, xsize=10, ysize=4, title="getting_started_1") else: - ift.add_plot(HT(MOCK_SIGNAL), title='Mock Signal', name='mock_signal.png') - ift.add_plot(mask_to_nan(mask, (GR*Mask).adjoint(data)), - title='Data', name='data.png') - ift.add_plot(HT(m), title='Reconstruction', name='reconstruction.png') - ift.add_plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)), name='residuals.png') - ift.plot() + ift.plot(HT(MOCK_SIGNAL), title='Mock Signal') + ift.plot(mask_to_nan(mask, (GR*Mask).adjoint(data)), title='Data') + ift.plot(HT(m), title='Reconstruction') + ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL))) + ift.plot_finish(2, 2, xsize=10, ysize=8, title="getting_started_1") diff --git a/nifty5/__init__.py b/nifty5/__init__.py index 87fd9d5d..e8b8e1d5 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -73,7 +73,7 @@ from .minimization.line_energy import LineEnergy from .minimization.energy_sum import EnergySum from .sugar import * -from .plotting.plot import add_plot, plot +from .plotting.plot import plot, plot_finish from .library.amplitude_model import make_amplitude_model from .library.gaussian_energy import GaussianEnergy diff --git a/nifty5/plotting/plot.py b/nifty5/plotting/plot.py index c5c6c1aa..5413d76b 100644 --- a/nifty5/plotting/plot.py +++ b/nifty5/plotting/plot.py @@ -23,7 +23,12 @@ import os import numpy as np from ..compat import * -from .. import Field, GLSpace, HPSpace, PowerSpace, RGSpace, dobj +from ..field import Field +from ..domains.gl_space import GLSpace +from ..domains.hp_space import HPSpace +from ..domains.power_space import PowerSpace +from ..domains.rg_space import RGSpace +from .. import dobj # relevant properties: # - x/y size @@ -209,12 +214,6 @@ def _plot(f, ax, **kwargs): alpha = [alpha] dom = dom[0] - #fig = plt.figure() - #ax = fig.add_subplot(1, 1, 1) - - #xsize = kwargs.pop("xsize", 6) - #ysize = kwargs.pop("ysize", 6) - #fig.set_size_inches(xsize, ysize) ax.set_title(kwargs.pop("title", "")) ax.set_xlabel(kwargs.pop("xlabel", "")) ax.set_ylabel(kwargs.pop("ylabel", "")) @@ -231,7 +230,6 @@ def _plot(f, ax, **kwargs): _limit_xy(**kwargs) if label != ([None]*len(f)): plt.legend() - #_makeplot(kwargs.get("name")) return elif len(dom.shape) == 2: f = f[0] @@ -251,7 +249,6 @@ def _plot(f, ax, **kwargs): # plt.colorbar(im,cax=cax) plt.colorbar(im) _limit_xy(**kwargs) - #_makeplot(kwargs.get("name")) return elif isinstance(dom, PowerSpace): plt.xscale('log') @@ -265,7 +262,6 @@ def _plot(f, ax, **kwargs): _limit_xy(**kwargs) if label != ([None]*len(f)): plt.legend() - #_makeplot(kwargs.get("name")) return elif isinstance(dom, HPSpace): f = f[0] @@ -282,7 +278,6 @@ def _plot(f, ax, **kwargs): plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), cmap=cmap, origin="lower") plt.colorbar(orientation="horizontal") - #_makeplot(kwargs.get("name")) return elif isinstance(dom, GLSpace): f = f[0] @@ -300,7 +295,6 @@ def _plot(f, ax, **kwargs): plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), cmap=cmap, origin="lower") plt.colorbar(orientation="horizontal") - #_makeplot(kwargs.get("name")) return raise ValueError("Field type not(yet) supported") @@ -308,15 +302,23 @@ def _plot(f, ax, **kwargs): _plots = [] _kwargs = [] -def add_plot(f, **kwargs): +def plot(f, **kwargs): _plots.append(f) _kwargs.append(kwargs) -def plot(**kwargs): +def plot_finish(nx, ny, **kwargs): + global _plots, _kwargs import matplotlib.pyplot as plt nplot = len(_plots) fig = plt.figure() + if "title" in kwargs: + plt.suptitle(kwargs.pop("title")) + xsize = kwargs.pop("xsize", 6) + ysize = kwargs.pop("ysize", 6) + fig.set_size_inches(xsize, ysize) for i in range(nplot): - ax = fig.add_subplot(nplot,1,i+1) + ax = fig.add_subplot(nx,ny,i+1) _plot(_plots[i], ax, **_kwargs[i]) - _makeplot(None) + _makeplot(kwargs.pop("name", None)) + _plots = [] + _kwargs = [] -- GitLab