Commit 4cb54f3e authored by Daniel Boeckenhoff's avatar Daniel Boeckenhoff
Browse files

tests work, plotting mesh and tensors also

parent fbb46f42
......@@ -4,7 +4,6 @@ from . import core
from . import bases
from . import lib
from .lib import *
from . import plotting
# __all__ = ['core', 'points3D']
from .core import Tensors, TensorFields, TensorMaps, Container
......
......@@ -20,6 +20,7 @@ from collections import Counter
import numpy as np
import sympy
import scipy as sp
import rna
import tfields.bases
np.seterr(all='warn', over='raise')
......@@ -241,7 +242,7 @@ class AbstractNdarray(np.ndarray):
raise NotImplementedError("Can not find save method for extension: "
"{extension}.".format(**locals()))
path = tfields.lib.in_out.resolve(path)
path = rna.path.resolve(path)
return save_method(path, **kwargs)
@classmethod
......@@ -260,7 +261,7 @@ class AbstractNdarray(np.ndarray):
if isinstance(path, (string_types, pathlib.Path)):
extension = pathlib.Path(path).suffix.lstrip('.')
path = str(path)
path = tfields.lib.in_out.resolve(path)
path = rna.path.resolve(path)
else:
extension = kwargs.pop('extension', 'npz')
......@@ -1335,9 +1336,9 @@ class Tensors(AbstractNdarray):
def plot(self, **kwargs):
"""
Forwarding to tfields.lib.plotting.plot_array
Forwarding to rna.plotting.plot_array
"""
artist = tfields.plotting.plot_array(self, **kwargs)
artist = rna.plotting.plot_array(self, **kwargs)
return artist
......@@ -1598,7 +1599,7 @@ class TensorFields(Tensors):
log.debug("Careful: Plotting tensors with field of"
"different dimension. No coord_sys check performed.")
if field.dim <= 3:
artist = tfields.plotting.plot_tensor_field(self, field,
artist = rna.plotting.plot_tensor_field(self, field,
**kwargs)
else:
raise NotImplementedError("Field of dimension {field.dim}"
......
......@@ -116,5 +116,3 @@ else:
from . import symbolics
from . import sets
from . import util
from . import in_out
from . import log
from contextlib import contextmanager
import logging
import time
def progressbar(iterable, log=None):
"""
Examples:
>>> import logging
>>> import tfields
>>> import sys
>>> sys.modules['tqdm'] = None
>>> log = logging.getLogger(__name__)
>>> a = range(3)
>>> for value in tfields.lib.log.progressbar(a, log=log):
... _ = value * 3
"""
if log is None:
log = logging.getLogger()
try:
from tqdm import tqdm as progressor
tqdm_exists = True
except ImportError as err:
def progressor(iterable):
"""
dummy function. Doe nothing
"""
return iterable
tqdm_exists = False
try:
nTotal = len(iterable)
except:
nTotal = None
for i in progressor(iterable):
if not tqdm_exists:
if nTotal is None:
log.info("Progress: item {i}".format(**locals()))
else:
log.info("Progress: {i} / {nTotal}".format(**locals()))
yield i
@contextmanager
def timeit(msg="No Description", log=None, precision=1):
"""
Context manager for autmated timeing
Args:
msg (str): message to customize the log message
log (logger)
precision (int): show until 10^-<precision> digits
"""
if log is None:
log = logging.getLogger()
startTime = time.time()
log.log(logging.INFO, "-> " * 30)
message = "Starting Process: {0} ->".format(msg)
log.log(logging.INFO, message)
yield
log.log(logging.INFO, "\t\t\t\t\t\t<- Process Duration:"
"{value:1.{precision}f} s".format(value=time.time() - startTime,
precision=precision))
log.log(logging.INFO, "<- " * 30)
if __name__ == '__main__':
import doctest
doctest.testmod()
......@@ -8,6 +8,7 @@ part of tfields library
"""
import numpy as np
import sympy
import rna
import tfields
# obj imports
......@@ -989,10 +990,10 @@ class Mesh3D(tfields.TensorMaps):
kwargs['color'] = self.maps[0].fields[map_index]
dim_defined = False
if 'axis' in kwargs:
if 'axes' in kwargs:
dim_defined = True
if 'zAxis' in kwargs:
if kwargs['zAxis'] is not None:
if 'z_index' in kwargs:
if kwargs['z_index'] is not None:
kwargs['dim'] = 3
else:
kwargs['dim'] = 2
......@@ -1003,7 +1004,7 @@ class Mesh3D(tfields.TensorMaps):
if not dim_defined:
kwargs['dim'] = 2
return tfields.plotting.plot_mesh(self, self.faces, **kwargs)
return rna.plotting.plot_mesh(self, self.faces, **kwargs)
if __name__ == '__main__': # pragma: no cover
......
......@@ -9,6 +9,7 @@ part of tfields library
import tfields
import sympy
import numpy as np
import rna
class Planes3D(tfields.TensorFields):
......@@ -41,9 +42,11 @@ class Planes3D(tfields.TensorFields):
centers = np.array(self)
norms = np.array(self.fields[0])
for i in range(len(self)):
artists.append(tfields.plotting.plot_plane(centers[i],
norms[i],
**kwargs))
artists.append(
rna.plotting.plot_plane(
centers[i],
norms[i],
**kwargs))
# symbolic = self.symbolic()
# planeMeshes = [tfields.Mesh3D([pl.arbitrary_point(t=(i + 1) * 1. / 2 * np.pi)
# for i in range(4)],
......
"""
Core plotting tools for tfields library. Especially PlotOptions class
is basis for many plotting expansions
TODO:
* add other library backends. Do not restrict to mpl
"""
import warnings
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from .mpl import *
from six import string_types
def set_default(dictionary, attr, value):
"""
Set defaults to a dictionary
"""
if attr not in dictionary:
dictionary[attr] = value
class PlotOptions(object):
"""
processing kwargs for plotting functions and providing easy
access to axis, dimension and plotting method as well as indices
for array choice (x..., y..., zAxis)
"""
def __init__(self, kwargs):
kwargs = dict(kwargs)
self.axis = kwargs.pop('axis', None)
self.dim = kwargs.pop('dim', None)
self.method = kwargs.pop('methodName', None)
self.setXYZAxis(kwargs)
self.plot_kwargs = kwargs
@property
def method(self):
"""
Method for plotting. Will be callable together with plot_kwargs
"""
return self._method
@method.setter
def method(self, methodName):
if not isinstance(methodName, str):
self._method = methodName
else:
self._method = getattr(self.axis, methodName)
@property
def dim(self):
"""
axis dimension
"""
return self._dim
@dim.setter
def dim(self, dim):
if dim is None:
if self._axis is None:
dim = 2
dim = axis_dim(self._axis)
elif self._axis is not None:
if not dim == axis_dim(self._axis):
raise ValueError("Axis and dim argument are in conflict.")
if dim not in [2, 3]:
raise NotImplementedError("Dimensions other than 2 or 3 are not supported.")
self._dim = dim
@property
def axis(self):
"""
The plt.Axis object that belongs to this instance
"""
if self._axis is None:
return gca(self._dim)
else:
return self._axis
@axis.setter
def axis(self, axis):
self._axis = axis
def setXYZAxis(self, kwargs):
self._xAxis = kwargs.pop('xAxis', 0)
self._yAxis = kwargs.pop('yAxis', 1)
zAxis = kwargs.pop('zAxis', None)
if zAxis is None and self.dim == 3:
indicesUsed = [0, 1, 2]
indicesUsed.remove(self._xAxis)
indicesUsed.remove(self._yAxis)
zAxis = indicesUsed[0]
self._zAxis = zAxis
def getXYZAxis(self):
return self._xAxis, self._yAxis, self._zAxis
def setVminVmaxAuto(self, vmin, vmax, scalars):
"""
Automatically set vmin and vmax as min/max of scalars
but only if vmin or vmax is None
"""
if scalars is None:
return
if len(scalars) < 2:
warnings.warn("Need at least two scalars to autoset vmin and/or vmax!")
return
if vmin is None:
vmin = min(scalars)
self.plot_kwargs['vmin'] = vmin
if vmax is None:
vmax = max(scalars)
self.plot_kwargs['vmax'] = vmax
def getNormArgs(self, vminDefault=0, vmaxDefault=1, cmapDefault=None):
if cmapDefault is None:
cmapDefault = plt.rcParams['image.cmap']
cmap = self.get('cmap', cmapDefault)
vmin = self.get('vmin', vminDefault)
vmax = self.get('vmax', vmaxDefault)
return cmap, vmin, vmax
def format_colors(self, colors, fmt='rgba', length=None):
"""
format colors according to fmt argument
Args:
colors (list/one value of rgba tuples/int/float/str): This argument will
be interpreted as color
fmt (str): rgba | hex | norm
length (int/None): if not None: correct colors lenght
Returns:
colors in fmt
"""
hasIter = True
if not hasattr(colors, '__iter__') or isinstance(colors, string_types):
# colors is just one element
hasIter = False
colors = [colors]
if fmt == 'norm':
if hasattr(colors[0], '__iter__'):
# rgba given but norm wanted
cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified',
vminDefault=None,
vmaxDefault=None)
colors = to_scalars(colors, cmap, vmin, vmax)
self.plot_kwargs['vmin'] = vmin
self.plot_kwargs['vmax'] = vmax
self.plot_kwargs['cmap'] = cmap
elif fmt == 'rgba':
if isinstance(colors[0], string_types):
# string color defined
colors = [mpl.colors.to_rgba(color) for color in colors]
else:
# norm given rgba wanted
cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified',
vminDefault=None,
vmaxDefault=None)
self.setVminVmaxAuto(vmin, vmax, colors)
# update vmin and vmax
cmap, vmin, vmax = self.getNormArgs()
colors = to_colors(colors,
vmin=vmin,
vmax=vmax,
cmap=cmap)
elif fmt == 'hex':
colors = [mpl.colors.to_hex(color) for color in colors]
else:
raise NotImplementedError("Color fmt '{fmt}' not implemented."
.format(**locals()))
if length is not None:
# just one colors value given
if len(colors) != length:
if not len(colors) == 1:
raise ValueError("Can not correct color length")
colors = list(colors)
colors *= length
elif not hasIter:
colors = colors[0]
colors = np.array(colors)
return colors
def delNormArgs(self):
self.plot_kwargs.pop('vmin', None)
self.plot_kwargs.pop('vmax', None)
self.plot_kwargs.pop('cmap', None)
def getSortedLabels(self, labels):
"""
Returns the labels corresponding to the axes
"""
return [labels[i] for i in self.getXYZAxis() if i is not None]
def get(self, attr, default=None):
return self.plot_kwargs.get(attr, default)
def pop(self, attr, default=None):
return self.plot_kwargs.pop(attr, default)
def set(self, attr, value):
self.plot_kwargs[attr] = value
def set_default(self, attr, value):
set_default(self.plot_kwargs, attr, value)
def retrieve(self, attr, default=None, keep=True):
if keep:
return self.get(attr, default)
else:
return self.pop(attr, default)
def retrieve_chain(self, *args, **kwargs):
default = kwargs.pop('default', None)
keep = kwargs.pop('keep', True)
if len(args) > 1:
return self.retrieve(args[0],
self.retrieve_chain(*args[1:],
default=default,
keep=keep),
keep=keep)
if len(args) != 1:
raise ValueError("Invalid number of args ({0})".format(len(args)))
return self.retrieve(args[0], default, keep=keep)
if __name__ == '__main__':
import doctest
doctest.testmod()
"""
Matplotlib specific plotting
"""
import tfields
import numpy as np
import warnings
import os
import matplotlib as mpl
import matplotlib.ticker
from matplotlib import style
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import mpl_toolkits.mplot3d as plt3D
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.dates as dates
from matplotlib.patches import Rectangle
from itertools import cycle
from functools import partial
import logging
def show():
plt.show()
def gca(dim=None, **kwargs):
"""
Forwarding to plt.gca but translating the dimension to projection
correct dimension
"""
if dim == 3:
axis = plt.gca(projection='3d', **kwargs)
else:
axis = plt.gca(**kwargs)
if dim != axis_dim(axis):
if dim is not None:
warnings.warn("You have another dimension set as gca."
"I will force the new dimension to return.")
axis = plt.gcf().add_subplot(1, 1, 1, **kwargs)
return axis
def upgrade_style(style, source, dest=None):
"""
Copy a style file at <origionalFilePath> to the <dest> which is the foreseen
local matplotlib rc dir by default
The style will be name <style>.mplstyle
Args:
style (str): name of style
source (str): full path to mplstyle file to use
dest (str): local directory to copy the file to. Matpotlib has to
search this directory for mplstyle files!
Examples:
>>> import tfields
>>> import os
>>> tfields.plotting.upgrade_style(
... 'tfields',
... os.path.join(os.path.dirname(tfields.plotting.__file__),
... 'tfields.mplstyle'))
"""
if dest is None:
dest = mpl.get_configdir()
style_extension = 'mplstyle'
path = tfields.lib.in_out.resolve(os.path.join(dest, style + '.' +
style_extension))
source = tfields.lib.in_out.resolve(source)
tfields.lib.in_out.cp(source, path)
def set_style(style='tfields', dest=None):
"""
Set the matplotlib style of name
Important:
Either you
Args:
style (str)
dest (str): local directory to use file from. if None, use default
maplotlib destination
"""
if dest is None:
dest = mpl.get_configdir()
style_extension = 'mplstyle'
path = tfields.lib.in_out.resolve(os.path.join(dest, style + '.' +
style_extension))
if style in mpl.style.available:
plt.style.use(style)
elif os.path.exists(path):
plt.style.use(path)
else:
log = logging.getLogger()
if style == 'tfields':
log.warning("I will copy the default style to {dest}."
.format(**locals()))
source = os.path.join(os.path.dirname(__file__),
style + '.' + style_extension)
try:
upgrade_style(style, source, dest)
set_style(style)
except Exception:
log.error("Could not set style")
else:
log.error("Could not set style {path}. Probably you would want to"
"call tfields.plotting.upgrade_style(<style>, "
"<path to mplstyle file that should be copied>)"
"once".format(**locals()))
def save(path, *fmts, **kwargs):
"""
Args:
path (str): path without extension to save to
*fmts (str): format of the figure to save. If multiple are given, create
that many files
**kwargs:
axis
fig
"""
log = logging.getLogger()
# catch figure from axis or fig
axis = kwargs.get('axis', None)
if axis is None:
fig_default = plt.gcf()
axis = gca()
if fig_default is None:
raise ValueError("fig_default may not be None")
else:
fig_default = axis.figure
fig = kwargs.get('fig', fig_default)
# set current figure
plt.figure(fig.number)
# crop the plot down based on the extents of the artists in the plot
kwargs['bbox_inches'] = kwargs.pop('bbox_inches', 'tight')
if kwargs['bbox_inches'] == 'tight':
extra_artists = None
for ax in fig.get_axes():
first_label = ax.get_legend_handles_labels()[0] or None
if first_label:
if not extra_artists:
extra_artists = []
if isinstance(first_label, list):
extra_artists.extend(first_label)
else:
extra_artists.append(first_label)
kwargs['bbox_extra_artists'] = kwargs.pop('bbox_extra_artists',
extra_artists)
if len(fmts) != 0:
for fmt in fmts:
if path.endswith('.'):
new_file_path = path + fmt
elif '{fmt}' in path:
new_file_path = path.format(**locals())
else:
new_file_path = path + '.' + fmt
save(new_file_path, **kwargs)
else:
path = tfields.lib.in_out.resolve(path)
log.info("Saving figure as {0}".format(path))
plt.savefig(path,
**kwargs)
def plot_array(array, **kwargs):
"""
Points3D plotting method.
Args:
array (numpy array)
axis (matplotlib.Axis) object
xAxis (int): coordinate index that should be on xAxis
yAxis (int): coordinate index that should be on yAxis
zAxis (int or None): coordinate index that should be on zAxis.
If it evaluates to None, 2D plot will be done.
methodName (str): method name to use for filling the axis
Returns:
Artist or list of Artists (imitating the axis.scatter/plot behaviour).
Better Artist not list of Artists
"""
array = np.array(array)
tfields.plotting.set_default(kwargs, 'methodName', 'scatter')
po = tfields.plotting.PlotOptions(kwargs)
labels = po.pop('labels', ['x (m)', 'y (m)', 'z (m)'])
xAxis, yAxis, zAxis = po.getXYZAxis()
tfields.plotting.set_labels(po.axis, *po.getSortedLabels(labels))
if zAxis is None: