From 012f46435c8cce3aecb54600586785b8fc742c6b Mon Sep 17 00:00:00 2001 From: Martin Reinecke Date: Fri, 10 Aug 2018 12:27:28 +0200 Subject: [PATCH] plotting tweaks --- demos/plot_test.py | 4 +- nifty5/plotting/plot.py | 81 ++++++++++++++--------------------------- 2 files changed, 29 insertions(+), 56 deletions(-) diff --git a/demos/plot_test.py b/demos/plot_test.py index 69ad9421..0138b31c 100644 --- a/demos/plot_test.py +++ b/demos/plot_test.py @@ -4,7 +4,7 @@ import numpy as np def plot_test(): rg_space1 = ift.makeDomain(ift.RGSpace((100,))) - rg_space2 = ift.makeDomain(ift.RGSpace((80, 80))) + rg_space2 = ift.makeDomain(ift.RGSpace((80, 60), distances=1)) hp_space = ift.makeDomain(ift.HPSpace(64)) gl_space = ift.makeDomain(ift.GLSpace(128)) @@ -13,7 +13,7 @@ def plot_test(): field_rg1_1 = ift.Field.from_global_data(rg_space1, np.random.randn(100)) field_rg1_2 = ift.Field.from_global_data(rg_space1, np.random.randn(100)) field_rg2 = ift.Field.from_global_data( - rg_space2, np.random.randn(80 ** 2).reshape((80, 80))) + rg_space2, np.random.randn(80*60).reshape((80, 60))) field_hp = ift.Field.from_global_data(hp_space, np.random.randn(12*64**2)) field_gl = ift.Field.from_global_data(gl_space, np.random.randn(32640)) field_ps = ift.power_analyze(fft.times(field_rg2)) diff --git a/nifty5/plotting/plot.py b/nifty5/plotting/plot.py index 00d543d0..1c930fe1 100644 --- a/nifty5/plotting/plot.py +++ b/nifty5/plotting/plot.py @@ -45,11 +45,9 @@ def _mollweide_helper(xsize): xsize = int(xsize) ysize = xsize//2 res = np.full(shape=(ysize, xsize), fill_value=np.nan, dtype=np.float64) - xc = (xsize-1)*0.5 - yc = (ysize-1)*0.5 + xc, yc = (xsize-1)*0.5, (ysize-1)*0.5 u, v = np.meshgrid(np.arange(xsize), np.arange(ysize)) - u = 2*(u-xc)/(xc/1.02) - v = (v-yc)/(yc/1.02) + u, v = 2*(u-xc)/(xc/1.02), (v-yc)/(yc/1.02) mask = np.where((u*u*0.25 + v*v) <= 1.) t1 = v[mask] @@ -62,11 +60,8 @@ def _mollweide_helper(xsize): def _find_closest(A, target): # A must be sorted - idx = A.searchsorted(target) - idx = np.clip(idx, 1, len(A)-1) - left = A[idx-1] - right = A[idx] - idx -= target - left < right - target + idx = np.clip(A.searchsorted(target), 1, len(A)-1) + idx -= target - A[idx-1] < A[idx] - target return idx @@ -80,10 +75,7 @@ def _makeplot(name): plt.close() return extension = os.path.splitext(name)[1] - if extension == ".pdf": - plt.savefig(name) - plt.close() - elif extension == ".png": + if extension in (".pdf", ".png"): plt.savefig(name) plt.close() else: @@ -186,22 +178,16 @@ def _plot(f, ax, **kwargs): raise ValueError("PowerSpace or 1D RGSpace required") label = kwargs.pop("label", None) - if label is None: - label = [None] * len(f) if not isinstance(label, list): - label = [label] + label = [label] * len(f) - linewidth = kwargs.pop("linewidth", None) - if linewidth is None: - linewidth = [1.] * len(f) + linewidth = kwargs.pop("linewidth", 1.) if not isinstance(linewidth, list): - linewidth = [linewidth] + linewidth = [linewidth] * len(f) alpha = kwargs.pop("alpha", None) - if alpha is None: - alpha = [None] * len(f) if not isinstance(alpha, list): - alpha = [alpha] + alpha = [alpha] * len(f) foo = kwargs.pop("norm", None) norm = {} if foo is None else {'norm': foo} @@ -225,14 +211,12 @@ def _plot(f, ax, **kwargs): plt.legend() return elif len(dom.shape) == 2: - f = f[0] nx, ny = dom.shape dx, dy = dom.distances - im = ax.imshow(fld.to_global_data().T, - extent=[0, nx*dx, 0, ny*dy], - vmin=kwargs.get("zmin"), - vmax=kwargs.get("zmax"), cmap=cmap, origin="lower", - **norm) + im = ax.imshow( + f[0].to_global_data().T, extent=[0, nx*dx, 0, ny*dy], + vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), + cmap=cmap, origin="lower", **norm) # from mpl_toolkits.axes_grid1 import make_axes_locatable # divider = make_axes_locatable(ax) # cax = divider.append_axes("right", size="5%", pad=0.05) @@ -252,34 +236,23 @@ def _plot(f, ax, **kwargs): if label != ([None]*len(f)): plt.legend() return - elif isinstance(dom, HPSpace): - f = f[0] + elif isinstance(dom, (HPSpace, GLSpace)): import pyHealpix xsize = 800 res, mask, theta, phi = _mollweide_helper(xsize) - - ptg = np.empty((phi.size, 2), dtype=np.float64) - ptg[:, 0] = theta - ptg[:, 1] = phi - base = pyHealpix.Healpix_Base(int(np.sqrt(f.size//12)), "RING") - res[mask] = f.to_global_data()[base.ang2pix(ptg)] - plt.axis('off') - plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), - cmap=cmap, origin="lower") - plt.colorbar(orientation="horizontal") - return - elif isinstance(dom, GLSpace): - f = f[0] - import pyHealpix - xsize = 800 - res, mask, theta, phi = _mollweide_helper(xsize) - ra = np.linspace(0, 2*np.pi, dom.nlon+1) - dec = pyHealpix.GL_thetas(dom.nlat) - ilat = _find_closest(dec, theta) - ilon = _find_closest(ra, phi) - ilon = np.where(ilon == dom.nlon, 0, ilon) - res[mask] = f.to_global_data()[ilat*dom.nlon + ilon] - + if isinstance(dom, HPSpace): + ptg = np.empty((phi.size, 2), dtype=np.float64) + ptg[:, 0] = theta + ptg[:, 1] = phi + base = pyHealpix.Healpix_Base(int(np.sqrt(f[0].size//12)), "RING") + res[mask] = f[0].to_global_data()[base.ang2pix(ptg)] + else: + ra = np.linspace(0, 2*np.pi, dom.nlon+1) + dec = pyHealpix.GL_thetas(dom.nlat) + ilat = _find_closest(dec, theta) + ilon = _find_closest(ra, phi) + ilon = np.where(ilon == dom.nlon, 0, ilon) + res[mask] = f[0].to_global_data()[ilat*dom.nlon + ilon] plt.axis('off') plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), cmap=cmap, origin="lower") -- GitLab