From df6b5347a83b97dc7fc6e2534db22e87534b5801 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Fri, 8 Sep 2017 21:05:57 +0200
Subject: [PATCH] more functionality

---
 nifty2go/plotting/plot.py | 63 +++++++++++++++++++++++++++++++--------
 1 file changed, 50 insertions(+), 13 deletions(-)

diff --git a/nifty2go/plotting/plot.py b/nifty2go/plotting/plot.py
index 89befb7a3..f26e48613 100644
--- a/nifty2go/plotting/plot.py
+++ b/nifty2go/plotting/plot.py
@@ -65,7 +65,20 @@ def _makeplot(name):
     else:
         raise ValueError("file format not understood")
 
-def plot (f,name=None):
+def _limit_xy(**kwargs):
+    import matplotlib.pyplot as plt
+    x1,x2,y1,y2 = plt.axis()
+    if (kwargs.get("xmin")) is not None:
+        x1 = kwargs.get("xmin")
+    if (kwargs.get("xmax")) is not None:
+        x2 = kwargs.get("xmax")
+    if (kwargs.get("ymin")) is not None:
+        y1 = kwargs.get("ymin")
+    if (kwargs.get("ymax")) is not None:
+        y2 = kwargs.get("ymax")
+    plt.axis((x1,x2,y1,y2))
+
+def plot(f,**kwargs):
     import matplotlib.pyplot as plt
     if not isinstance(f,Field):
         raise TypeError("incorrect data type")
@@ -73,8 +86,25 @@ def plot (f,name=None):
         raise ValueError("input field must have exactly one domain")
 
     dom = f.domain[0]
-    plt.gcf().set_size_inches(12,12)
+    fig = plt.figure()
+    ax = fig.add_subplot(1,1,1)
+
+    xsize,ysize = 6,6
+    if kwargs.get("xsize") is not None:
+        xsize = kwargs.get("xsize")
+    if kwargs.get("ysize") is not None:
+        ysize = kwargs.get("ysize")
+    fig.set_size_inches(xsize,ysize)
 
+    if kwargs.get("title") is not None:
+        ax.set_title(kwargs.get("title"))
+    if kwargs.get("xlabel") is not None:
+        ax.set_xlabel(kwargs.get("xlabel"))
+    if kwargs.get("ylabel") is not None:
+        ax.set_ylabel(kwargs.get("ylabel"))
+    cmap=plt.rcParams['image.cmap']
+    if kwargs.get("colormap") is not None:
+        cmap = kwargs.get("colormap")
     if isinstance(dom, RGSpace):
         if len(dom.shape)==1:
             npoints = dom.shape[0]
@@ -82,7 +112,8 @@ def plot (f,name=None):
             xcoord = np.arange(npoints,dtype=np.float64)*dist
             ycoord = f.val
             plt.plot(xcoord, ycoord)
-            _makeplot(name)
+            _limit_xy(**kwargs)
+            _makeplot(kwargs.get("name"))
             return
         elif len(dom.shape)==2:
             nx = dom.shape[0]
@@ -91,9 +122,14 @@ def plot (f,name=None):
             dy = dom.distances[1]
             xc = np.arange(nx,dtype=np.float64)*dx
             yc = np.arange(ny,dtype=np.float64)*dy
-            plt.imshow(f.val,extent=[xc[0],xc[-1],yc[0],yc[-1]])
-            plt.colorbar()
-            _makeplot(name)
+            im=ax.imshow(f.val,extent=[xc[0],xc[-1],yc[0],yc[-1]],vmin=kwargs.get("zmin"),vmax=kwargs.get("zmax"),cmap=cmap)
+            #from mpl_toolkits.axes_grid1 import make_axes_locatable
+            #divider = make_axes_locatable(ax)
+            #cax = divider.append_axes("right", size="5%", pad=0.05)
+            #plt.colorbar(im,cax=cax)
+            plt.colorbar(im)
+            _limit_xy(**kwargs)
+            _makeplot(kwargs.get("name"))
             return
     elif isinstance(dom, PowerSpace):
         xcoord = dom.kindex
@@ -102,7 +138,8 @@ def plot (f,name=None):
         plt.yscale('log')
         plt.title('power')
         plt.plot(xcoord, ycoord)
-        _makeplot(name)
+        _limit_xy(**kwargs)
+        _makeplot(kwargs.get("name"))
         return
     elif isinstance(dom, HPSpace):
         import pyHealpix
@@ -115,9 +152,9 @@ def plot (f,name=None):
         base = pyHealpix.Healpix_Base(int(np.sqrt(f.val.size//12)), "RING")
         res[mask] = f.val[base.ang2pix(ptg)]
         plt.axis('off')
-        plt.imshow(res)
-        plt.colorbar()
-        _makeplot(name)
+        plt.imshow(res,vmin=kwargs.get("zmin"),vmax=kwargs.get("zmax"),cmap=cmap)
+        plt.colorbar(orientation="horizontal")
+        _makeplot(kwargs.get("name"))
         return
     elif isinstance(dom, GLSpace):
         import pyHealpix
@@ -131,9 +168,9 @@ def plot (f,name=None):
         res[mask] = f.val[ilat*dom.nlon + ilon]
 
         plt.axis('off')
-        plt.imshow(res)
-        plt.colorbar()
-        _makeplot(name)
+        plt.imshow(res,vmin=kwargs.get("zmin"),vmax=kwargs.get("zmax"),cmap=cmap)
+        plt.colorbar(orientation="horizontal")
+        _makeplot(kwargs.get("name"))
         return
 
     raise ValueError("Field type not(yet) supported")
-- 
GitLab