From 7b46bd6b383c1506c99aaba0201df48885b9a7ea Mon Sep 17 00:00:00 2001 From: "Boeckenhoff, Daniel (dboe)" <daniel.boeckenhoff@ipp.mpg.de> Date: Thu, 5 Jul 2018 19:44:40 +0200 Subject: [PATCH] added plotting for mesh --- tfields/__init__.py | 1 + tfields/mesh3D.py | 31 +++- tfields/plotting/__init__.py | 275 ++++++++++++++++++++++++++++++ tfields/plotting/mpl.py | 313 +++++++++++++++++++++++++++++++++++ 4 files changed, 615 insertions(+), 5 deletions(-) create mode 100644 tfields/plotting/__init__.py create mode 100644 tfields/plotting/mpl.py diff --git a/tfields/__init__.py b/tfields/__init__.py index 4397180..2a61ba6 100644 --- a/tfields/__init__.py +++ b/tfields/__init__.py @@ -2,6 +2,7 @@ from . import core from . import bases from . import lib from .lib import * +from . import plotting # __all__ = ['core', 'points3D'] from .core import Tensors, TensorFields, TensorMaps diff --git a/tfields/mesh3D.py b/tfields/mesh3D.py index c558632..ab406dc 100644 --- a/tfields/mesh3D.py +++ b/tfields/mesh3D.py @@ -723,11 +723,32 @@ class Mesh3D(tfields.TensorMaps): return obj, template return obj - def plot(self): - import mplTools as mpt - mpt.plotMesh(self, self.faces, color=self.maps[0].fields[0], vmin=0, - vmax=20, axis=mpt.gca(3)) - mpt.plt.show() + def plot(self, **kwargs): + """ + Forwarding to plotTools.plotMesh + """ + scalars_demanded = any([v in kwargs for v in ['vmin', 'vmax', 'cmap']]) + map_index = kwargs.pop('map_index', None if not scalars_demanded else 0) + if map_index is not None: + if not len(self.maps[0]) == 0: + kwargs['color'] = self.maps[0].fields[map_index] + + dim_defined = False + if 'axis' in kwargs: + dim_defined = True + if 'zAxis' in kwargs: + if kwargs['zAxis'] is not None: + kwargs['dim'] = 3 + else: + kwargs['dim'] = 2 + dim_defined = True + if 'dim' in kwargs: + dim_defined = True + + if not dim_defined: + kwargs['dim'] = 2 + + return tfields.plotting.plotMesh(self, self.faces, **kwargs) if __name__ == '__main__': diff --git a/tfields/plotting/__init__.py b/tfields/plotting/__init__.py new file mode 100644 index 0000000..ac69c06 --- /dev/null +++ b/tfields/plotting/__init__.py @@ -0,0 +1,275 @@ +""" +Core plotting tools for tfields library. Especially PlotOptions class +is basis for many plotting expansions +""" +import warnings +import matplotlib.pyplot as plt +import matplotlib as mpl +import numpy as np +from .mpl import * + + +def setDefault(dictionary, attr, value): + """ + Set defaults to a dictionary + """ + if attr not in dictionary: + dictionary[attr] = value + + +def gca(dim=None, **kwargs): + """ + Forwarding to plt.gca but translating the dimension to projection + correct dimension + """ + if dim == 3: + axis = plt.gca(projection='3d', **kwargs) + else: + axis = plt.gca(**kwargs) + if dim != axisDim(axis): + if dim is not None: + warnings.warn("You have another dimension set as gca." + "I will force the new dimension to return.") + axis = plt.gcf().add_subplot(1, 1, 1, **kwargs) + return axis + + +def axisDim(axis): + """ + Returns int: axis dimension + """ + if hasattr(axis, 'get_zlim'): + return 3 + else: + return 2 + + +def setLabels(axis, *labels): + axis.set_xlabel(labels[0]) + axis.set_ylabel(labels[1]) + if axisDim(axis) == 3: + axis.set_zlabel(labels[2]) + + +def autoscale3D(axis, array=None, xLim=None, yLim=None, zLim=None): + if array is not None: + xMin, yMin, zMin = array.min(axis=0) + xMax, yMax, zMax = array.max(axis=0) + xLim = (xMin, xMax) + yLim = (yMin, yMax) + zLim = (zMin, zMax) + xLimAxis = axis.get_xlim() + yLimAxis = axis.get_ylim() + zLimAxis = axis.get_zlim() + + if not False: + # not empty axis + xMin = min(xLimAxis[0], xLim[0]) + yMin = min(yLimAxis[0], yLim[0]) + zMin = min(zLimAxis[0], zLim[0]) + xMax = max(xLimAxis[1], xLim[1]) + yMax = max(yLimAxis[1], yLim[1]) + zMax = max(zLimAxis[1], zLim[1]) + axis.set_xlim([xMin, xMax]) + axis.set_ylim([yMin, yMax]) + axis.set_zlim([zMin, zMax]) + + +class PlotOptions(object): + """ + processing kwargs for plotting functions and providing easy + access to axis, dimension and plotting method as well as indices + for array choice (x..., y..., zAxis) + """ + def __init__(self, kwargs): + kwargs = dict(kwargs) + self.axis = kwargs.pop('axis', None) + self.dim = kwargs.pop('dim', None) + self.method = kwargs.pop('methodName', None) + self.setXYZAxis(kwargs) + self.plotKwargs = kwargs + + @property + def method(self): + """ + Method for plotting. Will be callable together with plotKwargs + """ + return self._method + + @method.setter + def method(self, methodName): + if not isinstance(methodName, str): + self._method = methodName + else: + self._method = getattr(self.axis, methodName) + + @property + def dim(self): + """ + axis dimension + """ + return self._dim + + @dim.setter + def dim(self, dim): + if dim is None: + if self._axis is None: + dim = 2 + dim = axisDim(self._axis) + elif self._axis is not None: + if not dim == axisDim(self._axis): + raise ValueError("Axis and dim argument are in conflict.") + if dim not in [2, 3]: + raise NotImplementedError("Dimensions other than 2 or 3 are not supported.") + self._dim = dim + + @property + def axis(self): + """ + The plt.Axis object that belongs to this instance + """ + if self._axis is None: + return gca(self._dim) + else: + return self._axis + + @axis.setter + def axis(self, axis): + self._axis = axis + + def setXYZAxis(self, kwargs): + self._xAxis = kwargs.pop('xAxis', 0) + self._yAxis = kwargs.pop('yAxis', 1) + zAxis = kwargs.pop('zAxis', None) + if zAxis is None and self.dim == 3: + indicesUsed = [0, 1, 2] + indicesUsed.remove(self._xAxis) + indicesUsed.remove(self._yAxis) + zAxis = indicesUsed[0] + self._zAxis = zAxis + + def getXYZAxis(self): + return self._xAxis, self._yAxis, self._zAxis + + def setVminVmaxAuto(self, vmin, vmax, scalars): + """ + Automatically set vmin and vmax as min/max of scalars + but only if vmin or vmax is None + """ + if scalars is None: + return + if len(scalars) < 2: + warnings.warn("Need at least two scalars to autoset vmin and/or vmax!") + return + if vmin is None: + vmin = min(scalars) + self.plotKwargs['vmin'] = vmin + if vmax is None: + vmax = min(scalars) + self.plotKwargs['vmax'] = vmax + + def getNormArgs(self, vminDefault=0, vmaxDefault=1, cmapDefault=None): + if cmapDefault is None: + cmapDefault = plt.rcParams['image.cmap'] + cmap = self.get('cmap', cmapDefault) + vmin = self.get('vmin', vminDefault) + vmax = self.get('vmax', vmaxDefault) + return cmap, vmin, vmax + + def formatColors(self, colors, fmt='rgba', length=None): + """ + format colors according to fmt argument + Args: + colors (list/one value of rgba tuples/int/float/str): This argument will + be interpreted as color + fmt (str): rgba / norm + length (int/None): if not None: corrct colors lenght + + Returns: + colors in fmt + """ + hasIter = True + if not hasattr(colors, '__iter__'): + # colors is just one element + hasIter = False + colors = [colors] + if hasattr(colors[0], '__iter__') and fmt == 'norm': + # rgba given but norm wanted + cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified', + vminDefault=None, + vmaxDefault=None) + colors = getColorsInverse(colors, cmap, vmin, vmax) + self.plotKwargs['vmin'] = vmin + self.plotKwargs['vmax'] = vmax + self.plotKwargs['cmap'] = cmap + elif fmt == 'rgba': + if isinstance(colors[0], str) or isinstance(colors[0], unicode): + # string color defined + colors = map(mpl.colors.to_rgba, colors) + else: + # norm given rgba wanted + cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified', + vminDefault=None, + vmaxDefault=None) + self.setVminVmaxAuto(vmin, vmax, colors) + # update vmin and vmax + cmap, vmin, vmax = self.getNormArgs() + colors = getColors(colors, + vmin=vmin, + vmax=vmax, + cmap=cmap) + + if length is not None: + # just one colors value given + if len(colors) != length: + if not len(colors) == 1: + raise ValueError("Can not correct color length") + colors = list(colors) + colors *= length + elif not hasIter: + colors = colors[0] + + colors = np.array(colors) + return colors + + def delNormArgs(self): + self.plotKwargs.pop('vmin', None) + self.plotKwargs.pop('vmax', None) + self.plotKwargs.pop('cmap', None) + + def getSortedLabels(self, labels): + """ + Returns the labels corresponding to the axes + """ + return [labels[i] for i in self.getXYZAxis() if i is not None] + + def get(self, attr, default=None): + return self.plotKwargs.get(attr, default) + + def pop(self, attr, default=None): + return self.plotKwargs.pop(attr, default) + + def set(self, attr, value): + self.plotKwargs[attr] = value + + def setDefault(self, attr, value): + setDefault(self.plotKwargs, attr, value) + + def retrieve(self, attr, default=None, keep=True): + if keep: + return self.get(attr, default) + else: + return self.pop(attr, default) + + def retrieveChain(self, *args, **kwargs): + default = kwargs.pop('default', None) + keep = kwargs.pop('keep', True) + if len(args) > 1: + return self.retrieve(args[0], + self.retrieveChain(*args[1:], + default=default, + keep=keep), + keep=keep) + if len(args) != 1: + raise ValueError("Invalid number of args ({0})".format(len(args))) + return self.retrieve(args[0], default, keep=keep) diff --git a/tfields/plotting/mpl.py b/tfields/plotting/mpl.py new file mode 100644 index 0000000..b3e64b1 --- /dev/null +++ b/tfields/plotting/mpl.py @@ -0,0 +1,313 @@ +import tfields +import numpy as np +import warnings +import matplotlib as mpl +import matplotlib.pyplot as plt +from matplotlib.patches import Circle +import mpl_toolkits.mplot3d as plt3D + + + + +def plotArray(array, **kwargs): + """ + Points3D plotting method. + + Args: + axis (matplotlib.Axis) object + xAxis (int): coordinate index that should be on xAxis + yAxis (int): coordinate index that should be on yAxis + zAxis (int or None): coordinate index that should be on zAxis. + If it evaluates to None, 2D plot will be done. + methodName (str): method name to use for filling the axis + + Returns: + Artist or list of Artists (imitating the axis.scatter/plot behaviour). + Better Artist not list of Artists + """ + tfields.plotting.setDefault(kwargs, 'methodName', 'scatter') + po = tfields.plotting.PlotOptions(kwargs) + + labelList = po.pop('labelList', ['x (m)', 'y (m)', 'z (m)']) + xAxis, yAxis, zAxis = po.getXYZAxis() + tfields.plotting.setLabels(po.axis, *po.getSortedLabels(labelList)) + if zAxis is None: + args = [array[:, xAxis], + array[:, yAxis]] + else: + args = [array[:, xAxis], + array[:, yAxis], + array[:, zAxis]] + artist = po.method(*args, + **po.plotKwargs) + return artist + + +def plotMesh(vertices, faces, **kwargs): + """ + Args: + axis (matplotlib axis) + xAxis (int) + yAxis (int) + zAxis (int) + edgecolor (color) + color (color): if given, use this color for faces in 2D + cmap + vmin + vmax + """ + if faces.shape[0] == 0: + warnings.warn("No faces to plot") + return None + if max(faces.flat) > vertices.shape[0]: + raise ValueError("Some faces point to non existing vertices.") + po = tfields.plotting.PlotOptions(kwargs) + if po.dim == 2: + full = True + import npTools as npt + import pyTools + mesh = npt.Mesh3D(vertices, faces=faces) + xAxis, yAxis, zAxis = po.getXYZAxis() + facecolors = po.retrieveChain('facecolors', 'color', + default=0, + keep=False) + if full: + # implementation that will sort the triangles by zAxis + centroids = mesh.getCentroids() + axesIndices = [0, 1, 2] + axesIndices.pop(axesIndices.index(xAxis)) + axesIndices.pop(axesIndices.index(yAxis)) + zAxis = axesIndices[0] + zs = centroids[:, zAxis] + zs, faces, facecolors = pyTools.array.getSortedBoth(zs, faces, + facecolors) + nFacesInitial = len(faces) + else: + # cut away "back sides" implementation + directionVector = np.array([1., 1., 1.]) + directionVector[xAxis] = 0. + directionVector[yAxis] = 0. + normVectors = mesh.triangles.getNormVectors() + dotProduct = np.dot(normVectors, directionVector) + nFacesInitial = len(faces) + faces = faces[dotProduct > 0] + + vertices = mesh + + po.plotKwargs['methodName'] = 'tripcolor' + po.plotKwargs['triangles'] = faces + + """ + sort out color arguments + """ + facecolors = po.formatColors(facecolors, + fmt='norm', + length=nFacesInitial) + if not full: + facecolors = facecolors[dotProduct > 0] + po.plotKwargs['facecolors'] = facecolors + + d = po.plotKwargs + d['xAxis'] = xAxis + d['yAxis'] = yAxis + artist = plotArray(vertices, **d) + elif po.dim == 3: + label = po.pop('label', None) + color = po.retrieveChain('color', 'c', 'facecolors', + default='grey', + keep=False) + color = po.formatColors(color, + fmt='rgba', + length=len(faces)) + nanMask = np.isnan(color) + if nanMask.any(): + warnings.warn("nan found in colors. Removing the corresponding faces!") + color = color[~nanMask] + faces = faces[~nanMask] + + edgecolor = po.pop('edgecolor', None) + alpha = po.pop('alpha', None) + po.delNormArgs() + + triangles = np.array([vertices[face] for face in faces]) + artist = plt3D.art3d.Poly3DCollection(triangles, **po.plotKwargs) + po.axis.add_collection3d(artist) + + if edgecolor is not None: + artist.set_edgecolor(edgecolor) + artist.set_facecolors(color) + else: + artist.set_color(color) + + if alpha is not None: + artist.set_alpha(alpha) + + # for some reason auto-scale does not work + tfields.plotting.autoscale3D(po.axis, array=vertices) + + # legend lables do not work at all as an argument + if label: + artist.set_label(label) + + # when plotting the legend edgecolors/facecolors2d are needed + artist._edgecolors2d = None + artist._facecolors2d = None + + labelList = ['x (m)', 'y (m)', 'z (m)'] + tfields.plotting.setLabels(po.axis, *po.getSortedLabels(labelList)) + + return artist + + +def plotVectorField(points, vectors, **kwargs): + """ + Args: + points (array_like): base vectors + vectors (array_like): direction vectors + """ + po = tfields.plotting.PlotOptions(kwargs) + if points is None: + points = np.full(vectors.shape, 0.) + artists = [] + xAxis, yAxis, zAxis = po.getXYZAxis() + for point, vector in zip(points, vectors): + if po.dim == 3: + artists.append(po.axis.quiver(point[xAxis], point[yAxis], point[zAxis], + vector[xAxis], vector[yAxis], vector[zAxis], + **po.plotKwargs)) + else: + artists.append(po.axis.quiver(point[xAxis], point[yAxis], + vector[xAxis], vector[yAxis], + **po.plotKwargs)) + return artists + + +def plotPlane(point, normal, **kwargs): + + def plot_vector(fig, orig, v, color='blue'): + axis = fig.gca(projection='3d') + orig = np.array(orig) + v = np.array(v) + axis.quiver(orig[0], orig[1], orig[2], v[0], v[1], v[2], color=color) + axis.set_xlim(0, 10) + axis.set_ylim(0, 10) + axis.set_zlim(0, 10) + axis = fig.gca(projection='3d') + return fig + + def rotation_matrix(d): + sin_angle = np.linalg.norm(d) + if sin_angle == 0: + return np.identity(3) + d /= sin_angle + eye = np.eye(3) + ddt = np.outer(d, d) + skew = np.array([[0, d[2], -d[1]], + [-d[2], 0, d[0]], + [d[1], -d[0], 0]], + dtype=np.float64) + + M = ddt + np.sqrt(1 - sin_angle**2) * (eye - ddt) + sin_angle * skew + return M + + def pathpatch_2d_to_3d(pathpatch, z, normal): + if type(normal) is str: # Translate strings to normal vectors + index = "xyz".index(normal) + normal = np.roll((1.0, 0, 0), index) + + normal /= np.linalg.norm(normal) # Make sure the vector is normalised + path = pathpatch.get_path() # Get the path and the associated transform + trans = pathpatch.get_patch_transform() + + path = trans.transform_path(path) # Apply the transform + + pathpatch.__class__ = plt3D.art3d.PathPatch3D # Change the class + pathpatch._code3d = path.codes # Copy the codes + pathpatch._facecolor3d = pathpatch.get_facecolor # Get the face color + + verts = path.vertices # Get the vertices in 2D + + d = np.cross(normal, (0, 0, 1)) # Obtain the rotation vector + M = rotation_matrix(d) # Get the rotation matrix + + pathpatch._segment3d = np.array([np.dot(M, (x, y, 0)) + (0, 0, z) for x, y in verts]) + + def pathpatch_translate(pathpatch, delta): + pathpatch._segment3d += delta + + kwargs['alpha'] = kwargs.pop('alpha', 0.5) + po = tfields.plotting.PlotOptions(kwargs) + patch = Circle((0, 0), **po.plotKwargs) + po.axis.add_patch(patch) + pathpatch_2d_to_3d(patch, z=0, normal=normal) + pathpatch_translate(patch, (point[0], point[1], point[2])) + + +def plotSphere(point, radius, **kwargs): + po = tfields.plotting.PlotOptions(kwargs) + # Make data + u = np.linspace(0, 2 * np.pi, 100) + v = np.linspace(0, np.pi, 100) + x = point[0] + radius * np.outer(np.cos(u), np.sin(v)) + y = point[1] + radius * np.outer(np.sin(u), np.sin(v)) + z = point[2] + radius * np.outer(np.ones(np.size(u)), np.cos(v)) + + # Plot the surface + return po.axis.plot_surface(x, y, z, **po.plotKwargs) + + +""" +Color section +""" +def getColors(scalars, cmap=None, vmin=None, vmax=None): + """ + retrieve the colors for a list of scalars + """ + if not hasattr(scalars, '__iter__'): + scalars = [scalars] + if vmin is None: + vmin = min(scalars) + if vmax is None: + vmax = max(scalars) + colorMap = plt.get_cmap(cmap) + norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) + return colorMap(map(norm, scalars)) + + +def getColorsInverse(colors, cmap, vmin, vmax): + """ + Reconstruct the numeric values (0 - 1) of given + Args: + colors (list or rgba tuple) + cmap (matplotlib colormap) + vmin (float) + vmax (float) + """ + # colors = np.array(colors)/255. + r = np.linspace(vmin, vmax, 256) + norm = matplotlib.colors.Normalize(vmin, vmax) + mapvals = cmap(norm(r))[:, :4] # there are 4 channels: r,g,b,a + scalars = [] + for color in colors: + distance = np.sum((mapvals - color) ** 2, axis=1) + scalars.append(r[np.argmin(distance)]) + return scalars + + +if __name__ == '__main__': + import tfields + m = tfields.Mesh3D.grid((0, 2, 2), (0, 1, 3), (0, 0, 1)) + m.maps[0].fields.append(tfields.Tensors(np.arange(m.faces.shape[0]))) + art1 = m.plot(dim=3, map_index=0, label='twenty') + + m = tfields.Mesh3D.grid((4, 7, 2), (3, 5, 3), (2, 2, 1)) + m.maps[0].fields.append(tfields.Tensors(np.arange(m.faces.shape[0]))) + art = m.plot(dim=3, map_index=0, edgecolor='k', vmin=-1, vmax=1, label="something") + + plotSphere([7, 0, 1], 3) + + # mpt.setLegend(mpt.gca(3), [art1, art]) + # mpt.setAspectEqual(mpt.gca()) + # mpt.setView(vector=[0, 0, 1]) + # mpt.save('/tmp/test', 'png') + # mpt.plt.show() -- GitLab