From 5327d1694e310973a2acd3d7911855b02d5ec71a Mon Sep 17 00:00:00 2001
From: "Boeckenhoff, Daniel (dboe)" <daniel.boeckenhoff@ipp.mpg.de>
Date: Fri, 18 Jan 2019 14:59:03 +0100
Subject: [PATCH] Container added

---
 tfields/__init__.py          |  2 +-
 tfields/core.py              | 61 +++++++++++++++++++++++++++++++-----
 tfields/lib/in_out.py        | 19 +++++++++++
 tfields/plotting/__init__.py | 44 +++++++++++++-------------
 tfields/plotting/mpl.py      | 30 +++++++++---------
 5 files changed, 111 insertions(+), 45 deletions(-)

diff --git a/tfields/__init__.py b/tfields/__init__.py
index fb36f4b..e192083 100644
--- a/tfields/__init__.py
+++ b/tfields/__init__.py
@@ -7,7 +7,7 @@ from .lib import *
 from . import plotting
 
 # __all__ = ['core', 'points3D']
-from .core import Tensors, TensorFields, TensorMaps
+from .core import Tensors, TensorFields, TensorMaps, Container
 from .points3D import Points3D
 from .mask import evalf
 
diff --git a/tfields/core.py b/tfields/core.py
index 3ea0f2f..f215b3e 100644
--- a/tfields/core.py
+++ b/tfields/core.py
@@ -6,6 +6,11 @@ Mail:       daniel.boeckenhoff@ipp.mpg.de
 
 core of tfields library
 contains numpy ndarray derived bases of the tfields package
+
+Notes:
+    It could be worthwhile concidering np.li.mixins.NDArrayOperatorsMixin
+    ... see https://docs.scipy.org/doc/numpy-1.15.1/reference
+            /generated/numpy.lib.mixins.NDArrayOperatorsMixin.html
 """
 import warnings
 import os
@@ -164,6 +169,14 @@ class AbstractNdarray(np.ndarray):
             index = -(i + 1)
             setattr(self, slot, state[index])
 
+    @property
+    def bulk(self):
+        """
+        The pure ndarray version of the actual state
+            -> nothing attached
+        """
+        return np.array(self)
+
     def copy(self, *args, **kwargs):
         """
         The standard ndarray copy does not copy slots. Correct for this.
@@ -696,14 +709,6 @@ class Tensors(AbstractNdarray):
                            **cls_kwargs)
         return inst
 
-    @property
-    def bulk(self):
-        """
-        The pure ndarray version of the actual state
-            -> nothing attached
-        """
-        return np.array(self)
-
     @property
     def rank(self):
         """
@@ -1881,6 +1886,10 @@ class TensorMaps(TensorFields):
                     used like: self.maps[map_pos_idx]
                 map_indices_list (list of list of int): each int refers
                     to index in a map.
+
+        Returns:
+            list of type(self): One TensorMaps or TensorMaps subclass per
+                map_description
         """
         # raise ValueError(map_descriptions)
         parts = []
@@ -1900,6 +1909,8 @@ class TensorMaps(TensorFields):
     def disjoint_map(self, mp_idx):
         """
         Find the disjoint sets of map = self.maps[mp_idx]
+        As an example, this method is interesting for splitting a mesh
+        consisting of seperate parts
         Args:
             mp_idx (int): reference to map position
                 used like: self.maps[mp_idx]
@@ -1926,6 +1937,40 @@ class TensorMaps(TensorFields):
         return (0, maps_list)
 
 
+class Container(AbstractNdarray):
+    """
+    Story lists of tfields objects. Save mechanisms are provided
+    Examples:
+        >>> import numpy as np
+        >>> import tfields
+        >>> sphere = tfields.Mesh3D.grid(
+        ...     (1, 1, 1),
+        ...     (-np.pi, np.pi, 3),
+        ...     (-np.pi / 2, np.pi / 2, 3),
+        ...     coord_sys='spherical')
+        >>> sphere2 = sphere.copy() * 3
+        >>> c = tfields.Container([sphere, sphere2])
+
+        # >>> c.save("~/tmp/spheres.npz")
+        # >>> c1 = tfields.Container.load("~/tmp/spheres.npz")
+    """
+    __slots__ = ['items', 'labels']
+    def __new__(cls, items, **kwargs):
+        kwargs['items'] = items
+        cls._update_slot_kwargs(kwargs)
+
+        empty = np.empty(0, int)
+        obj = empty.view(cls)
+
+        ''' set kwargs to slots attributes '''
+        for attr in kwargs:
+            if attr not in cls._iter_slots():
+                raise AttributeError("Keyword argument {attr} not accepted "
+                                     "for class {cls}".format(**locals()))
+            setattr(obj, attr, kwargs[attr])
+        return obj
+
+
 if __name__ == '__main__':  # pragma: no cover
     import doctest
     doctest.testmod()
diff --git a/tfields/lib/in_out.py b/tfields/lib/in_out.py
index 8809a9e..9d2e746 100644
--- a/tfields/lib/in_out.py
+++ b/tfields/lib/in_out.py
@@ -42,6 +42,25 @@ def resolve(path):
     return os.path.realpath(os.path.abspath(os.path.expanduser(path)))
 
 
+def provide(cls, path, function, *args, **kwargs):
+    """
+    provide the object of type tfields_clas saved unter path
+    and generated by function with *args and **kwargs
+    Args:
+        cls (type): 
+    """
+    path = resolve(path)
+    if os.path.exists(path):
+        return cls.load(path)
+    obj = function(*args, **kwargs)
+    if not isinstance(obj, cls):
+        return_type = type(obj)
+        raise TypeError("Return value of function is not {cls} but"
+                        "{return_type}".format(**locals()))
+    obj.save(path)
+    return obj
+
+
 def cp(source, dest, overwrite=True):
     """
     copy with shutil
diff --git a/tfields/plotting/__init__.py b/tfields/plotting/__init__.py
index cdd636e..8f14760 100644
--- a/tfields/plotting/__init__.py
+++ b/tfields/plotting/__init__.py
@@ -33,12 +33,12 @@ class PlotOptions(object):
         self.dim = kwargs.pop('dim', None)
         self.method = kwargs.pop('methodName', None)
         self.setXYZAxis(kwargs)
-        self.plotKwargs = kwargs
+        self.plot_kwargs = kwargs
 
     @property
     def method(self):
         """
-        Method for plotting. Will be callable together with plotKwargs
+        Method for plotting. Will be callable together with plot_kwargs
         """
         return self._method
 
@@ -109,10 +109,10 @@ class PlotOptions(object):
             return
         if vmin is None:
             vmin = min(scalars)
-            self.plotKwargs['vmin'] = vmin
+            self.plot_kwargs['vmin'] = vmin
         if vmax is None:
             vmax = max(scalars)
-            self.plotKwargs['vmax'] = vmax
+            self.plot_kwargs['vmax'] = vmax
 
     def getNormArgs(self, vminDefault=0, vmaxDefault=1, cmapDefault=None):
         if cmapDefault is None:
@@ -140,15 +140,16 @@ class PlotOptions(object):
             hasIter = False
             colors = [colors]
 
-        if hasattr(colors[0], '__iter__') and fmt == 'norm':
-            # rgba given but norm wanted
-            cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified',
-                                                vminDefault=None,
-                                                vmaxDefault=None)
-            colors = to_scalars(colors, cmap, vmin, vmax)
-            self.plotKwargs['vmin'] = vmin
-            self.plotKwargs['vmax'] = vmax
-            self.plotKwargs['cmap'] = cmap
+        if fmt == 'norm':
+            if hasattr(colors[0], '__iter__'):
+                # rgba given but norm wanted
+                cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified',
+                                                    vminDefault=None,
+                                                    vmaxDefault=None)
+                colors = to_scalars(colors, cmap, vmin, vmax)
+                self.plot_kwargs['vmin'] = vmin
+                self.plot_kwargs['vmax'] = vmax
+                self.plot_kwargs['cmap'] = cmap
         elif fmt == 'rgba':
             if isinstance(colors[0], string_types):
                 # string color defined
@@ -165,10 +166,11 @@ class PlotOptions(object):
                                    vmin=vmin,
                                    vmax=vmax,
                                    cmap=cmap)
+            print(colors)
         elif fmt == 'hex':
             colors = [mpl.colors.to_hex(color) for color in colors]
         else:
-            raise NotImplementedError("Color fmt {fmt} not implemented."
+            raise NotImplementedError("Color fmt '{fmt}' not implemented."
                                       .format(**locals()))
     
         if length is not None:
@@ -185,9 +187,9 @@ class PlotOptions(object):
         return colors
 
     def delNormArgs(self):
-        self.plotKwargs.pop('vmin', None)
-        self.plotKwargs.pop('vmax', None)
-        self.plotKwargs.pop('cmap', None)
+        self.plot_kwargs.pop('vmin', None)
+        self.plot_kwargs.pop('vmax', None)
+        self.plot_kwargs.pop('cmap', None)
 
     def getSortedLabels(self, labels):
         """
@@ -196,16 +198,16 @@ class PlotOptions(object):
         return [labels[i] for i in self.getXYZAxis() if i is not None]
         
     def get(self, attr, default=None):
-        return self.plotKwargs.get(attr, default)
+        return self.plot_kwargs.get(attr, default)
 
     def pop(self, attr, default=None):
-        return self.plotKwargs.pop(attr, default)
+        return self.plot_kwargs.pop(attr, default)
 
     def set(self, attr, value):
-        self.plotKwargs[attr] = value
+        self.plot_kwargs[attr] = value
 
     def set_default(self, attr, value):
-        set_default(self.plotKwargs, attr, value)
+        set_default(self.plot_kwargs, attr, value)
 
     def retrieve(self, attr, default=None, keep=True):
         if keep:
diff --git a/tfields/plotting/mpl.py b/tfields/plotting/mpl.py
index af4e96f..435eb3a 100644
--- a/tfields/plotting/mpl.py
+++ b/tfields/plotting/mpl.py
@@ -198,7 +198,7 @@ def plot_array(array, **kwargs):
                 array[:, yAxis],
                 array[:, zAxis]]
     artist = po.method(*args,
-                       **po.plotKwargs)
+                       **po.plot_kwargs)
     return artist
 
 
@@ -258,8 +258,8 @@ def plot_mesh(vertices, faces, **kwargs):
 
         vertices = mesh
 
-        po.plotKwargs['methodName'] = 'tripcolor'
-        po.plotKwargs['triangles'] = faces
+        po.plot_kwargs['methodName'] = 'tripcolor'
+        po.plot_kwargs['triangles'] = faces
 
         """
         sort out color arguments
@@ -269,9 +269,9 @@ def plot_mesh(vertices, faces, **kwargs):
                                       length=nFacesInitial)
         if not full:
             facecolors = facecolors[dotProduct > 0]
