Commit 5750f527 authored by Daniel Boeckenhoff's avatar Daniel Boeckenhoff
Browse files

plotting for Tensors

parent 6a7107c3
...@@ -488,7 +488,7 @@ class Tensors(AbstractNdarray): ...@@ -488,7 +488,7 @@ class Tensors(AbstractNdarray):
''' transform all raw inputs to cls type with correct coordSys. Also ''' transform all raw inputs to cls type with correct coordSys. Also
automatically make a copy of those instances that are of the correct automatically make a copy of those instances that are of the correct
type already.''' type already.'''
objects = [cls(t, **kwargs) for t in objects] objects = [cls.__new__(cls, t, **kwargs) for t in objects]
''' check rank and dimension equality ''' ''' check rank and dimension equality '''
if not len(set([t.rank for t in objects])) == 1: if not len(set([t.rank for t in objects])) == 1:
...@@ -687,7 +687,8 @@ class Tensors(AbstractNdarray): ...@@ -687,7 +687,8 @@ class Tensors(AbstractNdarray):
>>> p.mirror(1) >>> p.mirror(1)
>>> assert p.equal([[1, -2, 3], [4, -5, 6], [1, -2, -6]]) >>> assert p.equal([[1, -2, 3], [4, -5, 6], [1, -2, -6]])
multiple coordinates can be mirrored. Eg. a point mirrorion would be multiple coordinates can be mirrored at the same time
i.e. a point mirrorion would be
>>> p = tfields.Tensors([[1., 2., 3.], [4., 5., 6.], [1, 2, -6]]) >>> p = tfields.Tensors([[1., 2., 3.], [4., 5., 6.], [1, 2, -6]])
>>> p.mirror([0,2]) >>> p.mirror([0,2])
>>> assert p.equal([[-1, 2, -3], [-4, 5, -6], [-1, 2., 6.]]) >>> assert p.equal([[-1, 2, -3], [-4, 5, -6], [-1, 2., 6.]])
...@@ -696,7 +697,7 @@ class Tensors(AbstractNdarray): ...@@ -696,7 +697,7 @@ class Tensors(AbstractNdarray):
The mirroring will only be applied to the points meeting the condition. The mirroring will only be applied to the points meeting the condition.
>>> import sympy >>> import sympy
>>> x, y, z = sympy.symbols('x y z') >>> x, y, z = sympy.symbols('x y z')
>>> p.mirror([0,2], y > 3) >>> p.mirror([0, 2], y > 3)
>>> p.equal([[-1, 2, -3], [4, 5, 6], [-1, 2, 6]]) >>> p.equal([[-1, 2, -3], [4, 5, 6], [-1, 2, 6]])
True True
...@@ -707,20 +708,28 @@ class Tensors(AbstractNdarray): ...@@ -707,20 +708,28 @@ class Tensors(AbstractNdarray):
condition = self.evalf(condition) condition = self.evalf(condition)
if isinstance(coordinate, list) or isinstance(coordinate, tuple): if isinstance(coordinate, list) or isinstance(coordinate, tuple):
for c in coordinate: for c in coordinate:
self.mirror(c, condition) self.mirror(c, condition=condition)
elif isinstance(coordinate, int): elif isinstance(coordinate, int):
self[:, coordinate][condition] *= -1 self[:, coordinate][condition] *= -1
else: else:
raise TypeError() raise TypeError()
def to_segment(self, segment, num_segments, coordinate, def to_segment(self, segment, num_segments, coordinate,
periodicity=2 * np.pi, offset=0, periodicity=2 * np.pi, offset=0.,
coordSys=None): coordSys=None):
""" """
For circular (close into themself after For circular (close into themself after
<periodicity>) coordinates at index <coordinate> assume <periodicity>) coordinates at index <coordinate> assume
<num_segments> segments and transform all values to <num_segments> segments and transform all values to
segment number <segment> segment number <segment>
Args:
segment (int): segment index (starting at 0)
num_segments (int): number of segments
coordinate (int): coordinate index
periodicity (float): after what lenght, the coordiante repeats
offset (float): offset in the mapping
coordSys (str or sympy.CoordinateSystem): in which coord sys the
transformation should be done
Examples: Examples:
>>> import tfields >>> import tfields
>>> import numpy as np >>> import numpy as np
...@@ -1076,6 +1085,13 @@ class Tensors(AbstractNdarray): ...@@ -1076,6 +1085,13 @@ class Tensors(AbstractNdarray):
evalfs, evecs = np.linalg.eigh(cov) evalfs, evecs = np.linalg.eigh(cov)
return (evecs * evalfs.T).T return (evecs * evalfs.T).T
def plot(self, **kwargs):
"""
Forwarding to tfields.lib.plotting.plotArray
"""
artist = tfields.plotting.plot_array(self, **kwargs)
return artist
class TensorFields(Tensors): class TensorFields(Tensors):
""" """
......
...@@ -111,3 +111,5 @@ else: ...@@ -111,3 +111,5 @@ else:
from . import symbolics from . import symbolics
from . import sets from . import sets
from . import util from . import util
from . import in_out
from . import log
...@@ -725,7 +725,7 @@ class Mesh3D(tfields.TensorMaps): ...@@ -725,7 +725,7 @@ class Mesh3D(tfields.TensorMaps):
def plot(self, **kwargs): # pragma: no cover def plot(self, **kwargs): # pragma: no cover
""" """
Forwarding to plotTools.plotMesh Forwarding to plotTools.plot_mesh
""" """
scalars_demanded = any([v in kwargs for v in ['vmin', 'vmax', 'cmap']]) scalars_demanded = any([v in kwargs for v in ['vmin', 'vmax', 'cmap']])
map_index = kwargs.pop('map_index', None if not scalars_demanded else 0) map_index = kwargs.pop('map_index', None if not scalars_demanded else 0)
...@@ -748,7 +748,7 @@ class Mesh3D(tfields.TensorMaps): ...@@ -748,7 +748,7 @@ class Mesh3D(tfields.TensorMaps):
if not dim_defined: if not dim_defined:
kwargs['dim'] = 2 kwargs['dim'] = 2
return tfields.plotting.plotMesh(self, self.faces, **kwargs) return tfields.plotting.plot_mesh(self, self.faces, **kwargs)
if __name__ == '__main__': # pragma: no cover if __name__ == '__main__': # pragma: no cover
......
""" """
Core plotting tools for tfields library. Especially PlotOptions class Core plotting tools for tfields library. Especially PlotOptions class
is basis for many plotting expansions is basis for many plotting expansions
TODO:
* add other library backends. Do not restrict to mpl
""" """
import warnings import warnings
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -17,64 +20,6 @@ def setDefault(dictionary, attr, value): ...@@ -17,64 +20,6 @@ def setDefault(dictionary, attr, value):
dictionary[attr] = value dictionary[attr] = value
def gca(dim=None, **kwargs):
"""
Forwarding to plt.gca but translating the dimension to projection
correct dimension
"""
if dim == 3:
axis = plt.gca(projection='3d', **kwargs)
else:
axis = plt.gca(**kwargs)
if dim != axisDim(axis):
if dim is not None:
warnings.warn("You have another dimension set as gca."
"I will force the new dimension to return.")
axis = plt.gcf().add_subplot(1, 1, 1, **kwargs)
return axis
def axisDim(axis):
"""
Returns int: axis dimension
"""
if hasattr(axis, 'get_zlim'):
return 3
else:
return 2
def setLabels(axis, *labels):
axis.set_xlabel(labels[0])
axis.set_ylabel(labels[1])
if axisDim(axis) == 3:
axis.set_zlabel(labels[2])
def autoscale3D(axis, array=None, xLim=None, yLim=None, zLim=None):
if array is not None:
xMin, yMin, zMin = array.min(axis=0)
xMax, yMax, zMax = array.max(axis=0)
xLim = (xMin, xMax)
yLim = (yMin, yMax)
zLim = (zMin, zMax)
xLimAxis = axis.get_xlim()
yLimAxis = axis.get_ylim()
zLimAxis = axis.get_zlim()
if not False:
# not empty axis
xMin = min(xLimAxis[0], xLim[0])
yMin = min(yLimAxis[0], yLim[0])
zMin = min(zLimAxis[0], zLim[0])
xMax = max(xLimAxis[1], xLim[1])
yMax = max(yLimAxis[1], yLim[1])
zMax = max(zLimAxis[1], zLim[1])
axis.set_xlim([xMin, xMax])
axis.set_ylim([yMin, yMax])
axis.set_zlim([zMin, zMax])
class PlotOptions(object): class PlotOptions(object):
""" """
processing kwargs for plotting functions and providing easy processing kwargs for plotting functions and providing easy
...@@ -115,9 +60,9 @@ class PlotOptions(object): ...@@ -115,9 +60,9 @@ class PlotOptions(object):
if dim is None: if dim is None:
if self._axis is None: if self._axis is None:
dim = 2 dim = 2
dim = axisDim(self._axis) dim = axis_dim(self._axis)
elif self._axis is not None: elif self._axis is not None:
if not dim == axisDim(self._axis): if not dim == axis_dim(self._axis):
raise ValueError("Axis and dim argument are in conflict.") raise ValueError("Axis and dim argument are in conflict.")
if dim not in [2, 3]: if dim not in [2, 3]:
raise NotImplementedError("Dimensions other than 2 or 3 are not supported.") raise NotImplementedError("Dimensions other than 2 or 3 are not supported.")
...@@ -198,7 +143,7 @@ class PlotOptions(object): ...@@ -198,7 +143,7 @@ class PlotOptions(object):
cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified', cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified',
vminDefault=None, vminDefault=None,
vmaxDefault=None) vmaxDefault=None)
colors = getColorsInverse(colors, cmap, vmin, vmax) colors = to_scalars(colors, cmap, vmin, vmax)
self.plotKwargs['vmin'] = vmin self.plotKwargs['vmin'] = vmin
self.plotKwargs['vmax'] = vmax self.plotKwargs['vmax'] = vmax
self.plotKwargs['cmap'] = cmap self.plotKwargs['cmap'] = cmap
...@@ -214,7 +159,7 @@ class PlotOptions(object): ...@@ -214,7 +159,7 @@ class PlotOptions(object):
self.setVminVmaxAuto(vmin, vmax, colors) self.setVminVmaxAuto(vmin, vmax, colors)
# update vmin and vmax # update vmin and vmax
cmap, vmin, vmax = self.getNormArgs() cmap, vmin, vmax = self.getNormArgs()
colors = getColors(colors, colors = to_colors(colors,
vmin=vmin, vmin=vmin,
vmax=vmax, vmax=vmax,
cmap=cmap) cmap=cmap)
......
"""
Matplotlib specific plotting
"""
import tfields import tfields
import numpy as np import numpy as np
import warnings import warnings
import os
import matplotlib as mpl import matplotlib as mpl
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.patches import Circle from matplotlib.patches import Circle
import mpl_toolkits.mplot3d as plt3D import mpl_toolkits.mplot3d as plt3D
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.dates as dates
from itertools import cycle
import logging
def gca(dim=None, **kwargs):
"""
Forwarding to plt.gca but translating the dimension to projection
correct dimension
"""
if dim == 3:
axis = plt.gca(projection='3d', **kwargs)
else:
axis = plt.gca(**kwargs)
if dim != axis_dim(axis):
if dim is not None:
warnings.warn("You have another dimension set as gca."
"I will force the new dimension to return.")
axis = plt.gcf().add_subplot(1, 1, 1, **kwargs)
return axis
def upgrade_style(style, source, dest="~/.config/matplotlib/"):
"""
Copy a style file at <origionalFilePath> to the <dest> which is the foreseen
local matplotlib rc dir by default
The style will be name <style>.mplstyle
Args:
style (str): name of style
source (str): full path to mplstyle file to use
dest (str): local directory to copy the file to. Matpotlib has to
search this directory for mplstyle files!
"""
styleExtension = 'mplstyle'
path = tfields.lib.in_out.resolve(os.path.join(dest, style + '.' + styleExtension))
source = tfields.lib.in_out.resolve(source)
tfields.lib.in_out.cp(source, path)
def plotArray(array, **kwargs): def set_style(style='tfields', dest="~/.config/matplotlib/"):
"""
Set the matplotlib style of name
Important:
Either you
Args:
style (str)
dest (str): local directory to use file from. if None, use default maplotlib styles
"""
if dest is None:
path = style
else:
styleExtension = 'mplstyle'
path = tfields.lib.in_out.resolve(os.path.join(dest, style + '.' + styleExtension))
try:
plt.style.use(path)
except IOError:
log = logging.getLogger()
if style == 'tfields':
log.warning("I will copy the default style to {dest}."
.format(**locals()))
source = os.path.join(os.path.dirname(__file__),
style + '.' + styleExtension)
upgrade_style(style, source, dest)
set_style(style)
else:
log.error("Could not set style {path}. Probably you would want to"
"call tfields.plotting.upgrade_style(<style>, "
"<path to mplstyle file that should be copied>)"
"once".format(**locals()))
def save(path, *fmts, **kwargs):
"""
Args:
path (str): path without extension to save to
*fmts (str): format of the figure to save. If multiple are given, create
that many files
**kwargs:
axis
fig
"""
log = logging.getLogger()
# catch figure from axis or fig
axis = kwargs.get('axis', None)
if axis is None:
figDefault = plt.gcf()
axis = gca()
else:
figDefault = axis.figure
fig = kwargs.get('fig', figDefault)
# set current figure
plt.figure(fig.number)
# crop the plot down based on the extents of the artists in the plot
kwargs['bbox_inches'] = kwargs.pop('bbox_inches', 'tight')
if kwargs['bbox_inches'] == 'tight':
extraArtists = None
for ax in fig.get_axes():
firstLabel = ax.get_legend_handles_labels()[0] or None
if firstLabel:
if not extraArtists:
extraArtists = []
extraArtists.append(firstLabel)
kwargs['bbox_extra_artists'] = kwargs.pop('bbox_extra_artists', extraArtists)
if len(fmts) != 0:
for fmt in fmts:
if path.endswith('.'):
newFilePath = path + fmt
elif '{fmt}' in path:
newFilePath = path.format(**locals())
else:
newFilePath = path + '.' + fmt
save(newFilePath, **kwargs)
else:
path = tfields.lib.in_out.resolve(path)
log.info("Saving figure as {0}".format(path))
plt.savefig(path,
**kwargs)
def plot_array(array, **kwargs):
""" """
Points3D plotting method. Points3D plotting method.
...@@ -30,7 +155,7 @@ def plotArray(array, **kwargs): ...@@ -30,7 +155,7 @@ def plotArray(array, **kwargs):
labelList = po.pop('labelList', ['x (m)', 'y (m)', 'z (m)']) labelList = po.pop('labelList', ['x (m)', 'y (m)', 'z (m)'])
xAxis, yAxis, zAxis = po.getXYZAxis() xAxis, yAxis, zAxis = po.getXYZAxis()
tfields.plotting.setLabels(po.axis, *po.getSortedLabels(labelList)) tfields.plotting.set_labels(po.axis, *po.getSortedLabels(labelList))
if zAxis is None: if zAxis is None:
args = [array[:, xAxis], args = [array[:, xAxis],
array[:, yAxis]] array[:, yAxis]]
...@@ -43,7 +168,7 @@ def plotArray(array, **kwargs): ...@@ -43,7 +168,7 @@ def plotArray(array, **kwargs):
return artist return artist
def plotMesh(vertices, faces, **kwargs): def plot_mesh(vertices, faces, **kwargs):
""" """
Args: Args:
axis (matplotlib axis) axis (matplotlib axis)
...@@ -85,7 +210,7 @@ def plotMesh(vertices, faces, **kwargs): ...@@ -85,7 +210,7 @@ def plotMesh(vertices, faces, **kwargs):
directionVector = np.array([1., 1., 1.]) directionVector = np.array([1., 1., 1.])
directionVector[xAxis] = 0. directionVector[xAxis] = 0.
directionVector[yAxis] = 0. directionVector[yAxis] = 0.
normVectors = mesh.triangles.norms() normVectors = mesh.triangles().norms()
dotProduct = np.dot(normVectors, directionVector) dotProduct = np.dot(normVectors, directionVector)
nFacesInitial = len(faces) nFacesInitial = len(faces)
faces = faces[dotProduct > 0] faces = faces[dotProduct > 0]
...@@ -108,7 +233,7 @@ def plotMesh(vertices, faces, **kwargs): ...@@ -108,7 +233,7 @@ def plotMesh(vertices, faces, **kwargs):
d = po.plotKwargs d = po.plotKwargs
d['xAxis'] = xAxis d['xAxis'] = xAxis
d['yAxis'] = yAxis d['yAxis'] = yAxis
artist = plotArray(vertices, **d) artist = plot_array(vertices, **d)
elif po.dim == 3: elif po.dim == 3:
label = po.pop('label', None) label = po.pop('label', None)
color = po.retrieveChain('color', 'c', 'facecolors', color = po.retrieveChain('color', 'c', 'facecolors',
...@@ -141,7 +266,7 @@ def plotMesh(vertices, faces, **kwargs): ...@@ -141,7 +266,7 @@ def plotMesh(vertices, faces, **kwargs):
artist.set_alpha(alpha) artist.set_alpha(alpha)
# for some reason auto-scale does not work # for some reason auto-scale does not work
tfields.plotting.autoscale3D(po.axis, array=vertices) tfields.plotting.autoscale_3d(po.axis, array=vertices)
# legend lables do not work at all as an argument # legend lables do not work at all as an argument
if label: if label:
...@@ -152,12 +277,15 @@ def plotMesh(vertices, faces, **kwargs): ...@@ -152,12 +277,15 @@ def plotMesh(vertices, faces, **kwargs):
artist._facecolors2d = None artist._facecolors2d = None
labelList = ['x (m)', 'y (m)', 'z (m)'] labelList = ['x (m)', 'y (m)', 'z (m)']
tfields.plotting.setLabels(po.axis, *po.getSortedLabels(labelList)) tfields.plotting.set_labels(po.axis, *po.getSortedLabels(labelList))
else:
raise NotImplementedError("Dimension != 2|3")
return artist return artist
def plotVectorField(points, vectors, **kwargs): def plot_tensor_field(points, vectors, **kwargs):
""" """
Args: Args:
points (array_like): base vectors points (array_like): base vectors
...@@ -173,14 +301,16 @@ def plotVectorField(points, vectors, **kwargs): ...@@ -173,14 +301,16 @@ def plotVectorField(points, vectors, **kwargs):
artists.append(po.axis.quiver(point[xAxis], point[yAxis], point[zAxis], artists.append(po.axis.quiver(point[xAxis], point[yAxis], point[zAxis],
vector[xAxis], vector[yAxis], vector[zAxis], vector[xAxis], vector[yAxis], vector[zAxis],
**po.plotKwargs)) **po.plotKwargs))
else: elif po.dim == 2:
artists.append(po.axis.quiver(point[xAxis], point[yAxis], artists.append(po.axis.quiver(point[xAxis], point[yAxis],
vector[xAxis], vector[yAxis], vector[xAxis], vector[yAxis],
**po.plotKwargs)) **po.plotKwargs))
else:
raise NotImplementedError("Dimension != 2|3")
return artists return artists
def plotPlane(point, normal, **kwargs): def plot_plane(point, normal, **kwargs):
def plot_vector(fig, orig, v, color='blue'): def plot_vector(fig, orig, v, color='blue'):
axis = fig.gca(projection='3d') axis = fig.gca(projection='3d')
...@@ -241,7 +371,7 @@ def plotPlane(point, normal, **kwargs): ...@@ -241,7 +371,7 @@ def plotPlane(point, normal, **kwargs):
pathpatch_translate(patch, (point[0], point[1], point[2])) pathpatch_translate(patch, (point[0], point[1], point[2]))
def plotSphere(point, radius, **kwargs): def plot_sphere(point, radius, **kwargs):
po = tfields.plotting.PlotOptions(kwargs) po = tfields.plotting.PlotOptions(kwargs)
# Make data # Make data
u = np.linspace(0, 2 * np.pi, 100) u = np.linspace(0, 2 * np.pi, 100)
...@@ -254,10 +384,34 @@ def plotSphere(point, radius, **kwargs): ...@@ -254,10 +384,34 @@ def plotSphere(point, radius, **kwargs):
return po.axis.plot_surface(x, y, z, **po.plotKwargs) return po.axis.plot_surface(x, y, z, **po.plotKwargs)
def plot_function(fun, **kwargs):
"""
Args:
axis (matplotlib.Axis) object
Returns:
Artist or list of Artists (imitating the axis.scatter/plot behaviour).
Better Artist not list of Artists
"""
import numpy as np
labelList = ['x', 'f(x)']
po = tfields.plotting.PlotOptions(kwargs)
tfields.plotting.set_labels(po.axis, *labelList)
xMin, xMax = po.pop('xMin', 0), po.pop('xMax', 1)
n = po.pop('n', 100)
vals = np.linspace(xMin, xMax, n)
args = (vals, map(fun, vals))
artist = po.axis.plot(*args,
**po.plotKwargs)
return artist
""" """
Color section Color section
""" """
def getColors(scalars, cmap=None, vmin=None, vmax=None):
def to_colors(scalars, cmap=None, vmin=None, vmax=None):
""" """
retrieve the colors for a list of scalars retrieve the colors for a list of scalars
""" """
...@@ -272,8 +426,9 @@ def getColors(scalars, cmap=None, vmin=None, vmax=None): ...@@ -272,8 +426,9 @@ def getColors(scalars, cmap=None, vmin=None, vmax=None):
return colorMap(map(norm, scalars)) return colorMap(map(norm, scalars))
def getColorsInverse(colors, cmap, vmin, vmax): def to_scalars(colors, cmap, vmin, vmax):
""" """
Inverse 'to_colors'
Reconstruct the numeric values (0 - 1) of given Reconstruct the numeric values (0 - 1) of given
Args: Args:
colors (list or rgba tuple) colors (list or rgba tuple)
...@@ -283,7 +438,7 @@ def getColorsInverse(colors, cmap, vmin, vmax): ...@@ -283,7 +438,7 @@ def getColorsInverse(colors, cmap, vmin, vmax):
""" """
# colors = np.array(colors)/255. # colors = np.array(colors)/255.
r = np.linspace(vmin, vmax, 256) r = np.linspace(vmin, vmax, 256)