Commit 5750f527 authored by Daniel Boeckenhoff's avatar Daniel Boeckenhoff
Browse files

plotting for Tensors

parent 6a7107c3
......@@ -488,7 +488,7 @@ class Tensors(AbstractNdarray):
''' transform all raw inputs to cls type with correct coordSys. Also
automatically make a copy of those instances that are of the correct
type already.'''
objects = [cls(t, **kwargs) for t in objects]
objects = [cls.__new__(cls, t, **kwargs) for t in objects]
''' check rank and dimension equality '''
if not len(set([t.rank for t in objects])) == 1:
......@@ -687,7 +687,8 @@ class Tensors(AbstractNdarray):
>>> p.mirror(1)
>>> assert p.equal([[1, -2, 3], [4, -5, 6], [1, -2, -6]])
multiple coordinates can be mirrored. Eg. a point mirrorion would be
multiple coordinates can be mirrored at the same time
i.e. a point mirrorion would be
>>> p = tfields.Tensors([[1., 2., 3.], [4., 5., 6.], [1, 2, -6]])
>>> p.mirror([0,2])
>>> assert p.equal([[-1, 2, -3], [-4, 5, -6], [-1, 2., 6.]])
......@@ -696,7 +697,7 @@ class Tensors(AbstractNdarray):
The mirroring will only be applied to the points meeting the condition.
>>> import sympy
>>> x, y, z = sympy.symbols('x y z')
>>> p.mirror([0,2], y > 3)
>>> p.mirror([0, 2], y > 3)
>>> p.equal([[-1, 2, -3], [4, 5, 6], [-1, 2, 6]])
True
......@@ -707,20 +708,28 @@ class Tensors(AbstractNdarray):
condition = self.evalf(condition)
if isinstance(coordinate, list) or isinstance(coordinate, tuple):
for c in coordinate:
self.mirror(c, condition)
self.mirror(c, condition=condition)
elif isinstance(coordinate, int):
self[:, coordinate][condition] *= -1
else:
raise TypeError()
def to_segment(self, segment, num_segments, coordinate,
periodicity=2 * np.pi, offset=0,
periodicity=2 * np.pi, offset=0.,
coordSys=None):
"""
For circular (close into themself after
<periodicity>) coordinates at index <coordinate> assume
<num_segments> segments and transform all values to
segment number <segment>
Args:
segment (int): segment index (starting at 0)
num_segments (int): number of segments
coordinate (int): coordinate index
periodicity (float): after what lenght, the coordiante repeats
offset (float): offset in the mapping
coordSys (str or sympy.CoordinateSystem): in which coord sys the
transformation should be done
Examples:
>>> import tfields
>>> import numpy as np
......@@ -1076,6 +1085,13 @@ class Tensors(AbstractNdarray):
evalfs, evecs = np.linalg.eigh(cov)
return (evecs * evalfs.T).T
def plot(self, **kwargs):
"""
Forwarding to tfields.lib.plotting.plotArray
"""
artist = tfields.plotting.plot_array(self, **kwargs)
return artist
class TensorFields(Tensors):
"""
......
......@@ -111,3 +111,5 @@ else:
from . import symbolics
from . import sets
from . import util
from . import in_out
from . import log
......@@ -725,7 +725,7 @@ class Mesh3D(tfields.TensorMaps):
def plot(self, **kwargs): # pragma: no cover
"""
Forwarding to plotTools.plotMesh
Forwarding to plotTools.plot_mesh
"""
scalars_demanded = any([v in kwargs for v in ['vmin', 'vmax', 'cmap']])
map_index = kwargs.pop('map_index', None if not scalars_demanded else 0)
......@@ -748,7 +748,7 @@ class Mesh3D(tfields.TensorMaps):
if not dim_defined:
kwargs['dim'] = 2
return tfields.plotting.plotMesh(self, self.faces, **kwargs)
return tfields.plotting.plot_mesh(self, self.faces, **kwargs)
if __name__ == '__main__': # pragma: no cover
......
"""
Core plotting tools for tfields library. Especially PlotOptions class
is basis for many plotting expansions
TODO:
* add other library backends. Do not restrict to mpl
"""
import warnings
import matplotlib.pyplot as plt
......@@ -17,64 +20,6 @@ def setDefault(dictionary, attr, value):
dictionary[attr] = value
def gca(dim=None, **kwargs):
"""
Forwarding to plt.gca but translating the dimension to projection
correct dimension
"""
if dim == 3:
axis = plt.gca(projection='3d', **kwargs)
else:
axis = plt.gca(**kwargs)
if dim != axisDim(axis):
if dim is not None:
warnings.warn("You have another dimension set as gca."
"I will force the new dimension to return.")
axis = plt.gcf().add_subplot(1, 1, 1, **kwargs)
return axis
def axisDim(axis):
"""
Returns int: axis dimension
"""
if hasattr(axis, 'get_zlim'):
return 3
else:
return 2
def setLabels(axis, *labels):
axis.set_xlabel(labels[0])
axis.set_ylabel(labels[1])
if axisDim(axis) == 3:
axis.set_zlabel(labels[2])
def autoscale3D(axis, array=None, xLim=None, yLim=None, zLim=None):
if array is not None:
xMin, yMin, zMin = array.min(axis=0)
xMax, yMax, zMax = array.max(axis=0)
xLim = (xMin, xMax)
yLim = (yMin, yMax)
zLim = (zMin, zMax)
xLimAxis = axis.get_xlim()
yLimAxis = axis.get_ylim()
zLimAxis = axis.get_zlim()
if not False:
# not empty axis
xMin = min(xLimAxis[0], xLim[0])
yMin = min(yLimAxis[0], yLim[0])
zMin = min(zLimAxis[0], zLim[0])
xMax = max(xLimAxis[1], xLim[1])
yMax = max(yLimAxis[1], yLim[1])
zMax = max(zLimAxis[1], zLim[1])
axis.set_xlim([xMin, xMax])
axis.set_ylim([yMin, yMax])
axis.set_zlim([zMin, zMax])
class PlotOptions(object):
"""
processing kwargs for plotting functions and providing easy
......@@ -115,9 +60,9 @@ class PlotOptions(object):
if dim is None:
if self._axis is None:
dim = 2
dim = axisDim(self._axis)
dim = axis_dim(self._axis)
elif self._axis is not None:
if not dim == axisDim(self._axis):
if not dim == axis_dim(self._axis):
raise ValueError("Axis and dim argument are in conflict.")
if dim not in [2, 3]:
raise NotImplementedError("Dimensions other than 2 or 3 are not supported.")
......@@ -198,7 +143,7 @@ class PlotOptions(object):
cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified',
vminDefault=None,
vmaxDefault=None)
colors = getColorsInverse(colors, cmap, vmin, vmax)
colors = to_scalars(colors, cmap, vmin, vmax)
self.plotKwargs['vmin'] = vmin
self.plotKwargs['vmax'] = vmax
self.plotKwargs['cmap'] = cmap
......@@ -214,7 +159,7 @@ class PlotOptions(object):
self.setVminVmaxAuto(vmin, vmax, colors)
# update vmin and vmax
cmap, vmin, vmax = self.getNormArgs()
colors = getColors(colors,
colors = to_colors(colors,
vmin=vmin,
vmax=vmax,
cmap=cmap)
......
"""
Matplotlib specific plotting
"""
import tfields
import numpy as np
import warnings
import os
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import mpl_toolkits.mplot3d as plt3D
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.dates as dates
from itertools import cycle
import logging
def gca(dim=None, **kwargs):
"""
Forwarding to plt.gca but translating the dimension to projection
correct dimension
"""
if dim == 3:
axis = plt.gca(projection='3d', **kwargs)
else:
axis = plt.gca(**kwargs)
if dim != axis_dim(axis):
if dim is not None:
warnings.warn("You have another dimension set as gca."
"I will force the new dimension to return.")
axis = plt.gcf().add_subplot(1, 1, 1, **kwargs)
return axis
def upgrade_style(style, source, dest="~/.config/matplotlib/"):
"""
Copy a style file at <origionalFilePath> to the <dest> which is the foreseen
local matplotlib rc dir by default
The style will be name <style>.mplstyle
Args:
style (str): name of style
source (str): full path to mplstyle file to use
dest (str): local directory to copy the file to. Matpotlib has to
search this directory for mplstyle files!
"""
styleExtension = 'mplstyle'
path = tfields.lib.in_out.resolve(os.path.join(dest, style + '.' + styleExtension))
source = tfields.lib.in_out.resolve(source)
tfields.lib.in_out.cp(source, path)
def plotArray(array, **kwargs):
def set_style(style='tfields', dest="~/.config/matplotlib/"):
"""
Set the matplotlib style of name
Important:
Either you
Args:
style (str)
dest (str): local directory to use file from. if None, use default maplotlib styles
"""
if dest is None:
path = style
else:
styleExtension = 'mplstyle'
path = tfields.lib.in_out.resolve(os.path.join(dest, style + '.' + styleExtension))
try:
plt.style.use(path)
except IOError:
log = logging.getLogger()
if style == 'tfields':
log.warning("I will copy the default style to {dest}."
.format(**locals()))
source = os.path.join(os.path.dirname(__file__),
style + '.' + styleExtension)
upgrade_style(style, source, dest)
set_style(style)
else:
log.error("Could not set style {path}. Probably you would want to"
"call tfields.plotting.upgrade_style(<style>, "
"<path to mplstyle file that should be copied>)"
"once".format(**locals()))
def save(path, *fmts, **kwargs):
"""
Args:
path (str): path without extension to save to
*fmts (str): format of the figure to save. If multiple are given, create
that many files
**kwargs:
axis
fig
"""
log = logging.getLogger()
# catch figure from axis or fig
axis = kwargs.get('axis', None)
if axis is None:
figDefault = plt.gcf()
axis = gca()
else:
figDefault = axis.figure
fig = kwargs.get('fig', figDefault)
# set current figure
plt.figure(fig.number)
# crop the plot down based on the extents of the artists in the plot
kwargs['bbox_inches'] = kwargs.pop('bbox_inches', 'tight')
if kwargs['bbox_inches'] == 'tight':
extraArtists = None
for ax in fig.get_axes():
firstLabel = ax.get_legend_handles_labels()[0] or None
if firstLabel:
if not extraArtists:
extraArtists = []
extraArtists.append(firstLabel)
kwargs['bbox_extra_artists'] = kwargs.pop('bbox_extra_artists', extraArtists)
if len(fmts) != 0:
for fmt in fmts:
if path.endswith('.'):
newFilePath = path + fmt
elif '{fmt}' in path:
newFilePath = path.format(**locals())
else:
newFilePath = path + '.' + fmt
save(newFilePath, **kwargs)
else:
path = tfields.lib.in_out.resolve(path)
log.info("Saving figure as {0}".format(path))
plt.savefig(path,
**kwargs)
def plot_array(array, **kwargs):
"""
Points3D plotting method.
......@@ -30,7 +155,7 @@ def plotArray(array, **kwargs):
labelList = po.pop('labelList', ['x (m)', 'y (m)', 'z (m)'])
xAxis, yAxis, zAxis = po.getXYZAxis()
tfields.plotting.setLabels(po.axis, *po.getSortedLabels(labelList))
tfields.plotting.set_labels(po.axis, *po.getSortedLabels(labelList))
if zAxis is None:
args = [array[:, xAxis],
array[:, yAxis]]
......@@ -43,7 +168,7 @@ def plotArray(array, **kwargs):
return artist
def plotMesh(vertices, faces, **kwargs):
def plot_mesh(vertices, faces, **kwargs):
"""
Args:
axis (matplotlib axis)
......@@ -85,7 +210,7 @@ def plotMesh(vertices, faces, **kwargs):
directionVector = np.array([1., 1., 1.])
directionVector[xAxis] = 0.
directionVector[yAxis] = 0.
normVectors = mesh.triangles.norms()
normVectors = mesh.triangles().norms()
dotProduct = np.dot(normVectors, directionVector)
nFacesInitial = len(faces)
faces = faces[dotProduct > 0]
......@@ -108,7 +233,7 @@ def plotMesh(vertices, faces, **kwargs):
d = po.plotKwargs
d['xAxis'] = xAxis
d['yAxis'] = yAxis
artist = plotArray(vertices, **d)
artist = plot_array(vertices, **d)
elif po.dim == 3:
label = po.pop('label', None)
color = po.retrieveChain('color', 'c', 'facecolors',
......@@ -141,7 +266,7 @@ def plotMesh(vertices, faces, **kwargs):
artist.set_alpha(alpha)
# for some reason auto-scale does not work
tfields.plotting.autoscale3D(po.axis, array=vertices)
tfields.plotting.autoscale_3d(po.axis, array=vertices)
# legend lables do not work at all as an argument
if label:
......@@ -152,12 +277,15 @@ def plotMesh(vertices, faces, **kwargs):
artist._facecolors2d = None
labelList = ['x (m)', 'y (m)', 'z (m)']
tfields.plotting.setLabels(po.axis, *po.getSortedLabels(labelList))
tfields.plotting.set_labels(po.axis, *po.getSortedLabels(labelList))
else:
raise NotImplementedError("Dimension != 2|3")
return artist
def plotVectorField(points, vectors, **kwargs):
def plot_tensor_field(points, vectors, **kwargs):
"""
Args:
points (array_like): base vectors
......@@ -173,14 +301,16 @@ def plotVectorField(points, vectors, **kwargs):
artists.append(po.axis.quiver(point[xAxis], point[yAxis], point[zAxis],
vector[xAxis], vector[yAxis], vector[zAxis],
**po.plotKwargs))
else:
elif po.dim == 2:
artists.append(po.axis.quiver(point[xAxis], point[yAxis],
vector[xAxis], vector[yAxis],
**po.plotKwargs))
else:
raise NotImplementedError("Dimension != 2|3")
return artists
def plotPlane(point, normal, **kwargs):
def plot_plane(point, normal, **kwargs):
def plot_vector(fig, orig, v, color='blue'):
axis = fig.gca(projection='3d')
......@@ -241,7 +371,7 @@ def plotPlane(point, normal, **kwargs):
pathpatch_translate(patch, (point[0], point[1], point[2]))
def plotSphere(point, radius, **kwargs):
def plot_sphere(point, radius, **kwargs):
po = tfields.plotting.PlotOptions(kwargs)
# Make data
u = np.linspace(0, 2 * np.pi, 100)
......@@ -254,10 +384,34 @@ def plotSphere(point, radius, **kwargs):
return po.axis.plot_surface(x, y, z, **po.plotKwargs)
def plot_function(fun, **kwargs):
"""
Args:
axis (matplotlib.Axis) object
Returns:
Artist or list of Artists (imitating the axis.scatter/plot behaviour).
Better Artist not list of Artists
"""
import numpy as np
labelList = ['x', 'f(x)']
po = tfields.plotting.PlotOptions(kwargs)
tfields.plotting.set_labels(po.axis, *labelList)
xMin, xMax = po.pop('xMin', 0), po.pop('xMax', 1)
n = po.pop('n', 100)
vals = np.linspace(xMin, xMax, n)
args = (vals, map(fun, vals))
artist = po.axis.plot(*args,
**po.plotKwargs)
return artist
"""
Color section
"""
def getColors(scalars, cmap=None, vmin=None, vmax=None):
def to_colors(scalars, cmap=None, vmin=None, vmax=None):
"""
retrieve the colors for a list of scalars
"""
......@@ -272,8 +426,9 @@ def getColors(scalars, cmap=None, vmin=None, vmax=None):
return colorMap(map(norm, scalars))
def getColorsInverse(colors, cmap, vmin, vmax):
def to_scalars(colors, cmap, vmin, vmax):
"""
Inverse 'to_colors'
Reconstruct the numeric values (0 - 1) of given
Args:
colors (list or rgba tuple)
......@@ -283,7 +438,7 @@ def getColorsInverse(colors, cmap, vmin, vmax):
"""
# colors = np.array(colors)/255.
r = np.linspace(vmin, vmax, 256)
norm = matplotlib.colors.Normalize(vmin, vmax)
norm = mpl.colors.Normalize(vmin, vmax)
mapvals = cmap(norm(r))[:, :4] # there are 4 channels: r,g,b,a
scalars = []
for color in colors:
......@@ -292,8 +447,159 @@ def getColorsInverse(colors, cmap, vmin, vmax):
return scalars
def colormap(seq):
"""
Args:
seq (iterable): a sequence of floats and RGB-tuples. The floats should be increasing
and in the interval (0,1).
Returns:
LinearSegmentedColormap
"""
seq = [(None,) * 3, 0.0] + list(seq) + [1.0, (None,) * 3]
cdict = {'red': [], 'green': [], 'blue': []}
for i, item in enumerate(seq):
if isinstance(item, float):
r1, g1, b1 = seq[i - 1]
r2, g2, b2 = seq[i + 1]
cdict['red'].append([item, r1, r2])
cdict['green'].append([item, g1, g2])
cdict['blue'].append([item, b1, b2])
return mpl.colors.LinearSegmentedColormap('CustomMap', cdict)
def color_cycle(colormap=None, n=None):
"""
Args:
colormap (matplotlib colormap): e.g. plotTools.plt.cm.coolwarm
n (int): needed for colormap argument
"""
if colormap:
color_rgb = to_colors(np.linspace(0, 1, n), cmap=colormap, vmin=0, vmax=1)
colors = map(lambda rgb: '#%02x%02x%02x' % (rgb[0] * 255,
rgb[1] * 255,
rgb[2] * 255),
tuple(color_rgb[:, 0:-1]))
else:
colors = list([color['color'] for color in mpl.rcParams['axes.prop_cycle']])
return cycle(colors)
"""
Display section
"""
def axis_dim(axis):
"""
Returns int: axis dimension
"""
if hasattr(axis, 'get_zlim'):
return 3
else:
return 2
def set_aspect_equal(axis):
"""Fix equal aspect bug for 3D plots."""
if axis_dim(axis) == 2:
axis.set_aspect('equal')
return
xlim = axis.get_xlim3d()
ylim = axis.get_ylim3d()
zlim = axis.get_zlim3d()
from numpy import mean
xmean = mean(xlim)
ymean = mean(ylim)
zmean = mean(zlim)
plot_radius = max([abs(lim - mean_)
for lims, mean_ in ((xlim, xmean),
(ylim, ymean),
(zlim, zmean))
for lim in lims])
axis.set_xlim3d([xmean - plot_radius, xmean + plot_radius])
axis.set_ylim3d([ymean - plot_radius, ymean + plot_radius])
axis.set_zlim3d([zmean - plot_radius, zmean + plot_radius])
def set_axis_off(axis):
if axis_dim(axis) == 2:
axis.set_axis_off()
else:
axis._axis3don = False
def autoscale_3d(axis, array=None, xLim=None, yLim=None, zLim=None):
if array is not None:
xMin, yMin, zMin = array.min(axis=0)
xMax, yMax, zMax = array.max(axis=0)
xLim = (xMin, xMax)
yLim = (yMin, yMax)
zLim = (zMin, zMax)
xLimAxis = axis.get_xlim()
yLimAxis = axis.get_ylim()
zLimAxis = axis.get_zlim()
if not False:
# not empty axis
xMin = min(xLimAxis[0], xLim[0])
yMin = min(yLimAxis[0], yLim[0])
zMin = min(zLimAxis[0], zLim[0])
xMax = max(xLimAxis[1], xLim[1])
yMax = max(yLimAxis[1], yLim[1])
zMax = max(zLimAxis[1], zLim[1])
axis.set_xlim([xMin, xMax])
axis.set_ylim([yMin, yMax])
axis.set_zlim([zMin, zMax])
def setLegend(axis, artists):
handles = []
for artist in artists:
if isinstance(artist, list):
handles.append(artist[0])
else:
handles.append(artist)
axis.legend(handles=handles)
def set_color_bar(axis, artist, label=None, divide=True, **kwargs):
# colorbar
if divide:
divider = make_axes_locatable(axis)
axis = divider.append_axes("right", size="2%", pad=0.05)
cbar = plt.colorbar(artist, cax=axis, **kwargs)
# label
if label is None:
artLabel = artist.get_label()