Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Daniel Boeckenhoff
tfields
Commits
5750f527
Commit
5750f527
authored
Jul 13, 2018
by
Daniel Boeckenhoff
Browse files
plotting for Tensors
parent
6a7107c3
Changes
5
Show whitespace changes
Inline
Side-by-side
tfields/core.py
View file @
5750f527
...
...
@@ -488,7 +488,7 @@ class Tensors(AbstractNdarray):
''' transform all raw inputs to cls type with correct coordSys. Also
automatically make a copy of those instances that are of the correct
type already.'''
objects
=
[
cls
(
t
,
**
kwargs
)
for
t
in
objects
]
objects
=
[
cls
.
__new__
(
cls
,
t
,
**
kwargs
)
for
t
in
objects
]
''' check rank and dimension equality '''
if
not
len
(
set
([
t
.
rank
for
t
in
objects
]))
==
1
:
...
...
@@ -687,7 +687,8 @@ class Tensors(AbstractNdarray):
>>> p.mirror(1)
>>> assert p.equal([[1, -2, 3], [4, -5, 6], [1, -2, -6]])
multiple coordinates can be mirrored. Eg. a point mirrorion would be
multiple coordinates can be mirrored at the same time
i.e. a point mirrorion would be
>>> p = tfields.Tensors([[1., 2., 3.], [4., 5., 6.], [1, 2, -6]])
>>> p.mirror([0,2])
>>> assert p.equal([[-1, 2, -3], [-4, 5, -6], [-1, 2., 6.]])
...
...
@@ -696,7 +697,7 @@ class Tensors(AbstractNdarray):
The mirroring will only be applied to the points meeting the condition.
>>> import sympy
>>> x, y, z = sympy.symbols('x y z')
>>> p.mirror([0,2], y > 3)
>>> p.mirror([0,
2], y > 3)
>>> p.equal([[-1, 2, -3], [4, 5, 6], [-1, 2, 6]])
True
...
...
@@ -707,20 +708,28 @@ class Tensors(AbstractNdarray):
condition
=
self
.
evalf
(
condition
)
if
isinstance
(
coordinate
,
list
)
or
isinstance
(
coordinate
,
tuple
):
for
c
in
coordinate
:
self
.
mirror
(
c
,
condition
)
self
.
mirror
(
c
,
condition
=
condition
)
elif
isinstance
(
coordinate
,
int
):
self
[:,
coordinate
][
condition
]
*=
-
1
else
:
raise
TypeError
()
def
to_segment
(
self
,
segment
,
num_segments
,
coordinate
,
periodicity
=
2
*
np
.
pi
,
offset
=
0
,
periodicity
=
2
*
np
.
pi
,
offset
=
0
.
,
coordSys
=
None
):
"""
For circular (close into themself after
<periodicity>) coordinates at index <coordinate> assume
<num_segments> segments and transform all values to
segment number <segment>
Args:
segment (int): segment index (starting at 0)
num_segments (int): number of segments
coordinate (int): coordinate index
periodicity (float): after what lenght, the coordiante repeats
offset (float): offset in the mapping
coordSys (str or sympy.CoordinateSystem): in which coord sys the
transformation should be done
Examples:
>>> import tfields
>>> import numpy as np
...
...
@@ -1076,6 +1085,13 @@ class Tensors(AbstractNdarray):
evalfs
,
evecs
=
np
.
linalg
.
eigh
(
cov
)
return
(
evecs
*
evalfs
.
T
).
T
def
plot
(
self
,
**
kwargs
):
"""
Forwarding to tfields.lib.plotting.plotArray
"""
artist
=
tfields
.
plotting
.
plot_array
(
self
,
**
kwargs
)
return
artist
class
TensorFields
(
Tensors
):
"""
...
...
tfields/lib/__init__.py
View file @
5750f527
...
...
@@ -111,3 +111,5 @@ else:
from
.
import
symbolics
from
.
import
sets
from
.
import
util
from
.
import
in_out
from
.
import
log
tfields/mesh3D.py
View file @
5750f527
...
...
@@ -725,7 +725,7 @@ class Mesh3D(tfields.TensorMaps):
def
plot
(
self
,
**
kwargs
):
# pragma: no cover
"""
Forwarding to plotTools.plot
M
esh
Forwarding to plotTools.plot
_m
esh
"""
scalars_demanded
=
any
([
v
in
kwargs
for
v
in
[
'vmin'
,
'vmax'
,
'cmap'
]])
map_index
=
kwargs
.
pop
(
'map_index'
,
None
if
not
scalars_demanded
else
0
)
...
...
@@ -748,7 +748,7 @@ class Mesh3D(tfields.TensorMaps):
if
not
dim_defined
:
kwargs
[
'dim'
]
=
2
return
tfields
.
plotting
.
plot
M
esh
(
self
,
self
.
faces
,
**
kwargs
)
return
tfields
.
plotting
.
plot
_m
esh
(
self
,
self
.
faces
,
**
kwargs
)
if
__name__
==
'__main__'
:
# pragma: no cover
...
...
tfields/plotting/__init__.py
View file @
5750f527
"""
Core plotting tools for tfields library. Especially PlotOptions class
is basis for many plotting expansions
TODO:
* add other library backends. Do not restrict to mpl
"""
import
warnings
import
matplotlib.pyplot
as
plt
...
...
@@ -17,64 +20,6 @@ def setDefault(dictionary, attr, value):
dictionary
[
attr
]
=
value
def
gca
(
dim
=
None
,
**
kwargs
):
"""
Forwarding to plt.gca but translating the dimension to projection
correct dimension
"""
if
dim
==
3
:
axis
=
plt
.
gca
(
projection
=
'3d'
,
**
kwargs
)
else
:
axis
=
plt
.
gca
(
**
kwargs
)
if
dim
!=
axisDim
(
axis
):
if
dim
is
not
None
:
warnings
.
warn
(
"You have another dimension set as gca."
"I will force the new dimension to return."
)
axis
=
plt
.
gcf
().
add_subplot
(
1
,
1
,
1
,
**
kwargs
)
return
axis
def
axisDim
(
axis
):
"""
Returns int: axis dimension
"""
if
hasattr
(
axis
,
'get_zlim'
):
return
3
else
:
return
2
def
setLabels
(
axis
,
*
labels
):
axis
.
set_xlabel
(
labels
[
0
])
axis
.
set_ylabel
(
labels
[
1
])
if
axisDim
(
axis
)
==
3
:
axis
.
set_zlabel
(
labels
[
2
])
def
autoscale3D
(
axis
,
array
=
None
,
xLim
=
None
,
yLim
=
None
,
zLim
=
None
):
if
array
is
not
None
:
xMin
,
yMin
,
zMin
=
array
.
min
(
axis
=
0
)
xMax
,
yMax
,
zMax
=
array
.
max
(
axis
=
0
)
xLim
=
(
xMin
,
xMax
)
yLim
=
(
yMin
,
yMax
)
zLim
=
(
zMin
,
zMax
)
xLimAxis
=
axis
.
get_xlim
()
yLimAxis
=
axis
.
get_ylim
()
zLimAxis
=
axis
.
get_zlim
()
if
not
False
:
# not empty axis
xMin
=
min
(
xLimAxis
[
0
],
xLim
[
0
])
yMin
=
min
(
yLimAxis
[
0
],
yLim
[
0
])
zMin
=
min
(
zLimAxis
[
0
],
zLim
[
0
])
xMax
=
max
(
xLimAxis
[
1
],
xLim
[
1
])
yMax
=
max
(
yLimAxis
[
1
],
yLim
[
1
])
zMax
=
max
(
zLimAxis
[
1
],
zLim
[
1
])
axis
.
set_xlim
([
xMin
,
xMax
])
axis
.
set_ylim
([
yMin
,
yMax
])
axis
.
set_zlim
([
zMin
,
zMax
])
class
PlotOptions
(
object
):
"""
processing kwargs for plotting functions and providing easy
...
...
@@ -115,9 +60,9 @@ class PlotOptions(object):
if
dim
is
None
:
if
self
.
_axis
is
None
:
dim
=
2
dim
=
axis
D
im
(
self
.
_axis
)
dim
=
axis
_d
im
(
self
.
_axis
)
elif
self
.
_axis
is
not
None
:
if
not
dim
==
axis
D
im
(
self
.
_axis
):
if
not
dim
==
axis
_d
im
(
self
.
_axis
):
raise
ValueError
(
"Axis and dim argument are in conflict."
)
if
dim
not
in
[
2
,
3
]:
raise
NotImplementedError
(
"Dimensions other than 2 or 3 are not supported."
)
...
...
@@ -198,7 +143,7 @@ class PlotOptions(object):
cmap
,
vmin
,
vmax
=
self
.
getNormArgs
(
cmapDefault
=
'NotSpecified'
,
vminDefault
=
None
,
vmaxDefault
=
None
)
colors
=
getColorsInve
rs
e
(
colors
,
cmap
,
vmin
,
vmax
)
colors
=
to_scala
rs
(
colors
,
cmap
,
vmin
,
vmax
)
self
.
plotKwargs
[
'vmin'
]
=
vmin
self
.
plotKwargs
[
'vmax'
]
=
vmax
self
.
plotKwargs
[
'cmap'
]
=
cmap
...
...
@@ -214,7 +159,7 @@ class PlotOptions(object):
self
.
setVminVmaxAuto
(
vmin
,
vmax
,
colors
)
# update vmin and vmax
cmap
,
vmin
,
vmax
=
self
.
getNormArgs
()
colors
=
getC
olors
(
colors
,
colors
=
to_c
olors
(
colors
,
vmin
=
vmin
,
vmax
=
vmax
,
cmap
=
cmap
)
...
...
tfields/plotting/mpl.py
View file @
5750f527
"""
Matplotlib specific plotting
"""
import
tfields
import
numpy
as
np
import
warnings
import
os
import
matplotlib
as
mpl
import
matplotlib.pyplot
as
plt
from
matplotlib.patches
import
Circle
import
mpl_toolkits.mplot3d
as
plt3D
from
mpl_toolkits.axes_grid1
import
make_axes_locatable
import
matplotlib.dates
as
dates
from
itertools
import
cycle
import
logging
def
gca
(
dim
=
None
,
**
kwargs
):
"""
Forwarding to plt.gca but translating the dimension to projection
correct dimension
"""
if
dim
==
3
:
axis
=
plt
.
gca
(
projection
=
'3d'
,
**
kwargs
)
else
:
axis
=
plt
.
gca
(
**
kwargs
)
if
dim
!=
axis_dim
(
axis
):
if
dim
is
not
None
:
warnings
.
warn
(
"You have another dimension set as gca."
"I will force the new dimension to return."
)
axis
=
plt
.
gcf
().
add_subplot
(
1
,
1
,
1
,
**
kwargs
)
return
axis
def
upgrade_style
(
style
,
source
,
dest
=
"~/.config/matplotlib/"
):
"""
Copy a style file at <origionalFilePath> to the <dest> which is the foreseen
local matplotlib rc dir by default
The style will be name <style>.mplstyle
Args:
style (str): name of style
source (str): full path to mplstyle file to use
dest (str): local directory to copy the file to. Matpotlib has to
search this directory for mplstyle files!
"""
styleExtension
=
'mplstyle'
path
=
tfields
.
lib
.
in_out
.
resolve
(
os
.
path
.
join
(
dest
,
style
+
'.'
+
styleExtension
))
source
=
tfields
.
lib
.
in_out
.
resolve
(
source
)
tfields
.
lib
.
in_out
.
cp
(
source
,
path
)
def
set_style
(
style
=
'tfields'
,
dest
=
"~/.config/matplotlib/"
):
"""
Set the matplotlib style of name
Important:
Either you
Args:
style (str)
dest (str): local directory to use file from. if None, use default maplotlib styles
"""
if
dest
is
None
:
path
=
style
else
:
styleExtension
=
'mplstyle'
path
=
tfields
.
lib
.
in_out
.
resolve
(
os
.
path
.
join
(
dest
,
style
+
'.'
+
styleExtension
))
try
:
plt
.
style
.
use
(
path
)
except
IOError
:
log
=
logging
.
getLogger
()
if
style
==
'tfields'
:
log
.
warning
(
"I will copy the default style to {dest}."
.
format
(
**
locals
()))
source
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
style
+
'.'
+
styleExtension
)
upgrade_style
(
style
,
source
,
dest
)
set_style
(
style
)
else
:
log
.
error
(
"Could not set style {path}. Probably you would want to"
"call tfields.plotting.upgrade_style(<style>, "
"<path to mplstyle file that should be copied>)"
"once"
.
format
(
**
locals
()))
def
save
(
path
,
*
fmts
,
**
kwargs
):
"""
Args:
path (str): path without extension to save to
*fmts (str): format of the figure to save. If multiple are given, create
that many files
**kwargs:
axis
fig
"""
log
=
logging
.
getLogger
()
# catch figure from axis or fig
axis
=
kwargs
.
get
(
'axis'
,
None
)
if
axis
is
None
:
figDefault
=
plt
.
gcf
()
axis
=
gca
()
else
:
figDefault
=
axis
.
figure
fig
=
kwargs
.
get
(
'fig'
,
figDefault
)
# set current figure
plt
.
figure
(
fig
.
number
)
# crop the plot down based on the extents of the artists in the plot
kwargs
[
'bbox_inches'
]
=
kwargs
.
pop
(
'bbox_inches'
,
'tight'
)
if
kwargs
[
'bbox_inches'
]
==
'tight'
:
extraArtists
=
None
for
ax
in
fig
.
get_axes
():
firstLabel
=
ax
.
get_legend_handles_labels
()[
0
]
or
None
if
firstLabel
:
if
not
extraArtists
:
extraArtists
=
[]
extraArtists
.
append
(
firstLabel
)
kwargs
[
'bbox_extra_artists'
]
=
kwargs
.
pop
(
'bbox_extra_artists'
,
extraArtists
)
if
len
(
fmts
)
!=
0
:
for
fmt
in
fmts
:
if
path
.
endswith
(
'.'
):
newFilePath
=
path
+
fmt
elif
'{fmt}'
in
path
:
newFilePath
=
path
.
format
(
**
locals
())
else
:
newFilePath
=
path
+
'.'
+
fmt
save
(
newFilePath
,
**
kwargs
)
else
:
path
=
tfields
.
lib
.
in_out
.
resolve
(
path
)
log
.
info
(
"Saving figure as {0}"
.
format
(
path
))
plt
.
savefig
(
path
,
**
kwargs
)
def
plot
A
rray
(
array
,
**
kwargs
):
def
plot
_a
rray
(
array
,
**
kwargs
):
"""
Points3D plotting method.
...
...
@@ -30,7 +155,7 @@ def plotArray(array, **kwargs):
labelList
=
po
.
pop
(
'labelList'
,
[
'x (m)'
,
'y (m)'
,
'z (m)'
])
xAxis
,
yAxis
,
zAxis
=
po
.
getXYZAxis
()
tfields
.
plotting
.
set
L
abels
(
po
.
axis
,
*
po
.
getSortedLabels
(
labelList
))
tfields
.
plotting
.
set
_l
abels
(
po
.
axis
,
*
po
.
getSortedLabels
(
labelList
))
if
zAxis
is
None
:
args
=
[
array
[:,
xAxis
],
array
[:,
yAxis
]]
...
...
@@ -43,7 +168,7 @@ def plotArray(array, **kwargs):
return
artist
def
plot
M
esh
(
vertices
,
faces
,
**
kwargs
):
def
plot
_m
esh
(
vertices
,
faces
,
**
kwargs
):
"""
Args:
axis (matplotlib axis)
...
...
@@ -85,7 +210,7 @@ def plotMesh(vertices, faces, **kwargs):
directionVector
=
np
.
array
([
1.
,
1.
,
1.
])
directionVector
[
xAxis
]
=
0.
directionVector
[
yAxis
]
=
0.
normVectors
=
mesh
.
triangles
.
norms
()
normVectors
=
mesh
.
triangles
()
.
norms
()
dotProduct
=
np
.
dot
(
normVectors
,
directionVector
)
nFacesInitial
=
len
(
faces
)
faces
=
faces
[
dotProduct
>
0
]
...
...
@@ -108,7 +233,7 @@ def plotMesh(vertices, faces, **kwargs):
d
=
po
.
plotKwargs
d
[
'xAxis'
]
=
xAxis
d
[
'yAxis'
]
=
yAxis
artist
=
plot
A
rray
(
vertices
,
**
d
)
artist
=
plot
_a
rray
(
vertices
,
**
d
)
elif
po
.
dim
==
3
:
label
=
po
.
pop
(
'label'
,
None
)
color
=
po
.
retrieveChain
(
'color'
,
'c'
,
'facecolors'
,
...
...
@@ -141,7 +266,7 @@ def plotMesh(vertices, faces, **kwargs):
artist
.
set_alpha
(
alpha
)
# for some reason auto-scale does not work
tfields
.
plotting
.
autoscale
3D
(
po
.
axis
,
array
=
vertices
)
tfields
.
plotting
.
autoscale
_3d
(
po
.
axis
,
array
=
vertices
)
# legend lables do not work at all as an argument
if
label
:
...
...
@@ -152,12 +277,15 @@ def plotMesh(vertices, faces, **kwargs):
artist
.
_facecolors2d
=
None
labelList
=
[
'x (m)'
,
'y (m)'
,
'z (m)'
]
tfields
.
plotting
.
setLabels
(
po
.
axis
,
*
po
.
getSortedLabels
(
labelList
))
tfields
.
plotting
.
set_labels
(
po
.
axis
,
*
po
.
getSortedLabels
(
labelList
))
else
:
raise
NotImplementedError
(
"Dimension != 2|3"
)
return
artist
def
plot
VectorF
ield
(
points
,
vectors
,
**
kwargs
):
def
plot
_tensor_f
ield
(
points
,
vectors
,
**
kwargs
):
"""
Args:
points (array_like): base vectors
...
...
@@ -173,14 +301,16 @@ def plotVectorField(points, vectors, **kwargs):
artists
.
append
(
po
.
axis
.
quiver
(
point
[
xAxis
],
point
[
yAxis
],
point
[
zAxis
],
vector
[
xAxis
],
vector
[
yAxis
],
vector
[
zAxis
],
**
po
.
plotKwargs
))
el
se
:
el
if
po
.
dim
==
2
:
artists
.
append
(
po
.
axis
.
quiver
(
point
[
xAxis
],
point
[
yAxis
],
vector
[
xAxis
],
vector
[
yAxis
],
**
po
.
plotKwargs
))
else
:
raise
NotImplementedError
(
"Dimension != 2|3"
)
return
artists
def
plot
P
lane
(
point
,
normal
,
**
kwargs
):
def
plot
_p
lane
(
point
,
normal
,
**
kwargs
):
def
plot_vector
(
fig
,
orig
,
v
,
color
=
'blue'
):
axis
=
fig
.
gca
(
projection
=
'3d'
)
...
...
@@ -241,7 +371,7 @@ def plotPlane(point, normal, **kwargs):
pathpatch_translate
(
patch
,
(
point
[
0
],
point
[
1
],
point
[
2
]))
def
plot
S
phere
(
point
,
radius
,
**
kwargs
):
def
plot
_s
phere
(
point
,
radius
,
**
kwargs
):
po
=
tfields
.
plotting
.
PlotOptions
(
kwargs
)
# Make data
u
=
np
.
linspace
(
0
,
2
*
np
.
pi
,
100
)
...
...
@@ -254,10 +384,34 @@ def plotSphere(point, radius, **kwargs):
return
po
.
axis
.
plot_surface
(
x
,
y
,
z
,
**
po
.
plotKwargs
)
def
plot_function
(
fun
,
**
kwargs
):
"""
Args:
axis (matplotlib.Axis) object
Returns:
Artist or list of Artists (imitating the axis.scatter/plot behaviour).
Better Artist not list of Artists
"""
import
numpy
as
np
labelList
=
[
'x'
,
'f(x)'
]
po
=
tfields
.
plotting
.
PlotOptions
(
kwargs
)
tfields
.
plotting
.
set_labels
(
po
.
axis
,
*
labelList
)
xMin
,
xMax
=
po
.
pop
(
'xMin'
,
0
),
po
.
pop
(
'xMax'
,
1
)
n
=
po
.
pop
(
'n'
,
100
)
vals
=
np
.
linspace
(
xMin
,
xMax
,
n
)
args
=
(
vals
,
map
(
fun
,
vals
))
artist
=
po
.
axis
.
plot
(
*
args
,
**
po
.
plotKwargs
)
return
artist
"""
Color section
"""
def
getColors
(
scalars
,
cmap
=
None
,
vmin
=
None
,
vmax
=
None
):
def
to_colors
(
scalars
,
cmap
=
None
,
vmin
=
None
,
vmax
=
None
):
"""
retrieve the colors for a list of scalars
"""
...
...
@@ -272,8 +426,9 @@ def getColors(scalars, cmap=None, vmin=None, vmax=None):
return
colorMap
(
map
(
norm
,
scalars
))
def
getColorsInve
rs
e
(
colors
,
cmap
,
vmin
,
vmax
):
def
to_scala
rs
(
colors
,
cmap
,
vmin
,
vmax
):
"""
Inverse 'to_colors'
Reconstruct the numeric values (0 - 1) of given
Args:
colors (list or rgba tuple)
...
...
@@ -283,7 +438,7 @@ def getColorsInverse(colors, cmap, vmin, vmax):
"""
# colors = np.array(colors)/255.
r
=
np
.
linspace
(
vmin
,
vmax
,
256
)
norm
=
m
atplotlib
.
colors
.
Normalize
(
vmin
,
vmax
)
norm
=
m
pl
.
colors
.
Normalize
(
vmin
,
vmax
)
mapvals
=
cmap
(
norm
(
r
))[:,
:
4
]
# there are 4 channels: r,g,b,a
scalars
=
[]
for
color
in
colors
:
...
...
@@ -292,8 +447,159 @@ def getColorsInverse(colors, cmap, vmin, vmax):
return
scalars
def
colormap
(
seq
):
"""
Args:
seq (iterable): a sequence of floats and RGB-tuples. The floats should be increasing
and in the interval (0,1).
Returns:
LinearSegmentedColormap
"""
seq
=
[(
None
,)
*
3
,
0.0
]
+
list
(
seq
)
+
[
1.0
,
(
None
,)
*
3
]
cdict
=
{
'red'
:
[],
'green'
:
[],
'blue'
:
[]}
for
i
,
item
in
enumerate
(
seq
):
if
isinstance
(
item
,
float
):
r1
,
g1
,
b1
=
seq
[
i
-
1
]
r2
,
g2
,
b2
=
seq
[
i
+
1
]
cdict
[
'red'
].
append
([
item
,
r1
,
r2
])
cdict
[
'green'
].
append
([
item
,
g1
,
g2
])
cdict
[
'blue'
].
append
([
item
,
b1
,
b2
])
return
mpl
.
colors
.
LinearSegmentedColormap
(
'CustomMap'
,
cdict
)
def
color_cycle
(
colormap
=
None
,
n
=
None
):
"""
Args:
colormap (matplotlib colormap): e.g. plotTools.plt.cm.coolwarm
n (int): needed for colormap argument
"""
if
colormap
:
color_rgb
=
to_colors
(
np
.
linspace
(
0
,
1
,
n
),
cmap
=
colormap
,
vmin
=
0
,
vmax
=
1
)
colors
=
map
(
lambda
rgb
:
'#%02x%02x%02x'
%
(
rgb
[
0
]
*
255
,
rgb
[
1
]
*
255
,
rgb
[
2
]
*
255
),
tuple
(
color_rgb
[:,
0
:
-
1
]))
else
:
colors
=
list
([
color
[
'color'
]
for
color
in
mpl
.
rcParams
[
'axes.prop_cycle'
]])
return
cycle
(
colors
)
"""
Display section
"""
def
axis_dim
(
axis
):
"""
Returns int: axis dimension
"""
if
hasattr
(
axis
,
'get_zlim'
):
return
3
else
:
return
2
def
set_aspect_equal
(
axis
):
"""Fix equal aspect bug for 3D plots."""
if
axis_dim
(
axis
)
==
2
:
axis
.
set_aspect
(
'equal'
)
return
xlim
=
axis
.
get_xlim3d
()
ylim
=
axis
.
get_ylim3d
()
zlim
=
axis
.
get_zlim3d
()
from
numpy
import
mean
xmean
=
mean
(
xlim
)
ymean
=
mean
(
ylim
)
zmean
=
mean
(
zlim
)
plot_radius
=
max
([
abs
(
lim
-
mean_
)
for
lims
,
mean_
in
((
xlim
,
xmean
),
(
ylim
,
ymean
),
(
zlim
,
zmean
))
for
lim
in
lims
])
axis
.
set_xlim3d
([
xmean
-
plot_radius
,
xmean
+
plot_radius
])
axis
.
set_ylim3d
([
ymean
-
plot_radius
,
ymean
+
plot_radius
])
axis
.
set_zlim3d
([
zmean
-
plot_radius
,
zmean
+
plot_radius
])
def
set_axis_off
(
axis
):
if
axis_dim
(
axis
)
==
2
:
axis
.
set_axis_off
()
else
:
axis
.
_axis3don
=
False
def
autoscale_3d
(
axis
,
array
=
None
,
xLim
=
None
,
yLim
=
None
,
zLim
=
None
):
if
array
is
not
None
:
xMin
,
yMin
,
zMin
=
array
.
min
(
axis
=
0
)
xMax
,
yMax
,
zMax
=
array
.
max
(
axis
=
0
)
xLim
=
(
xMin
,
xMax
)
yLim
=
(
yMin
,
yMax
)
zLim
=
(
zMin
,
zMax
)
xLimAxis
=
axis
.
get_xlim
()
yLimAxis
=
axis
.
get_ylim
()
zLimAxis
=
axis
.
get_zlim
()
if
not
False
:
# not empty axis
xMin
=
min
(
xLimAxis
[
0
],
xLim
[
0
])
yMin
=
min
(
yLimAxis
[
0
],
yLim
[
0
])
zMin
=
min
(
zLimAxis
[
0
],
zLim
[
0
])
xMax
=
max
(
xLimAxis
[
1
],
xLim
[
1
])
yMax
=
max
(
yLimAxis
[
1
],
yLim
[
1
])
zMax
=
max
(
zLimAxis
[
1
],
zLim
[
1
])
axis
.
set_xlim
([
xMin
,
xMax
])
axis
.
set_ylim
([
yMin
,
yMax
])
axis
.
set_zlim
([
zMin
,
zMax
])
def
setLegend
(
axis
,
artists
):
handles
=
[]
for
artist
in
artists
:
if
isinstance
(
artist
,
list
):
handles
.
append
(
artist
[
0
])
else
:
handles
.
append
(
artist
)
axis
.
legend
(
handles
=
handles
)
def
set_color_bar
(
axis
,
artist
,
label
=
None
,
divide
=
True
,
**
kwargs
):
# colorbar
if
divide
:
divider
=
make_axes_locatable
(
axis
)
axis
=
divider
.
append_axes
(
"right"
,
size
=
"2%"
,
pad
=
0.05
)
cbar
=
plt
.
colorbar
(
artist
,
cax
=
axis
,
**
kwargs
)
# label
if
label
is
None
:
artLabel
=
artist
.
get_label
()