Commit 4cb54f3e authored by Daniel Boeckenhoff's avatar Daniel Boeckenhoff
Browse files

tests work, plotting mesh and tensors also

parent fbb46f42
...@@ -4,7 +4,6 @@ from . import core ...@@ -4,7 +4,6 @@ from . import core
from . import bases from . import bases
from . import lib from . import lib
from .lib import * from .lib import *
from . import plotting
# __all__ = ['core', 'points3D'] # __all__ = ['core', 'points3D']
from .core import Tensors, TensorFields, TensorMaps, Container from .core import Tensors, TensorFields, TensorMaps, Container
......
...@@ -20,6 +20,7 @@ from collections import Counter ...@@ -20,6 +20,7 @@ from collections import Counter
import numpy as np import numpy as np
import sympy import sympy
import scipy as sp import scipy as sp
import rna
import tfields.bases import tfields.bases
np.seterr(all='warn', over='raise') np.seterr(all='warn', over='raise')
...@@ -241,7 +242,7 @@ class AbstractNdarray(np.ndarray): ...@@ -241,7 +242,7 @@ class AbstractNdarray(np.ndarray):
raise NotImplementedError("Can not find save method for extension: " raise NotImplementedError("Can not find save method for extension: "
"{extension}.".format(**locals())) "{extension}.".format(**locals()))
path = tfields.lib.in_out.resolve(path) path = rna.path.resolve(path)
return save_method(path, **kwargs) return save_method(path, **kwargs)
@classmethod @classmethod
...@@ -260,7 +261,7 @@ class AbstractNdarray(np.ndarray): ...@@ -260,7 +261,7 @@ class AbstractNdarray(np.ndarray):
if isinstance(path, (string_types, pathlib.Path)): if isinstance(path, (string_types, pathlib.Path)):
extension = pathlib.Path(path).suffix.lstrip('.') extension = pathlib.Path(path).suffix.lstrip('.')
path = str(path) path = str(path)
path = tfields.lib.in_out.resolve(path) path = rna.path.resolve(path)
else: else:
extension = kwargs.pop('extension', 'npz') extension = kwargs.pop('extension', 'npz')
...@@ -1335,9 +1336,9 @@ class Tensors(AbstractNdarray): ...@@ -1335,9 +1336,9 @@ class Tensors(AbstractNdarray):
def plot(self, **kwargs): def plot(self, **kwargs):
""" """
Forwarding to tfields.lib.plotting.plot_array Forwarding to rna.plotting.plot_array
""" """
artist = tfields.plotting.plot_array(self, **kwargs) artist = rna.plotting.plot_array(self, **kwargs)
return artist return artist
...@@ -1598,7 +1599,7 @@ class TensorFields(Tensors): ...@@ -1598,7 +1599,7 @@ class TensorFields(Tensors):
log.debug("Careful: Plotting tensors with field of" log.debug("Careful: Plotting tensors with field of"
"different dimension. No coord_sys check performed.") "different dimension. No coord_sys check performed.")
if field.dim <= 3: if field.dim <= 3:
artist = tfields.plotting.plot_tensor_field(self, field, artist = rna.plotting.plot_tensor_field(self, field,
**kwargs) **kwargs)
else: else:
raise NotImplementedError("Field of dimension {field.dim}" raise NotImplementedError("Field of dimension {field.dim}"
......
...@@ -116,5 +116,3 @@ else: ...@@ -116,5 +116,3 @@ else:
from . import symbolics from . import symbolics
from . import sets from . import sets
from . import util from . import util
from . import in_out
from . import log
from contextlib import contextmanager
import logging
import time
def progressbar(iterable, log=None):
"""
Examples:
>>> import logging
>>> import tfields
>>> import sys
>>> sys.modules['tqdm'] = None
>>> log = logging.getLogger(__name__)
>>> a = range(3)
>>> for value in tfields.lib.log.progressbar(a, log=log):
... _ = value * 3
"""
if log is None:
log = logging.getLogger()
try:
from tqdm import tqdm as progressor
tqdm_exists = True
except ImportError as err:
def progressor(iterable):
"""
dummy function. Doe nothing
"""
return iterable
tqdm_exists = False
try:
nTotal = len(iterable)
except:
nTotal = None
for i in progressor(iterable):
if not tqdm_exists:
if nTotal is None:
log.info("Progress: item {i}".format(**locals()))
else:
log.info("Progress: {i} / {nTotal}".format(**locals()))
yield i
@contextmanager
def timeit(msg="No Description", log=None, precision=1):
"""
Context manager for autmated timeing
Args:
msg (str): message to customize the log message
log (logger)
precision (int): show until 10^-<precision> digits
"""
if log is None:
log = logging.getLogger()
startTime = time.time()
log.log(logging.INFO, "-> " * 30)
message = "Starting Process: {0} ->".format(msg)
log.log(logging.INFO, message)
yield
log.log(logging.INFO, "\t\t\t\t\t\t<- Process Duration:"
"{value:1.{precision}f} s".format(value=time.time() - startTime,
precision=precision))
log.log(logging.INFO, "<- " * 30)
if __name__ == '__main__':
import doctest
doctest.testmod()
...@@ -8,6 +8,7 @@ part of tfields library ...@@ -8,6 +8,7 @@ part of tfields library
""" """
import numpy as np import numpy as np
import sympy import sympy
import rna
import tfields import tfields
# obj imports # obj imports
...@@ -989,10 +990,10 @@ class Mesh3D(tfields.TensorMaps): ...@@ -989,10 +990,10 @@ class Mesh3D(tfields.TensorMaps):
kwargs['color'] = self.maps[0].fields[map_index] kwargs['color'] = self.maps[0].fields[map_index]
dim_defined = False dim_defined = False
if 'axis' in kwargs: if 'axes' in kwargs:
dim_defined = True dim_defined = True
if 'zAxis' in kwargs: if 'z_index' in kwargs:
if kwargs['zAxis'] is not None: if kwargs['z_index'] is not None:
kwargs['dim'] = 3 kwargs['dim'] = 3
else: else:
kwargs['dim'] = 2 kwargs['dim'] = 2
...@@ -1003,7 +1004,7 @@ class Mesh3D(tfields.TensorMaps): ...@@ -1003,7 +1004,7 @@ class Mesh3D(tfields.TensorMaps):
if not dim_defined: if not dim_defined:
kwargs['dim'] = 2 kwargs['dim'] = 2
return tfields.plotting.plot_mesh(self, self.faces, **kwargs) return rna.plotting.plot_mesh(self, self.faces, **kwargs)
if __name__ == '__main__': # pragma: no cover if __name__ == '__main__': # pragma: no cover
......
...@@ -9,6 +9,7 @@ part of tfields library ...@@ -9,6 +9,7 @@ part of tfields library
import tfields import tfields
import sympy import sympy
import numpy as np import numpy as np
import rna
class Planes3D(tfields.TensorFields): class Planes3D(tfields.TensorFields):
...@@ -41,9 +42,11 @@ class Planes3D(tfields.TensorFields): ...@@ -41,9 +42,11 @@ class Planes3D(tfields.TensorFields):
centers = np.array(self) centers = np.array(self)
norms = np.array(self.fields[0]) norms = np.array(self.fields[0])
for i in range(len(self)): for i in range(len(self)):
artists.append(tfields.plotting.plot_plane(centers[i], artists.append(
norms[i], rna.plotting.plot_plane(
**kwargs)) centers[i],
norms[i],
**kwargs))
# symbolic = self.symbolic() # symbolic = self.symbolic()
# planeMeshes = [tfields.Mesh3D([pl.arbitrary_point(t=(i + 1) * 1. / 2 * np.pi) # planeMeshes = [tfields.Mesh3D([pl.arbitrary_point(t=(i + 1) * 1. / 2 * np.pi)
# for i in range(4)], # for i in range(4)],
......
"""
Core plotting tools for tfields library. Especially PlotOptions class
is basis for many plotting expansions
TODO:
* add other library backends. Do not restrict to mpl
"""
import warnings
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from .mpl import *
from six import string_types
def set_default(dictionary, attr, value):
"""
Set defaults to a dictionary
"""
if attr not in dictionary:
dictionary[attr] = value
class PlotOptions(object):
"""
processing kwargs for plotting functions and providing easy
access to axis, dimension and plotting method as well as indices
for array choice (x..., y..., zAxis)
"""
def __init__(self, kwargs):
kwargs = dict(kwargs)
self.axis = kwargs.pop('axis', None)
self.dim = kwargs.pop('dim', None)
self.method = kwargs.pop('methodName', None)
self.setXYZAxis(kwargs)
self.plot_kwargs = kwargs
@property
def method(self):
"""
Method for plotting. Will be callable together with plot_kwargs
"""
return self._method
@method.setter
def method(self, methodName):
if not isinstance(methodName, str):
self._method = methodName
else:
self._method = getattr(self.axis, methodName)
@property
def dim(self):
"""
axis dimension
"""
return self._dim
@dim.setter
def dim(self, dim):
if dim is None:
if self._axis is None:
dim = 2
dim = axis_dim(self._axis)
elif self._axis is not None:
if not dim == axis_dim(self._axis):
raise ValueError("Axis and dim argument are in conflict.")
if dim not in [2, 3]:
raise NotImplementedError("Dimensions other than 2 or 3 are not supported.")
self._dim = dim
@property
def axis(self):
"""
The plt.Axis object that belongs to this instance
"""
if self._axis is None:
return gca(self._dim)
else:
return self._axis
@axis.setter
def axis(self, axis):
self._axis = axis
def setXYZAxis(self, kwargs):
self._xAxis = kwargs.pop('xAxis', 0)
self._yAxis = kwargs.pop('yAxis', 1)
zAxis = kwargs.pop('zAxis', None)
if zAxis is None and self.dim == 3:
indicesUsed = [0, 1, 2]
indicesUsed.remove(self._xAxis)
indicesUsed.remove(self._yAxis)
zAxis = indicesUsed[0]
self._zAxis = zAxis
def getXYZAxis(self):
return self._xAxis, self._yAxis, self._zAxis
def setVminVmaxAuto(self, vmin, vmax, scalars):
"""
Automatically set vmin and vmax as min/max of scalars
but only if vmin or vmax is None
"""
if scalars is None:
return
if len(scalars) < 2:
warnings.warn("Need at least two scalars to autoset vmin and/or vmax!")
return
if vmin is None:
vmin = min(scalars)
self.plot_kwargs['vmin'] = vmin
if vmax is None:
vmax = max(scalars)
self.plot_kwargs['vmax'] = vmax
def getNormArgs(self, vminDefault=0, vmaxDefault=1, cmapDefault=None):
if cmapDefault is None:
cmapDefault = plt.rcParams['image.cmap']
cmap = self.get('cmap', cmapDefault)
vmin = self.get('vmin', vminDefault)
vmax = self.get('vmax', vmaxDefault)
return cmap, vmin, vmax
def format_colors(self, colors, fmt='rgba', length=None):
"""
format colors according to fmt argument
Args:
colors (list/one value of rgba tuples/int/float/str): This argument will
be interpreted as color
fmt (str): rgba | hex | norm
length (int/None): if not None: correct colors lenght
Returns:
colors in fmt
"""
hasIter = True
if not hasattr(colors, '__iter__') or isinstance(colors, string_types):
# colors is just one element
hasIter = False
colors = [colors]
if fmt == 'norm':
if hasattr(colors[0], '__iter__'):
# rgba given but norm wanted
cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified',
vminDefault=None,
vmaxDefault=None)
colors = to_scalars(colors, cmap, vmin, vmax)
self.plot_kwargs['vmin'] = vmin
self.plot_kwargs['vmax'] = vmax
self.plot_kwargs['cmap'] = cmap
elif fmt == 'rgba':
if isinstance(colors[0], string_types):
# string color defined
colors = [mpl.colors.to_rgba(color) for color in colors]
else:
# norm given rgba wanted
cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified',
vminDefault=None,
vmaxDefault=None)
self.setVminVmaxAuto(vmin, vmax, colors)
# update vmin and vmax
cmap, vmin, vmax = self.getNormArgs()
colors = to_colors(colors,
vmin=vmin,
vmax=vmax,
cmap=cmap)
elif fmt == 'hex':
colors = [mpl.colors.to_hex(color) for color in colors]
else:
raise NotImplementedError("Color fmt '{fmt}' not implemented."
.format(**locals()))
if length is not None:
# just one colors value given
if len(colors) != length:
if not len(colors) == 1:
raise ValueError("Can not correct color length")
colors = list(colors)
colors *= length
elif not hasIter:
colors = colors[0]
colors = np.array(colors)
return colors
def delNormArgs(self):
self.plot_kwargs.pop('vmin', None)
self.plot_kwargs.pop('vmax', None)
self.plot_kwargs.pop('cmap', None)
def getSortedLabels(self, labels):
"""
Returns the labels corresponding to the axes
"""
return [labels[i] for i in self.getXYZAxis() if i is not None]
def get(self, attr, default=None):
return self.plot_kwargs.get(attr, default)
def pop(self, attr, default=None):
return self.plot_kwargs.pop(attr, default)
def set(self, attr, value):
self.plot_kwargs[attr] = value
def set_default(self, attr, value):
set_default(self.plot_kwargs, attr, value)
def retrieve(self, attr, default=None, keep=True):
if keep:
return self.get(attr, default)
else:
return self.pop(attr, default)
def retrieve_chain(self, *args, **kwargs):
default = kwargs.pop('default', None)
keep = kwargs.pop('keep', True)
if len(args) > 1:
return self.retrieve(args[0],
self.retrieve_chain(*args[1:],
default=default,
keep=keep),
keep=keep)
if len(args) != 1:
raise ValueError("Invalid number of args ({0})".format(len(args)))
return self.retrieve(args[0], default, keep=keep)
if __name__ == '__main__':
import doctest
doctest.testmod()
This diff is collapsed.
legend.numpoints : 1 # the number of points in the legend line
legend.scatterpoints : 1
figure.figsize: 12.8, 8.8
axes.labelsize: 20
axes.titlesize: 24
xtick.labelsize: 20
ytick.labelsize: 20
legend.fontsize: 20
grid.linewidth: 1.6
lines.linewidth: 2.8
patch.linewidth: 0.48
lines.markersize: 11.2
lines.markeredgewidth: 0
xtick.major.width: 1.6
ytick.major.width: 1.6
xtick.minor.width: 0.8
ytick.minor.width: 0.8
xtick.major.pad: 11.2
ytick.major.pad: 11.2
savefig.transparent : True
savefig.dpi: 300
figure.autolayout : True
# IMAGES
image.cmap: viridis
# SAVING / saving
# savefig.dpi: 600
# LATEX / LaTeX / latex
text.usetex : True
font.family : serif
font.size : 40 # 20 was a duplicate definition
text.latex.preamble : \usepackage{amsmath} # for \text etc
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment