Commit 5327d169 authored by Daniel Boeckenhoff's avatar Daniel Boeckenhoff
Browse files

Container added

parent 33227e94
......@@ -7,7 +7,7 @@ from .lib import *
from . import plotting
# __all__ = ['core', 'points3D']
from .core import Tensors, TensorFields, TensorMaps
from .core import Tensors, TensorFields, TensorMaps, Container
from .points3D import Points3D
from .mask import evalf
......
......@@ -6,6 +6,11 @@ Mail: daniel.boeckenhoff@ipp.mpg.de
core of tfields library
contains numpy ndarray derived bases of the tfields package
Notes:
It could be worthwhile concidering np.li.mixins.NDArrayOperatorsMixin
... see https://docs.scipy.org/doc/numpy-1.15.1/reference
/generated/numpy.lib.mixins.NDArrayOperatorsMixin.html
"""
import warnings
import os
......@@ -164,6 +169,14 @@ class AbstractNdarray(np.ndarray):
index = -(i + 1)
setattr(self, slot, state[index])
@property
def bulk(self):
"""
The pure ndarray version of the actual state
-> nothing attached
"""
return np.array(self)
def copy(self, *args, **kwargs):
"""
The standard ndarray copy does not copy slots. Correct for this.
......@@ -696,14 +709,6 @@ class Tensors(AbstractNdarray):
**cls_kwargs)
return inst
@property
def bulk(self):
"""
The pure ndarray version of the actual state
-> nothing attached
"""
return np.array(self)
@property
def rank(self):
"""
......@@ -1881,6 +1886,10 @@ class TensorMaps(TensorFields):
used like: self.maps[map_pos_idx]
map_indices_list (list of list of int): each int refers
to index in a map.
Returns:
list of type(self): One TensorMaps or TensorMaps subclass per
map_description
"""
# raise ValueError(map_descriptions)
parts = []
......@@ -1900,6 +1909,8 @@ class TensorMaps(TensorFields):
def disjoint_map(self, mp_idx):
"""
Find the disjoint sets of map = self.maps[mp_idx]
As an example, this method is interesting for splitting a mesh
consisting of seperate parts
Args:
mp_idx (int): reference to map position
used like: self.maps[mp_idx]
......@@ -1926,6 +1937,40 @@ class TensorMaps(TensorFields):
return (0, maps_list)
class Container(AbstractNdarray):
"""
Story lists of tfields objects. Save mechanisms are provided
Examples:
>>> import numpy as np
>>> import tfields
>>> sphere = tfields.Mesh3D.grid(
... (1, 1, 1),
... (-np.pi, np.pi, 3),
... (-np.pi / 2, np.pi / 2, 3),
... coord_sys='spherical')
>>> sphere2 = sphere.copy() * 3
>>> c = tfields.Container([sphere, sphere2])
# >>> c.save("~/tmp/spheres.npz")
# >>> c1 = tfields.Container.load("~/tmp/spheres.npz")
"""
__slots__ = ['items', 'labels']
def __new__(cls, items, **kwargs):
kwargs['items'] = items
cls._update_slot_kwargs(kwargs)
empty = np.empty(0, int)
obj = empty.view(cls)
''' set kwargs to slots attributes '''
for attr in kwargs:
if attr not in cls._iter_slots():
raise AttributeError("Keyword argument {attr} not accepted "
"for class {cls}".format(**locals()))
setattr(obj, attr, kwargs[attr])
return obj
if __name__ == '__main__': # pragma: no cover
import doctest
doctest.testmod()
......
......@@ -42,6 +42,25 @@ def resolve(path):
return os.path.realpath(os.path.abspath(os.path.expanduser(path)))
def provide(cls, path, function, *args, **kwargs):
"""
provide the object of type tfields_clas saved unter path
and generated by function with *args and **kwargs
Args:
cls (type):
"""
path = resolve(path)
if os.path.exists(path):
return cls.load(path)
obj = function(*args, **kwargs)
if not isinstance(obj, cls):
return_type = type(obj)
raise TypeError("Return value of function is not {cls} but"
"{return_type}".format(**locals()))
obj.save(path)
return obj
def cp(source, dest, overwrite=True):
"""
copy with shutil
......
......@@ -33,12 +33,12 @@ class PlotOptions(object):
self.dim = kwargs.pop('dim', None)
self.method = kwargs.pop('methodName', None)
self.setXYZAxis(kwargs)
self.plotKwargs = kwargs
self.plot_kwargs = kwargs
@property
def method(self):
"""
Method for plotting. Will be callable together with plotKwargs
Method for plotting. Will be callable together with plot_kwargs
"""
return self._method
......@@ -109,10 +109,10 @@ class PlotOptions(object):
return
if vmin is None:
vmin = min(scalars)
self.plotKwargs['vmin'] = vmin
self.plot_kwargs['vmin'] = vmin
if vmax is None:
vmax = max(scalars)
self.plotKwargs['vmax'] = vmax
self.plot_kwargs['vmax'] = vmax
def getNormArgs(self, vminDefault=0, vmaxDefault=1, cmapDefault=None):
if cmapDefault is None:
......@@ -140,15 +140,16 @@ class PlotOptions(object):
hasIter = False
colors = [colors]
if hasattr(colors[0], '__iter__') and fmt == 'norm':
# rgba given but norm wanted
cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified',
vminDefault=None,
vmaxDefault=None)
colors = to_scalars(colors, cmap, vmin, vmax)
self.plotKwargs['vmin'] = vmin
self.plotKwargs['vmax'] = vmax
self.plotKwargs['cmap'] = cmap
if fmt == 'norm':
if hasattr(colors[0], '__iter__'):
# rgba given but norm wanted
cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified',
vminDefault=None,
vmaxDefault=None)
colors = to_scalars(colors, cmap, vmin, vmax)
self.plot_kwargs['vmin'] = vmin
self.plot_kwargs['vmax'] = vmax
self.plot_kwargs['cmap'] = cmap
elif fmt == 'rgba':
if isinstance(colors[0], string_types):
# string color defined
......@@ -165,10 +166,11 @@ class PlotOptions(object):
vmin=vmin,
vmax=vmax,
cmap=cmap)
print(colors)
elif fmt == 'hex':
colors = [mpl.colors.to_hex(color) for color in colors]
else:
raise NotImplementedError("Color fmt {fmt} not implemented."
raise NotImplementedError("Color fmt '{fmt}' not implemented."
.format(**locals()))
if length is not None:
......@@ -185,9 +187,9 @@ class PlotOptions(object):
return colors
def delNormArgs(self):
self.plotKwargs.pop('vmin', None)
self.plotKwargs.pop('vmax', None)
self.plotKwargs.pop('cmap', None)
self.plot_kwargs.pop('vmin', None)
self.plot_kwargs.pop('vmax', None)
self.plot_kwargs.pop('cmap', None)
def getSortedLabels(self, labels):
"""
......@@ -196,16 +198,16 @@ class PlotOptions(object):
return [labels[i] for i in self.getXYZAxis() if i is not None]
def get(self, attr, default=None):
return self.plotKwargs.get(attr, default)
return self.plot_kwargs.get(attr, default)
def pop(self, attr, default=None):
return self.plotKwargs.pop(attr, default)
return self.plot_kwargs.pop(attr, default)
def set(self, attr, value):
self.plotKwargs[attr] = value
self.plot_kwargs[attr] = value
def set_default(self, attr, value):
set_default(self.plotKwargs, attr, value)
set_default(self.plot_kwargs, attr, value)
def retrieve(self, attr, default=None, keep=True):
if keep:
......
......@@ -198,7 +198,7 @@ def plot_array(array, **kwargs):
array[:, yAxis],
array[:, zAxis]]
artist = po.method(*args,
**po.plotKwargs)
**po.plot_kwargs)
return artist
......@@ -258,8 +258,8 @@ def plot_mesh(vertices, faces, **kwargs):
vertices = mesh
po.plotKwargs['methodName'] = 'tripcolor'
po.plotKwargs['triangles'] = faces
po.plot_kwargs['methodName'] = 'tripcolor'
po.plot_kwargs['triangles'] = faces
"""
sort out color arguments
......@@ -269,9 +269,9 @@ def plot_mesh(vertices, faces, **kwargs):
length=nFacesInitial)
if not full:
facecolors = facecolors[dotProduct > 0]
po.plotKwargs['facecolors'] = facecolors
po.plot_kwargs['facecolors'] = facecolors
d = po.plotKwargs
d = po.plot_kwargs
d['xAxis'] = xAxis
d['yAxis'] = yAxis
artist = plot_array(vertices, **d)
......@@ -294,7 +294,7 @@ def plot_mesh(vertices, faces, **kwargs):
po.delNormArgs()
triangles = np.array([vertices[face] for face in faces])
artist = plt3D.art3d.Poly3DCollection(triangles, **po.plotKwargs)
artist = plt3D.art3d.Poly3DCollection(triangles, **po.plot_kwargs)
po.axis.add_collection3d(artist)
if edgecolor is not None:
......@@ -341,11 +341,11 @@ def plot_tensor_field(points, vectors, **kwargs):
if po.dim == 3:
artists.append(po.axis.quiver(point[xAxis], point[yAxis], point[zAxis],
vector[xAxis], vector[yAxis], vector[zAxis],
**po.plotKwargs))
**po.plot_kwargs))
elif po.dim == 2:
artists.append(po.axis.quiver(point[xAxis], point[yAxis],
vector[xAxis], vector[yAxis],
**po.plotKwargs))
**po.plot_kwargs))
else:
raise NotImplementedError("Dimension != 2|3")
return artists
......@@ -406,7 +406,7 @@ def plot_plane(point, normal, **kwargs):
kwargs['alpha'] = kwargs.pop('alpha', 0.5)
po = tfields.plotting.PlotOptions(kwargs)
patch = Circle((0, 0), **po.plotKwargs)
patch = Circle((0, 0), **po.plot_kwargs)
po.axis.add_patch(patch)
pathpatch_2d_to_3d(patch, z=0, normal=normal)
pathpatch_translate(patch, (point[0], point[1], point[2]))
......@@ -422,7 +422,7 @@ def plot_sphere(point, radius, **kwargs):
z = point[2] + radius * np.outer(np.ones(np.size(u)), np.cos(v))
# Plot the surface
return po.axis.plot_surface(x, y, z, **po.plotKwargs)
return po.axis.plot_surface(x, y, z, **po.plot_kwargs)
def plot_function(fun, **kwargs):
......@@ -443,7 +443,7 @@ def plot_function(fun, **kwargs):
vals = np.linspace(xMin, xMax, n)
args = (vals, map(fun, vals))
artist = po.axis.plot(*args,
**po.plotKwargs)
**po.plot_kwargs)
return artist
......@@ -465,7 +465,7 @@ def plot_errorbar(points, errors_up, errors_down=None, **kwargs):
artists = []
# plot points
# artists.append(po.axis.plot(points, **po.plotKwargs))
# artists.append(po.axis.plot(points, **po.plot_kwargs))
# plot errorbars
for i in range(len(points)):
......@@ -474,19 +474,19 @@ def plot_errorbar(points, errors_up, errors_down=None, **kwargs):
points[i, 0] - errors_down[i, 0]],
[points[i, 1], points[i, 1]],
[points[i, 2], points[i, 2]],
**po.plotKwargs))
**po.plot_kwargs))
artists.append(
po.axis.plot([points[i, 0], points[i, 0]],
[points[i, 1] + errors_up[i, 1],
points[i, 1] - errors_down[i, 1]],
[points[i, 2], points[i, 2]],
**po.plotKwargs))
**po.plot_kwargs))
artists.append(
po.axis.plot([points[i, 0], points[i, 0]],
[points[i, 1], points[i, 1]],
[points[i, 2] + errors_up[i, 2],
points[i, 2] - errors_down[i, 2]],
**po.plotKwargs))
**po.plot_kwargs))
return artists
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment