diff --git a/tfields/core.py b/tfields/core.py index 79424d1e3f4b306767b898cd36a71cea8bb43fae..f9ffe54091c8d61c89654708c43d051e0fc9b361 100644 --- a/tfields/core.py +++ b/tfields/core.py @@ -488,7 +488,7 @@ class Tensors(AbstractNdarray): ''' transform all raw inputs to cls type with correct coordSys. Also automatically make a copy of those instances that are of the correct type already.''' - objects = [cls(t, **kwargs) for t in objects] + objects = [cls.__new__(cls, t, **kwargs) for t in objects] ''' check rank and dimension equality ''' if not len(set([t.rank for t in objects])) == 1: @@ -687,7 +687,8 @@ class Tensors(AbstractNdarray): >>> p.mirror(1) >>> assert p.equal([[1, -2, 3], [4, -5, 6], [1, -2, -6]]) - multiple coordinates can be mirrored. Eg. a point mirrorion would be + multiple coordinates can be mirrored at the same time + i.e. a point mirrorion would be >>> p = tfields.Tensors([[1., 2., 3.], [4., 5., 6.], [1, 2, -6]]) >>> p.mirror([0,2]) >>> assert p.equal([[-1, 2, -3], [-4, 5, -6], [-1, 2., 6.]]) @@ -696,7 +697,7 @@ class Tensors(AbstractNdarray): The mirroring will only be applied to the points meeting the condition. >>> import sympy >>> x, y, z = sympy.symbols('x y z') - >>> p.mirror([0,2], y > 3) + >>> p.mirror([0, 2], y > 3) >>> p.equal([[-1, 2, -3], [4, 5, 6], [-1, 2, 6]]) True @@ -707,20 +708,28 @@ class Tensors(AbstractNdarray): condition = self.evalf(condition) if isinstance(coordinate, list) or isinstance(coordinate, tuple): for c in coordinate: - self.mirror(c, condition) + self.mirror(c, condition=condition) elif isinstance(coordinate, int): self[:, coordinate][condition] *= -1 else: raise TypeError() def to_segment(self, segment, num_segments, coordinate, - periodicity=2 * np.pi, offset=0, + periodicity=2 * np.pi, offset=0., coordSys=None): """ For circular (close into themself after <periodicity>) coordinates at index <coordinate> assume <num_segments> segments and transform all values to segment number <segment> + Args: + segment (int): segment index (starting at 0) + num_segments (int): number of segments + coordinate (int): coordinate index + periodicity (float): after what lenght, the coordiante repeats + offset (float): offset in the mapping + coordSys (str or sympy.CoordinateSystem): in which coord sys the + transformation should be done Examples: >>> import tfields >>> import numpy as np @@ -1076,6 +1085,13 @@ class Tensors(AbstractNdarray): evalfs, evecs = np.linalg.eigh(cov) return (evecs * evalfs.T).T + def plot(self, **kwargs): + """ + Forwarding to tfields.lib.plotting.plotArray + """ + artist = tfields.plotting.plot_array(self, **kwargs) + return artist + class TensorFields(Tensors): """ diff --git a/tfields/lib/__init__.py b/tfields/lib/__init__.py index 61f2209313380fd5428967b33d2cf01359d6967a..6755d3ba69a3e6c26f27d2e1d8c375233724f875 100644 --- a/tfields/lib/__init__.py +++ b/tfields/lib/__init__.py @@ -111,3 +111,5 @@ else: from . import symbolics from . import sets from . import util + from . import in_out + from . import log diff --git a/tfields/mesh3D.py b/tfields/mesh3D.py index 98845995a268bf78c7a9bbefe44d3c65e729ee1a..758a7bd96a03b5ec91319fa3bec38ce43feab1bb 100644 --- a/tfields/mesh3D.py +++ b/tfields/mesh3D.py @@ -725,7 +725,7 @@ class Mesh3D(tfields.TensorMaps): def plot(self, **kwargs): # pragma: no cover """ - Forwarding to plotTools.plotMesh + Forwarding to plotTools.plot_mesh """ scalars_demanded = any([v in kwargs for v in ['vmin', 'vmax', 'cmap']]) map_index = kwargs.pop('map_index', None if not scalars_demanded else 0) @@ -748,7 +748,7 @@ class Mesh3D(tfields.TensorMaps): if not dim_defined: kwargs['dim'] = 2 - return tfields.plotting.plotMesh(self, self.faces, **kwargs) + return tfields.plotting.plot_mesh(self, self.faces, **kwargs) if __name__ == '__main__': # pragma: no cover diff --git a/tfields/plotting/__init__.py b/tfields/plotting/__init__.py index ac69c06446dbba6a3b74fdd985f755ad26f40538..b53fafdafcd2b8fe2602698be9f91f03112da48c 100644 --- a/tfields/plotting/__init__.py +++ b/tfields/plotting/__init__.py @@ -1,6 +1,9 @@ """ Core plotting tools for tfields library. Especially PlotOptions class is basis for many plotting expansions + +TODO: + * add other library backends. Do not restrict to mpl """ import warnings import matplotlib.pyplot as plt @@ -17,64 +20,6 @@ def setDefault(dictionary, attr, value): dictionary[attr] = value -def gca(dim=None, **kwargs): - """ - Forwarding to plt.gca but translating the dimension to projection - correct dimension - """ - if dim == 3: - axis = plt.gca(projection='3d', **kwargs) - else: - axis = plt.gca(**kwargs) - if dim != axisDim(axis): - if dim is not None: - warnings.warn("You have another dimension set as gca." - "I will force the new dimension to return.") - axis = plt.gcf().add_subplot(1, 1, 1, **kwargs) - return axis - - -def axisDim(axis): - """ - Returns int: axis dimension - """ - if hasattr(axis, 'get_zlim'): - return 3 - else: - return 2 - - -def setLabels(axis, *labels): - axis.set_xlabel(labels[0]) - axis.set_ylabel(labels[1]) - if axisDim(axis) == 3: - axis.set_zlabel(labels[2]) - - -def autoscale3D(axis, array=None, xLim=None, yLim=None, zLim=None): - if array is not None: - xMin, yMin, zMin = array.min(axis=0) - xMax, yMax, zMax = array.max(axis=0) - xLim = (xMin, xMax) - yLim = (yMin, yMax) - zLim = (zMin, zMax) - xLimAxis = axis.get_xlim() - yLimAxis = axis.get_ylim() - zLimAxis = axis.get_zlim() - - if not False: - # not empty axis - xMin = min(xLimAxis[0], xLim[0]) - yMin = min(yLimAxis[0], yLim[0]) - zMin = min(zLimAxis[0], zLim[0]) - xMax = max(xLimAxis[1], xLim[1]) - yMax = max(yLimAxis[1], yLim[1]) - zMax = max(zLimAxis[1], zLim[1]) - axis.set_xlim([xMin, xMax]) - axis.set_ylim([yMin, yMax]) - axis.set_zlim([zMin, zMax]) - - class PlotOptions(object): """ processing kwargs for plotting functions and providing easy @@ -115,9 +60,9 @@ class PlotOptions(object): if dim is None: if self._axis is None: dim = 2 - dim = axisDim(self._axis) + dim = axis_dim(self._axis) elif self._axis is not None: - if not dim == axisDim(self._axis): + if not dim == axis_dim(self._axis): raise ValueError("Axis and dim argument are in conflict.") if dim not in [2, 3]: raise NotImplementedError("Dimensions other than 2 or 3 are not supported.") @@ -198,7 +143,7 @@ class PlotOptions(object): cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified', vminDefault=None, vmaxDefault=None) - colors = getColorsInverse(colors, cmap, vmin, vmax) + colors = to_scalars(colors, cmap, vmin, vmax) self.plotKwargs['vmin'] = vmin self.plotKwargs['vmax'] = vmax self.plotKwargs['cmap'] = cmap @@ -214,7 +159,7 @@ class PlotOptions(object): self.setVminVmaxAuto(vmin, vmax, colors) # update vmin and vmax cmap, vmin, vmax = self.getNormArgs() - colors = getColors(colors, + colors = to_colors(colors, vmin=vmin, vmax=vmax, cmap=cmap) diff --git a/tfields/plotting/mpl.py b/tfields/plotting/mpl.py index a648a060933fbf49f1d8449fbde2933abf06e7a2..0469fb873c54c04696ad1e619654c332fc8cf2a3 100644 --- a/tfields/plotting/mpl.py +++ b/tfields/plotting/mpl.py @@ -1,15 +1,140 @@ +""" +Matplotlib specific plotting +""" import tfields + import numpy as np import warnings +import os import matplotlib as mpl import matplotlib.pyplot as plt from matplotlib.patches import Circle import mpl_toolkits.mplot3d as plt3D +from mpl_toolkits.axes_grid1 import make_axes_locatable +import matplotlib.dates as dates +from itertools import cycle +import logging +def gca(dim=None, **kwargs): + """ + Forwarding to plt.gca but translating the dimension to projection + correct dimension + """ + if dim == 3: + axis = plt.gca(projection='3d', **kwargs) + else: + axis = plt.gca(**kwargs) + if dim != axis_dim(axis): + if dim is not None: + warnings.warn("You have another dimension set as gca." + "I will force the new dimension to return.") + axis = plt.gcf().add_subplot(1, 1, 1, **kwargs) + return axis + + +def upgrade_style(style, source, dest="~/.config/matplotlib/"): + """ + Copy a style file at <origionalFilePath> to the <dest> which is the foreseen + local matplotlib rc dir by default + The style will be name <style>.mplstyle + Args: + style (str): name of style + source (str): full path to mplstyle file to use + dest (str): local directory to copy the file to. Matpotlib has to + search this directory for mplstyle files! + """ + styleExtension = 'mplstyle' + path = tfields.lib.in_out.resolve(os.path.join(dest, style + '.' + styleExtension)) + source = tfields.lib.in_out.resolve(source) + tfields.lib.in_out.cp(source, path) -def plotArray(array, **kwargs): +def set_style(style='tfields', dest="~/.config/matplotlib/"): + """ + Set the matplotlib style of name + Important: + Either you + Args: + style (str) + dest (str): local directory to use file from. if None, use default maplotlib styles + """ + if dest is None: + path = style + else: + styleExtension = 'mplstyle' + path = tfields.lib.in_out.resolve(os.path.join(dest, style + '.' + styleExtension)) + try: + plt.style.use(path) + except IOError: + log = logging.getLogger() + if style == 'tfields': + log.warning("I will copy the default style to {dest}." + .format(**locals())) + source = os.path.join(os.path.dirname(__file__), + style + '.' + styleExtension) + upgrade_style(style, source, dest) + set_style(style) + else: + log.error("Could not set style {path}. Probably you would want to" + "call tfields.plotting.upgrade_style(<style>, " + "<path to mplstyle file that should be copied>)" + "once".format(**locals())) + + +def save(path, *fmts, **kwargs): + """ + Args: + path (str): path without extension to save to + *fmts (str): format of the figure to save. If multiple are given, create + that many files + **kwargs: + axis + fig + """ + log = logging.getLogger() + + # catch figure from axis or fig + axis = kwargs.get('axis', None) + if axis is None: + figDefault = plt.gcf() + axis = gca() + else: + figDefault = axis.figure + fig = kwargs.get('fig', figDefault) + + # set current figure + plt.figure(fig.number) + + # crop the plot down based on the extents of the artists in the plot + kwargs['bbox_inches'] = kwargs.pop('bbox_inches', 'tight') + if kwargs['bbox_inches'] == 'tight': + extraArtists = None + for ax in fig.get_axes(): + firstLabel = ax.get_legend_handles_labels()[0] or None + if firstLabel: + if not extraArtists: + extraArtists = [] + extraArtists.append(firstLabel) + kwargs['bbox_extra_artists'] = kwargs.pop('bbox_extra_artists', extraArtists) + + if len(fmts) != 0: + for fmt in fmts: + if path.endswith('.'): + newFilePath = path + fmt + elif '{fmt}' in path: + newFilePath = path.format(**locals()) + else: + newFilePath = path + '.' + fmt + save(newFilePath, **kwargs) + else: + path = tfields.lib.in_out.resolve(path) + log.info("Saving figure as {0}".format(path)) + plt.savefig(path, + **kwargs) + + +def plot_array(array, **kwargs): """ Points3D plotting method. @@ -30,7 +155,7 @@ def plotArray(array, **kwargs): labelList = po.pop('labelList', ['x (m)', 'y (m)', 'z (m)']) xAxis, yAxis, zAxis = po.getXYZAxis() - tfields.plotting.setLabels(po.axis, *po.getSortedLabels(labelList)) + tfields.plotting.set_labels(po.axis, *po.getSortedLabels(labelList)) if zAxis is None: args = [array[:, xAxis], array[:, yAxis]] @@ -43,7 +168,7 @@ def plotArray(array, **kwargs): return artist -def plotMesh(vertices, faces, **kwargs): +def plot_mesh(vertices, faces, **kwargs): """ Args: axis (matplotlib axis) @@ -85,7 +210,7 @@ def plotMesh(vertices, faces, **kwargs): directionVector = np.array([1., 1., 1.]) directionVector[xAxis] = 0. directionVector[yAxis] = 0. - normVectors = mesh.triangles.norms() + normVectors = mesh.triangles().norms() dotProduct = np.dot(normVectors, directionVector) nFacesInitial = len(faces) faces = faces[dotProduct > 0] @@ -108,7 +233,7 @@ def plotMesh(vertices, faces, **kwargs): d = po.plotKwargs d['xAxis'] = xAxis d['yAxis'] = yAxis - artist = plotArray(vertices, **d) + artist = plot_array(vertices, **d) elif po.dim == 3: label = po.pop('label', None) color = po.retrieveChain('color', 'c', 'facecolors', @@ -141,7 +266,7 @@ def plotMesh(vertices, faces, **kwargs): artist.set_alpha(alpha) # for some reason auto-scale does not work - tfields.plotting.autoscale3D(po.axis, array=vertices) + tfields.plotting.autoscale_3d(po.axis, array=vertices) # legend lables do not work at all as an argument if label: @@ -152,12 +277,15 @@ def plotMesh(vertices, faces, **kwargs): artist._facecolors2d = None labelList = ['x (m)', 'y (m)', 'z (m)'] - tfields.plotting.setLabels(po.axis, *po.getSortedLabels(labelList)) + tfields.plotting.set_labels(po.axis, *po.getSortedLabels(labelList)) + + else: + raise NotImplementedError("Dimension != 2|3") return artist -def plotVectorField(points, vectors, **kwargs): +def plot_tensor_field(points, vectors, **kwargs): """ Args: points (array_like): base vectors @@ -173,14 +301,16 @@ def plotVectorField(points, vectors, **kwargs): artists.append(po.axis.quiver(point[xAxis], point[yAxis], point[zAxis], vector[xAxis], vector[yAxis], vector[zAxis], **po.plotKwargs)) - else: + elif po.dim == 2: artists.append(po.axis.quiver(point[xAxis], point[yAxis], vector[xAxis], vector[yAxis], **po.plotKwargs)) + else: + raise NotImplementedError("Dimension != 2|3") return artists -def plotPlane(point, normal, **kwargs): +def plot_plane(point, normal, **kwargs): def plot_vector(fig, orig, v, color='blue'): axis = fig.gca(projection='3d') @@ -241,7 +371,7 @@ def plotPlane(point, normal, **kwargs): pathpatch_translate(patch, (point[0], point[1], point[2])) -def plotSphere(point, radius, **kwargs): +def plot_sphere(point, radius, **kwargs): po = tfields.plotting.PlotOptions(kwargs) # Make data u = np.linspace(0, 2 * np.pi, 100) @@ -254,10 +384,34 @@ def plotSphere(point, radius, **kwargs): return po.axis.plot_surface(x, y, z, **po.plotKwargs) +def plot_function(fun, **kwargs): + """ + Args: + axis (matplotlib.Axis) object + + Returns: + Artist or list of Artists (imitating the axis.scatter/plot behaviour). + Better Artist not list of Artists + """ + import numpy as np + labelList = ['x', 'f(x)'] + po = tfields.plotting.PlotOptions(kwargs) + tfields.plotting.set_labels(po.axis, *labelList) + xMin, xMax = po.pop('xMin', 0), po.pop('xMax', 1) + n = po.pop('n', 100) + vals = np.linspace(xMin, xMax, n) + args = (vals, map(fun, vals)) + artist = po.axis.plot(*args, + **po.plotKwargs) + return artist + + """ Color section """ -def getColors(scalars, cmap=None, vmin=None, vmax=None): + + +def to_colors(scalars, cmap=None, vmin=None, vmax=None): """ retrieve the colors for a list of scalars """ @@ -272,8 +426,9 @@ def getColors(scalars, cmap=None, vmin=None, vmax=None): return colorMap(map(norm, scalars)) -def getColorsInverse(colors, cmap, vmin, vmax): +def to_scalars(colors, cmap, vmin, vmax): """ + Inverse 'to_colors' Reconstruct the numeric values (0 - 1) of given Args: colors (list or rgba tuple) @@ -283,7 +438,7 @@ def getColorsInverse(colors, cmap, vmin, vmax): """ # colors = np.array(colors)/255. r = np.linspace(vmin, vmax, 256) - norm = matplotlib.colors.Normalize(vmin, vmax) + norm = mpl.colors.Normalize(vmin, vmax) mapvals = cmap(norm(r))[:, :4] # there are 4 channels: r,g,b,a scalars = [] for color in colors: @@ -292,8 +447,159 @@ def getColorsInverse(colors, cmap, vmin, vmax): return scalars +def colormap(seq): + """ + Args: + seq (iterable): a sequence of floats and RGB-tuples. The floats should be increasing + and in the interval (0,1). + Returns: + LinearSegmentedColormap + """ + seq = [(None,) * 3, 0.0] + list(seq) + [1.0, (None,) * 3] + cdict = {'red': [], 'green': [], 'blue': []} + for i, item in enumerate(seq): + if isinstance(item, float): + r1, g1, b1 = seq[i - 1] + r2, g2, b2 = seq[i + 1] + cdict['red'].append([item, r1, r2]) + cdict['green'].append([item, g1, g2]) + cdict['blue'].append([item, b1, b2]) + return mpl.colors.LinearSegmentedColormap('CustomMap', cdict) + + +def color_cycle(colormap=None, n=None): + """ + Args: + colormap (matplotlib colormap): e.g. plotTools.plt.cm.coolwarm + n (int): needed for colormap argument + """ + if colormap: + color_rgb = to_colors(np.linspace(0, 1, n), cmap=colormap, vmin=0, vmax=1) + colors = map(lambda rgb: '#%02x%02x%02x' % (rgb[0] * 255, + rgb[1] * 255, + rgb[2] * 255), + tuple(color_rgb[:, 0:-1])) + else: + colors = list([color['color'] for color in mpl.rcParams['axes.prop_cycle']]) + return cycle(colors) + + +""" +Display section +""" + + +def axis_dim(axis): + """ + Returns int: axis dimension + """ + if hasattr(axis, 'get_zlim'): + return 3 + else: + return 2 + + +def set_aspect_equal(axis): + """Fix equal aspect bug for 3D plots.""" + + if axis_dim(axis) == 2: + axis.set_aspect('equal') + return + + xlim = axis.get_xlim3d() + ylim = axis.get_ylim3d() + zlim = axis.get_zlim3d() + + from numpy import mean + xmean = mean(xlim) + ymean = mean(ylim) + zmean = mean(zlim) + + plot_radius = max([abs(lim - mean_) + for lims, mean_ in ((xlim, xmean), + (ylim, ymean), + (zlim, zmean)) + for lim in lims]) + + axis.set_xlim3d([xmean - plot_radius, xmean + plot_radius]) + axis.set_ylim3d([ymean - plot_radius, ymean + plot_radius]) + axis.set_zlim3d([zmean - plot_radius, zmean + plot_radius]) + + +def set_axis_off(axis): + if axis_dim(axis) == 2: + axis.set_axis_off() + else: + axis._axis3don = False + + +def autoscale_3d(axis, array=None, xLim=None, yLim=None, zLim=None): + if array is not None: + xMin, yMin, zMin = array.min(axis=0) + xMax, yMax, zMax = array.max(axis=0) + xLim = (xMin, xMax) + yLim = (yMin, yMax) + zLim = (zMin, zMax) + xLimAxis = axis.get_xlim() + yLimAxis = axis.get_ylim() + zLimAxis = axis.get_zlim() + + if not False: + # not empty axis + xMin = min(xLimAxis[0], xLim[0]) + yMin = min(yLimAxis[0], yLim[0]) + zMin = min(zLimAxis[0], zLim[0]) + xMax = max(xLimAxis[1], xLim[1]) + yMax = max(yLimAxis[1], yLim[1]) + zMax = max(zLimAxis[1], zLim[1]) + axis.set_xlim([xMin, xMax]) + axis.set_ylim([yMin, yMax]) + axis.set_zlim([zMin, zMax]) + + +def setLegend(axis, artists): + handles = [] + for artist in artists: + if isinstance(artist, list): + handles.append(artist[0]) + else: + handles.append(artist) + axis.legend(handles=handles) + + +def set_color_bar(axis, artist, label=None, divide=True, **kwargs): + # colorbar + if divide: + divider = make_axes_locatable(axis) + axis = divider.append_axes("right", size="2%", pad=0.05) + cbar = plt.colorbar(artist, cax=axis, **kwargs) + + # label + if label is None: + artLabel = artist.get_label() + if artLabel: + label = artLabel + if label is not None: + labelpad = 30 + cbar.set_label(label, rotation=270, labelpad=labelpad) + return cbar + + +def set_labels(axis, *labels): + axis.set_xlabel(labels[0]) + axis.set_ylabel(labels[1]) + if axis_dim(axis) == 3: + axis.set_zlabel(labels[2]) + + +def set_formatter(sub_axis=None, formatter=dates.DateFormatter('%d-%m-%y')): + if sub_axis is None: + axis = gca() + sub_axis = axis.xaxis + sub_axis.set_major_formatter(formatter) + + if __name__ == '__main__': - import tfields m = tfields.Mesh3D.grid((0, 2, 2), (0, 1, 3), (0, 0, 1)) m.maps[0].fields.append(tfields.Tensors(np.arange(m.faces.shape[0]))) art1 = m.plot(dim=3, map_index=0, label='twenty') @@ -302,10 +608,4 @@ if __name__ == '__main__': m.maps[0].fields.append(tfields.Tensors(np.arange(m.faces.shape[0]))) art = m.plot(dim=3, map_index=0, edgecolor='k', vmin=-1, vmax=1, label="something") - plotSphere([7, 0, 1], 3) - - # mpt.setLegend(mpt.gca(3), [art1, art]) - # mpt.setAspectEqual(mpt.gca()) - # mpt.setView(vector=[0, 0, 1]) - # mpt.save('/tmp/test', 'png') - # mpt.plt.show() + plot_sphere([7, 0, 1], 3)