From 7b46bd6b383c1506c99aaba0201df48885b9a7ea Mon Sep 17 00:00:00 2001
From: "Boeckenhoff, Daniel (dboe)" <daniel.boeckenhoff@ipp.mpg.de>
Date: Thu, 5 Jul 2018 19:44:40 +0200
Subject: [PATCH] added plotting for mesh

---
 tfields/__init__.py          |   1 +
 tfields/mesh3D.py            |  31 +++-
 tfields/plotting/__init__.py | 275 ++++++++++++++++++++++++++++++
 tfields/plotting/mpl.py      | 313 +++++++++++++++++++++++++++++++++++
 4 files changed, 615 insertions(+), 5 deletions(-)
 create mode 100644 tfields/plotting/__init__.py
 create mode 100644 tfields/plotting/mpl.py

diff --git a/tfields/__init__.py b/tfields/__init__.py
index 4397180..2a61ba6 100644
--- a/tfields/__init__.py
+++ b/tfields/__init__.py
@@ -2,6 +2,7 @@ from . import core
 from . import bases
 from . import lib
 from .lib import *
+from . import plotting
 
 # __all__ = ['core', 'points3D']
 from .core import Tensors, TensorFields, TensorMaps
diff --git a/tfields/mesh3D.py b/tfields/mesh3D.py
index c558632..ab406dc 100644
--- a/tfields/mesh3D.py
+++ b/tfields/mesh3D.py
@@ -723,11 +723,32 @@ class Mesh3D(tfields.TensorMaps):
             return obj, template
         return obj
 
-    def plot(self):
-        import mplTools as mpt
-        mpt.plotMesh(self, self.faces, color=self.maps[0].fields[0], vmin=0,
-                     vmax=20, axis=mpt.gca(3))
-        mpt.plt.show()
+    def plot(self, **kwargs):
+        """
+        Forwarding to plotTools.plotMesh
+        """
+        scalars_demanded = any([v in kwargs for v in ['vmin', 'vmax', 'cmap']])
+        map_index = kwargs.pop('map_index', None if not scalars_demanded else 0)
+        if map_index is not None:
+            if not len(self.maps[0]) == 0:
+                kwargs['color'] = self.maps[0].fields[map_index]
+
+        dim_defined = False
+        if 'axis' in kwargs:
+            dim_defined = True
+        if 'zAxis' in kwargs:
+            if kwargs['zAxis'] is not None:
+                kwargs['dim'] = 3
+            else:
+                kwargs['dim'] = 2
+            dim_defined = True
+        if 'dim' in kwargs:
+            dim_defined = True
+
+        if not dim_defined:
+            kwargs['dim'] = 2
+
+        return tfields.plotting.plotMesh(self, self.faces, **kwargs)
 
 
 if __name__ == '__main__':