-        po.plotKwargs['facecolors'] = facecolors
+        po.plot_kwargs['facecolors'] = facecolors
 
-        d = po.plotKwargs
+        d = po.plot_kwargs
         d['xAxis'] = xAxis
         d['yAxis'] = yAxis
         artist = plot_array(vertices, **d)
@@ -294,7 +294,7 @@ def plot_mesh(vertices, faces, **kwargs):
         po.delNormArgs()
 
         triangles = np.array([vertices[face] for face in faces])
-        artist = plt3D.art3d.Poly3DCollection(triangles, **po.plotKwargs)
+        artist = plt3D.art3d.Poly3DCollection(triangles, **po.plot_kwargs)
         po.axis.add_collection3d(artist)
 
         if edgecolor is not None:
@@ -341,11 +341,11 @@ def plot_tensor_field(points, vectors, **kwargs):
         if po.dim == 3:
             artists.append(po.axis.quiver(point[xAxis], point[yAxis], point[zAxis],
                                           vector[xAxis], vector[yAxis], vector[zAxis],
-                                          **po.plotKwargs))
+                                          **po.plot_kwargs))
         elif po.dim == 2:
             artists.append(po.axis.quiver(point[xAxis], point[yAxis],
                                           vector[xAxis], vector[yAxis],
-                                          **po.plotKwargs))
+                                          **po.plot_kwargs))
         else:
             raise NotImplementedError("Dimension != 2|3")
     return artists
@@ -406,7 +406,7 @@ def plot_plane(point, normal, **kwargs):
 
     kwargs['alpha'] = kwargs.pop('alpha', 0.5)
     po = tfields.plotting.PlotOptions(kwargs)
-    patch = Circle((0, 0), **po.plotKwargs)
+    patch = Circle((0, 0), **po.plot_kwargs)
     po.axis.add_patch(patch)
     pathpatch_2d_to_3d(patch, z=0, normal=normal)
     pathpatch_translate(patch, (point[0], point[1], point[2]))
@@ -422,7 +422,7 @@ def plot_sphere(point, radius, **kwargs):
     z = point[2] + radius * np.outer(np.ones(np.size(u)), np.cos(v))
 
     # Plot the surface
-    return po.axis.plot_surface(x, y, z, **po.plotKwargs)
+    return po.axis.plot_surface(x, y, z, **po.plot_kwargs)
 
 
 def plot_function(fun, **kwargs):
@@ -443,7 +443,7 @@ def plot_function(fun, **kwargs):
     vals = np.linspace(xMin, xMax, n)
     args = (vals, map(fun, vals))
     artist = po.axis.plot(*args,
-                          **po.plotKwargs)
+                          **po.plot_kwargs)
     return artist
 
 
@@ -465,7 +465,7 @@ def plot_errorbar(points, errors_up, errors_down=None, **kwargs):
     artists = []
 
     # plot points
-    # artists.append(po.axis.plot(points, **po.plotKwargs))
+    # artists.append(po.axis.plot(points, **po.plot_kwargs))
 
     # plot errorbars
     for i in range(len(points)):
@@ -474,19 +474,19 @@ def plot_errorbar(points, errors_up, errors_down=None, **kwargs):
                           points[i, 0] - errors_down[i, 0]],
                          [points[i, 1], points[i, 1]],
                          [points[i, 2], points[i, 2]],
-                         **po.plotKwargs))
+                         **po.plot_kwargs))
         artists.append(
             po.axis.plot([points[i, 0], points[i, 0]],
                          [points[i, 1] + errors_up[i, 1],
                           points[i, 1] - errors_down[i, 1]],
                          [points[i, 2], points[i, 2]],
-                         **po.plotKwargs))
+                         **po.plot_kwargs))
         artists.append(
             po.axis.plot([points[i, 0], points[i, 0]],
                          [points[i, 1], points[i, 1]],
                          [points[i, 2] + errors_up[i, 2],
                           points[i, 2] - errors_down[i, 2]],
-                         **po.plotKwargs))
+                         **po.plot_kwargs))
 
     return artists
 
-- 
GitLab