diff --git a/tfields/plotting/__init__.py b/tfields/plotting/__init__.py
new file mode 100644
index 0000000..ac69c06
--- /dev/null
+++ b/tfields/plotting/__init__.py
@@ -0,0 +1,275 @@
+"""
+Core plotting tools for tfields library. Especially PlotOptions class
+is basis for many plotting expansions
+"""
+import warnings
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+import numpy as np
+from .mpl import *
+
+
+def setDefault(dictionary, attr, value):
+    """
+    Set defaults to a dictionary
+    """
+    if attr not in dictionary:
+        dictionary[attr] = value
+
+
+def gca(dim=None, **kwargs):
+    """
+    Forwarding to plt.gca but translating the dimension to projection
+    correct dimension
+    """
+    if dim == 3:
+        axis = plt.gca(projection='3d', **kwargs)
+    else:
+        axis = plt.gca(**kwargs)
+        if dim != axisDim(axis):
+            if dim is not None:
+                warnings.warn("You have another dimension set as gca."
+                              "I will force the new dimension to return.")
+                axis = plt.gcf().add_subplot(1, 1, 1, **kwargs)
+    return axis
+
+
+def axisDim(axis):
+    """
+    Returns int: axis dimension
+    """
+    if hasattr(axis, 'get_zlim'):
+        return 3
+    else:
+        return 2
+
+
+def setLabels(axis, *labels):
+    axis.set_xlabel(labels[0])
+    axis.set_ylabel(labels[1])
+    if axisDim(axis) == 3:
+        axis.set_zlabel(labels[2])
+
+
+def autoscale3D(axis, array=None, xLim=None, yLim=None, zLim=None):
+    if array is not None:
+        xMin, yMin, zMin = array.min(axis=0)
+        xMax, yMax, zMax = array.max(axis=0)
+        xLim = (xMin, xMax)
+        yLim = (yMin, yMax)
+        zLim = (zMin, zMax)
+    xLimAxis = axis.get_xlim()
+    yLimAxis = axis.get_ylim()
+    zLimAxis = axis.get_zlim()
+
+    if not False:
+        # not empty axis
+        xMin = min(xLimAxis[0], xLim[0])
+        yMin = min(yLimAxis[0], yLim[0])
+        zMin = min(zLimAxis[0], zLim[0])
+        xMax = max(xLimAxis[1], xLim[1])
+        yMax = max(yLimAxis[1], yLim[1])
+        zMax = max(zLimAxis[1], zLim[1])
+    axis.set_xlim([xMin, xMax])
+    axis.set_ylim([yMin, yMax])
+    axis.set_zlim([zMin, zMax])
+
+
+class PlotOptions(object):
+    """
+    processing kwargs for plotting functions and providing easy
+    access to axis, dimension and plotting method as well as indices
+    for array choice (x..., y..., zAxis)
+    """
+    def __init__(self, kwargs):
+        kwargs = dict(kwargs)
+        self.axis = kwargs.pop('axis', None)
+        self.dim = kwargs.pop('dim', None)
+        self.method = kwargs.pop('methodName', None)
+        self.setXYZAxis(kwargs)
+        self.plotKwargs = kwargs
+
+    @property
+    def method(self):
+        """
+        Method for plotting. Will be callable together with plotKwargs
+        """
+        return self._method
+
+    @method.setter
+    def method(self, methodName):
+        if not isinstance(methodName, str):
+            self._method = methodName
+        else:
+            self._method = getattr(self.axis, methodName)
+
+    @property
+    def dim(self):
+        """
+        axis dimension
+        """
+        return self._dim
+
+    @dim.setter
+    def dim(self, dim):
+        if dim is None:
+            if self._axis is None:
+                dim = 2
+            dim = axisDim(self._axis)
+        elif self._axis is not None:
+            if not dim == axisDim(self._axis):
+                raise ValueError("Axis and dim argument are in conflict.")
+        if dim not in [2, 3]:
+            raise NotImplementedError("Dimensions other than 2 or 3 are not supported.")
+        self._dim = dim
+
+    @property
+    def axis(self):
+        """
+        The plt.Axis object that belongs to this instance
+        """
+        if self._axis is None:
+            return gca(self._dim)
+        else:
+            return self._axis
+
+    @axis.setter
+    def axis(self, axis):
+        self._axis = axis
+
+    def setXYZAxis(self, kwargs):
+        self._xAxis = kwargs.pop('xAxis', 0)
+        self._yAxis = kwargs.pop('yAxis', 1)
+        zAxis = kwargs.pop('zAxis', None)
+        if zAxis is None and self.dim == 3:
+            indicesUsed = [0, 1, 2]
+            indicesUsed.remove(self._xAxis)
+            indicesUsed.remove(self._yAxis)
+            zAxis = indicesUsed[0]
+        self._zAxis = zAxis
+
+    def getXYZAxis(self):
+        return self._xAxis, self._yAxis, self._zAxis
+
+    def setVminVmaxAuto(self, vmin, vmax, scalars):
+        """
+        Automatically set vmin and vmax as min/max of scalars
+        but only if vmin or vmax is None
+        """
+        if scalars is None:
+            return
+        if len(scalars) < 2:
+            warnings.warn("Need at least two scalars to autoset vmin and/or vmax!")
+            return
+        if vmin is None:
+            vmin = min(scalars)
+            self.plotKwargs['vmin'] = vmin
+        if vmax is None:
+            vmax = min(scalars)
+            self.plotKwargs['vmax'] = vmax
+
+    def getNormArgs(self, vminDefault=0, vmaxDefault=1, cmapDefault=None):
+        if cmapDefault is None:
+            cmapDefault = plt.rcParams['image.cmap']
+        cmap = self.get('cmap', cmapDefault)
+        vmin = self.get('vmin', vminDefault)
+        vmax = self.get('vmax', vmaxDefault)
+        return cmap, vmin, vmax
+
+    def formatColors(self, colors, fmt='rgba', length=None):
+        """
+        format colors according to fmt argument
+        Args:
+            colors (list/one value of rgba tuples/int/float/str): This argument will
+                be interpreted as color
+            fmt (str): rgba / norm
+            length (int/None): if not None: corrct colors lenght
+    
+        Returns:
+            colors in fmt
+        """
+        hasIter = True
+        if not hasattr(colors, '__iter__'):
+            # colors is just one element
+            hasIter = False
+            colors = [colors]
+        if hasattr(colors[0], '__iter__') and fmt == 'norm':
+            # rgba given but norm wanted
+            cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified',
+                                                vminDefault=None,
+                                                vmaxDefault=None)
+            colors = getColorsInverse(colors, cmap, vmin, vmax)
+            self.plotKwargs['vmin'] = vmin
+            self.plotKwargs['vmax'] = vmax
+            self.plotKwargs['cmap'] = cmap
+        elif fmt == 'rgba':
+            if isinstance(colors[0], str) or isinstance(colors[0], unicode):
+                # string color defined
+                colors = map(mpl.colors.to_rgba, colors)
+            else:
+                # norm given rgba wanted
+                cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified',
+                                                    vminDefault=None,
+                                                    vmaxDefault=None)
+                self.setVminVmaxAuto(vmin, vmax, colors)
+                # update vmin and vmax
+                cmap, vmin, vmax = self.getNormArgs()
+                colors = getColors(colors,
+                                   vmin=vmin,
+                                   vmax=vmax,
+                                   cmap=cmap)
+    
+        if length is not None:
+            # just one colors value given
+            if len(colors) != length:
+                if not len(colors) == 1:
+                    raise ValueError("Can not correct color length")
+                colors = list(colors)
+                colors *= length
+        elif not hasIter:
+            colors = colors[0]
+    
+        colors = np.array(colors)
+        return colors
+
+    def delNormArgs(self):
+        self.plotKwargs.pop('vmin', None)
+        self.plotKwargs.pop('vmax', None)
+        self.plotKwargs.pop('cmap', None)
+
+    def getSortedLabels(self, labels):
+        """
+        Returns the labels corresponding to the axes
+        """
+        return [labels[i] for i in self.getXYZAxis() if i is not None]
+        
+    def get(self, attr, default=None):
+        return self.plotKwargs.get(attr, default)
+
+    def pop(self, attr, default=None):
+        return self.plotKwargs.pop(attr, default)
+
+    def set(self, attr, value):
+        self.plotKwargs[attr] = value
+
+    def setDefault(self, attr, value):
+        setDefault(self.plotKwargs, attr, value)
+
+    def retrieve(self, attr, default=None, keep=True):
+        if keep:
+            return self.get(attr, default)
+        else:
+            return self.pop(attr, default)
+
+    def retrieveChain(self, *args, **kwargs):
+        default = kwargs.pop('default', None)
+        keep = kwargs.pop('keep', True)
+        if len(args) > 1:
+            return self.retrieve(args[0],
+                                 self.retrieveChain(*args[1:],
+                                                    default=default,
+                                                    keep=keep),
+                                 keep=keep)
+        if len(args) != 1:
+            raise ValueError("Invalid number of args ({0})".format(len(args)))
+        return self.retrieve(args[0], default, keep=keep)
diff --git a/tfields/plotting/mpl.py b/tfields/plotting/mpl.py
new file mode 100644
index 0000000..b3e64b1
--- /dev/null
+++ b/tfields/plotting/mpl.py
@@ -0,0 +1,313 @@
+import tfields
+import numpy as np
+import warnings
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+from matplotlib.patches import Circle
+import mpl_toolkits.mplot3d as plt3D
+
+
+
+
+def plotArray(array, **kwargs):
+    """
+    Points3D plotting method.
+
+    Args:
+        axis (matplotlib.Axis) object
+        xAxis (int): coordinate index that should be on xAxis
+        yAxis (int): coordinate index that should be on yAxis
+        zAxis (int or None): coordinate index that should be on zAxis.
+            If it evaluates to None, 2D plot will be done.
+        methodName (str): method name to use for filling the axis
+
+    Returns:
+        Artist or list of Artists (imitating the axis.scatter/plot behaviour).
+        Better Artist not list of Artists
+    """
+    tfields.plotting.setDefault(kwargs, 'methodName', 'scatter')
+    po = tfields.plotting.PlotOptions(kwargs)
+
+    labelList = po.pop('labelList', ['x (m)', 'y (m)', 'z (m)'])
+    xAxis, yAxis, zAxis = po.getXYZAxis()
+    tfields.plotting.setLabels(po.axis, *po.getSortedLabels(labelList))
+    if zAxis is None:
+        args = [array[:, xAxis],
+                array[:, yAxis]]
+    else:
+        args = [array[:, xAxis],
+                array[:, yAxis],
+                array[:, zAxis]]
+    artist = po.method(*args,
+                       **po.plotKwargs)
+    return artist
+
+
+def plotMesh(vertices, faces, **kwargs):
+    """
+    Args:
+        axis (matplotlib axis)
+        xAxis (int)
+        yAxis (int)
+        zAxis (int)
+        edgecolor (color)
+        color (color): if given, use this color for faces in 2D
+        cmap
+        vmin
+        vmax
+    """
+    if faces.shape[0] == 0:
+        warnings.warn("No faces to plot")
+        return None
+    if max(faces.flat) > vertices.shape[0]:
+        raise ValueError("Some faces point to non existing vertices.")
+    po = tfields.plotting.PlotOptions(kwargs)
+    if po.dim == 2:
+        full = True
+        import npTools as npt
+        import pyTools
+        mesh = npt.Mesh3D(vertices, faces=faces)
+        xAxis, yAxis, zAxis = po.getXYZAxis()
+        facecolors = po.retrieveChain('facecolors', 'color',
+                                      default=0,
+                                      keep=False)
+        if full:
+            # implementation that will sort the triangles by zAxis
+            centroids = mesh.getCentroids()
+            axesIndices = [0, 1, 2]
+            axesIndices.pop(axesIndices.index(xAxis))
+            axesIndices.pop(axesIndices.index(yAxis))
+            zAxis = axesIndices[0]
+            zs = centroids[:, zAxis]
+            zs, faces, facecolors = pyTools.array.getSortedBoth(zs, faces,
+                                                                facecolors)
+            nFacesInitial = len(faces)
+        else:
+            # cut away "back sides" implementation
+            directionVector = np.array([1., 1., 1.])
+            directionVector[xAxis] = 0.
+            directionVector[yAxis] = 0.
+            normVectors = mesh.triangles.getNormVectors()
+            dotProduct = np.dot(normVectors, directionVector)
+            nFacesInitial = len(faces)
+            faces = faces[dotProduct > 0]
+
+        vertices = mesh
+
+        po.plotKwargs['methodName'] = 'tripcolor'
+        po.plotKwargs['triangles'] = faces
+
+        """
+        sort out color arguments
+        """
+        facecolors = po.formatColors(facecolors,
+                                     fmt='norm',
+                                     length=nFacesInitial)
+        if not full:
+            facecolors = facecolors[dotProduct > 0]
+        po.plotKwargs['facecolors'] = facecolors
+
+        d = po.plotKwargs
+        d['xAxis'] = xAxis
+        d['yAxis'] = yAxis
+        artist = plotArray(vertices, **d)
+    elif po.dim == 3:
+        label = po.pop('label', None)
+        color = po.retrieveChain('color', 'c', 'facecolors',
+                                 default='grey',
+                                 keep=False)
+        color = po.formatColors(color,
+                                fmt='rgba',
+                                length=len(faces))
+        nanMask = np.isnan(color)
+        if nanMask.any():
+            warnings.warn("nan found in colors. Removing the corresponding faces!")
+            color = color[~nanMask]
+            faces = faces[~nanMask]
+
+        edgecolor = po.pop('edgecolor', None)
+        alpha = po.pop('alpha', None)
+        po.delNormArgs()
+
+        triangles = np.array([vertices[face] for face in faces])
+        artist = plt3D.art3d.Poly3DCollection(triangles, **po.plotKwargs)
+        po.axis.add_collection3d(artist)
+
+        if edgecolor is not None:
+            artist.set_edgecolor(edgecolor)
+            artist.set_facecolors(color)
+        else:
+            artist.set_color(color)
+
+        if alpha is not None:
+            artist.set_alpha(alpha)
+
+        # for some reason auto-scale does not work
+        tfields.plotting.autoscale3D(po.axis, array=vertices)
+
+        # legend lables do not work at all as an argument
+        if label:
+            artist.set_label(label)
+
+        # when plotting the legend edgecolors/facecolors2d are needed
+        artist._edgecolors2d = None
+        artist._facecolors2d = None
+
+        labelList = ['x (m)', 'y (m)', 'z (m)']
+        tfields.plotting.setLabels(po.axis, *po.getSortedLabels(labelList))
+
+    return artist
+
+
+def plotVectorField(points, vectors, **kwargs):
+    """
+    Args:
+        points (array_like): base vectors
+        vectors (array_like): direction vectors
+    """
+    po = tfields.plotting.PlotOptions(kwargs)
+    if points is None:
+        points = np.full(vectors.shape, 0.)
+    artists = []
+    xAxis, yAxis, zAxis = po.getXYZAxis()
+    for point, vector in zip(points, vectors):
+        if po.dim == 3:
+            artists.append(po.axis.quiver(point[xAxis], point[yAxis], point[zAxis],
+                                          vector[xAxis], vector[yAxis], vector[zAxis],
+                                          **po.plotKwargs))
+        else:
+            artists.append(po.axis.quiver(point[xAxis], point[yAxis],
+                                          vector[xAxis], vector[yAxis],
+                                          **po.plotKwargs))
+    return artists
+
+
+def plotPlane(point, normal, **kwargs):
+
+    def plot_vector(fig, orig, v, color='blue'):
+        axis = fig.gca(projection='3d')
+        orig = np.array(orig)
+        v = np.array(v)
+        axis.quiver(orig[0], orig[1], orig[2], v[0], v[1], v[2], color=color)
+        axis.set_xlim(0, 10)
+        axis.set_ylim(0, 10)
+        axis.set_zlim(0, 10)
+        axis = fig.gca(projection='3d')
+        return fig
+
+    def rotation_matrix(d):
+        sin_angle = np.linalg.norm(d)
+        if sin_angle == 0:
+            return np.identity(3)
+        d /= sin_angle
+        eye = np.eye(3)
+        ddt = np.outer(d, d)
+        skew = np.array([[0, d[2], -d[1]],
+                         [-d[2], 0, d[0]],
+                         [d[1], -d[0], 0]],
+                        dtype=np.float64)
+
+        M = ddt + np.sqrt(1 - sin_angle**2) * (eye - ddt) + sin_angle * skew
+        return M
+
+    def pathpatch_2d_to_3d(pathpatch, z, normal):
+        if type(normal) is str:  # Translate strings to normal vectors
+            index = "xyz".index(normal)
+            normal = np.roll((1.0, 0, 0), index)
+
+        normal /= np.linalg.norm(normal)  # Make sure the vector is normalised
+        path = pathpatch.get_path()  # Get the path and the associated transform
+        trans = pathpatch.get_patch_transform()
+
+        path = trans.transform_path(path)  # Apply the transform
+
+        pathpatch.__class__ = plt3D.art3d.PathPatch3D  # Change the class
+        pathpatch._code3d = path.codes  # Copy the codes
+        pathpatch._facecolor3d = pathpatch.get_facecolor  # Get the face color
+
+        verts = path.vertices  # Get the vertices in 2D
+
+        d = np.cross(normal, (0, 0, 1))  # Obtain the rotation vector
+        M = rotation_matrix(d)  # Get the rotation matrix
+
+        pathpatch._segment3d = np.array([np.dot(M, (x, y, 0)) + (0, 0, z) for x, y in verts])
+
+    def pathpatch_translate(pathpatch, delta):
+        pathpatch._segment3d += delta
+
+    kwargs['alpha'] = kwargs.pop('alpha', 0.5)
+    po = tfields.plotting.PlotOptions(kwargs)
+    patch = Circle((0, 0), **po.plotKwargs)
+    po.axis.add_patch(patch)
+    pathpatch_2d_to_3d(patch, z=0, normal=normal)
+    pathpatch_translate(patch, (point[0], point[1], point[2]))
+
+
+def plotSphere(point, radius, **kwargs):
+    po = tfields.plotting.PlotOptions(kwargs)
+    # Make data
+    u = np.linspace(0, 2 * np.pi, 100)
+    v = np.linspace(0, np.pi, 100)
+    x = point[0] + radius * np.outer(np.cos(u), np.sin(v))
+    y = point[1] + radius * np.outer(np.sin(u), np.sin(v))
+    z = point[2] + radius * np.outer(np.ones(np.size(u)), np.cos(v))
+
+    # Plot the surface
+    return po.axis.plot_surface(x, y, z, **po.plotKwargs)
+
+
+"""
+Color section
+"""
+def getColors(scalars, cmap=None, vmin=None, vmax=None):
+    """
+    retrieve the colors for a list of scalars
+    """
+    if not hasattr(scalars, '__iter__'):
+        scalars = [scalars]
+    if vmin is None:
+        vmin = min(scalars)
+    if vmax is None:
+        vmax = max(scalars)
+    colorMap = plt.get_cmap(cmap)
+    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
+    return colorMap(map(norm, scalars))
+
+
+def getColorsInverse(colors, cmap, vmin, vmax):
+    """
+    Reconstruct the numeric values (0 - 1) of given
+    Args:
+        colors (list or rgba tuple)
+        cmap (matplotlib colormap)
+        vmin (float)
+        vmax (float)
+    """
+    # colors = np.array(colors)/255.
+    r = np.linspace(vmin, vmax, 256)
+    norm = matplotlib.colors.Normalize(vmin, vmax)
+    mapvals = cmap(norm(r))[:, :4]  # there are 4 channels: r,g,b,a
+    scalars = []
+    for color in colors:
+        distance = np.sum((mapvals - color) ** 2, axis=1)
+        scalars.append(r[np.argmin(distance)])
+    return scalars
+
+
+if __name__ == '__main__':
+    import tfields
+    m = tfields.Mesh3D.grid((0, 2, 2), (0, 1, 3), (0, 0, 1))
+    m.maps[0].fields.append(tfields.Tensors(np.arange(m.faces.shape[0])))
+    art1 = m.plot(dim=3, map_index=0, label='twenty')
+
+    m = tfields.Mesh3D.grid((4, 7, 2), (3, 5, 3), (2, 2, 1))
+    m.maps[0].fields.append(tfields.Tensors(np.arange(m.faces.shape[0])))
+    art = m.plot(dim=3, map_index=0, edgecolor='k', vmin=-1, vmax=1, label="something")
+
+    plotSphere([7, 0, 1], 3)
+
+    # mpt.setLegend(mpt.gca(3), [art1, art])
+    # mpt.setAspectEqual(mpt.gca())
+    # mpt.setView(vector=[0, 0, 1])
+    # mpt.save('/tmp/test', 'png')
+    # mpt.plt.show()
-- 
GitLab