From 9bc2a23f709c19adc2d7c49638ff627a39bae4cc Mon Sep 17 00:00:00 2001
From: "Boeckenhoff, Daniel (dboe)" <daniel.boeckenhoff@ipp.mpg.de>
Date: Fri, 3 Aug 2018 18:43:37 +0200
Subject: [PATCH] np.savez implemented generally for all sub types of Tensors

---
 setup.py                |   2 +
 tfields/__init__.py     |   2 +-
 tfields/bases/bases.py  |   5 +-
 tfields/core.py         | 127 +++++++++++++++++++++++++++++++++-------
 tfields/mesh3D.py       |   3 +-
 tfields/plotting/mpl.py |   9 ++-
 tfields/triangles3D.py  |   4 ++
 7 files changed, 122 insertions(+), 30 deletions(-)

diff --git a/setup.py b/setup.py
index aa93c74..33fce85 100644
--- a/setup.py
+++ b/setup.py
@@ -31,6 +31,8 @@ setup(
         'sympy',
         'scipy',
         'pathlib',
+        'six',
+        'matplotlib',
     ],
     entry_points={
         'console_scripts': ['tfields = tfields.__main__:runDoctests']
diff --git a/tfields/__init__.py b/tfields/__init__.py
index 9518a8e..6565cde 100644
--- a/tfields/__init__.py
+++ b/tfields/__init__.py
@@ -1,4 +1,4 @@
-__version__ = '0.1.0.dev4'
+__version__ = '0.1.0.dev5'
 __author__ = "Daniel Boeckenhoff"
 __email__ = "daniel.boeckenhoff@ipp.mpg.de"
 
diff --git a/tfields/bases/bases.py b/tfields/bases/bases.py
index f66b21f..c94ef72 100644
--- a/tfields/bases/bases.py
+++ b/tfields/bases/bases.py
@@ -22,8 +22,9 @@ def get_coord_system(base):
     Return:
         sympy.diffgeom.get_coord_system
     """
-    if isinstance(base, string_types):
-        base = getattr(tfields.bases, base)
+    if (isinstance(base, string_types) 
+        or (isinstance(base, np.ndarray) and base.dtype.kind in {'U', 'S'})):
+        base = getattr(tfields.bases, str(base))
     if not isinstance(base, sympy.diffgeom.CoordSystem):
         bse_tpe = type(base)
         expctd_tpe = type(sympy.diffgeom.CoordSystem)
diff --git a/tfields/core.py b/tfields/core.py
index 6075736..8c08e3d 100644
--- a/tfields/core.py
+++ b/tfields/core.py
@@ -115,6 +115,8 @@ class AbstractNdarray(np.ndarray):
     def __reduce__(self):
         """
         important for pickling
+        see https://stackoverflow.com/questions/26598109/
+            preserve-custom-attributes-when-pickling-subclass-of-numpy-array
         Examples:
             >>> from tempfile import NamedTemporaryFile
             >>> import pickle
@@ -127,13 +129,13 @@ class AbstractNdarray(np.ndarray):
             >>> scalarField = TensorFields(vectors, scalars, coord_sys='cylinder')
 
             Save it and restore it
-            >>> outFile = NamedTemporaryFile(suffix='.pickle')
+            >>> out_file = NamedTemporaryFile(suffix='.pickle')
 
             >>> pickle.dump(scalarField,
-            ...             outFile)
-            >>> _ = outFile.seek(0)
+            ...             out_file)
+            >>> _ = out_file.seek(0)
 
-            >>> sf = pickle.load(outFile)
+            >>> sf = pickle.load(out_file)
             >>> sf.coord_sys == 'cylinder'
             True
             >>> sf.fields[0][2] == 2.
@@ -250,26 +252,34 @@ class AbstractNdarray(np.ndarray):
         Args:
             path (open file or str/unicode): destination to save file to.
         Examples:
+            Build some dummies:
             >>> import tfields
             >>> from tempfile import NamedTemporaryFile
-            >>> outFile = NamedTemporaryFile(suffix='.npz')
+            >>> out_file = NamedTemporaryFile(suffix='.npz')
             >>> p = tfields.Points3D([[1., 2., 3.], [4., 5., 6.], [1, 2, -6]])
-            >>> p.save(outFile.name)
-            >>> _ = outFile.seek(0)
-            >>> p1 = tfields.Points3D.load(outFile.name)
+
+            >>> scalars = tfields.Tensors([0, 1, 2])
+            >>> vectors = tfields.Tensors([[0, 0, 0], [0, 0, 1], [0, -1, 0]])
+            >>> maps = [tfields.TensorFields([[0, 1, 2], [0, 1, 2]], [42, 21]),
+            ...         tfields.TensorFields([[1], [2]], [-42, -21])]
+            >>> m = tfields.TensorMaps(vectors, scalars,
+            ...                        maps=maps)
+
+            Simply give the file name to save
+            >>> p.save(out_file.name)
+            >>> _ = out_file.seek(0)
+            >>> p1 = tfields.Points3D.load(out_file.name)
             >>> assert p.equal(p1)
 
-        """
-        kwargs = {}
-        for attr in self._iter_slots():
-            if not hasattr(self, attr):
-                # attribute in __slots__ not found.
-                warnings.warn("When saving instance of class {0} Attribute {1} not set."
-                              "This Attribute is not saved.".format(self.__class__, attr), Warning)
-            else:
-                kwargs[attr] = getattr(self, attr)
+            The fully nested structure of a TensorMaps object is reconstructed
+            >>> out_file_maps = NamedTemporaryFile(suffix='.npz')
+            >>> m.save(out_file_maps.name)
+            >>> _ = out_file_maps.seek(0)
+            >>> m1 = tfields.TensorMaps.load(out_file_maps.name)
+            >>> assert m.equal(m1)
 
-        np.savez(path, self, **kwargs) 
+        """
+        np.savez(path, **self._as_dict())
 
     @classmethod
     def _load_npz(cls, path, **load_kwargs):
@@ -278,10 +288,82 @@ class AbstractNdarray(np.ndarray):
         Given a path to a npz file, construct the object
         """
         np_file = np.load(path, **load_kwargs)
-        keys = np_file.keys()
-        bulk = np_file['arr_0']
-        data_kwargs = {key: np_file[key] for key in keys if key not in ['arr_0']}
-        return cls.__new__(cls, bulk, **data_kwargs)
+        return cls._from_dict(**np_file)
+
+    def _as_dict(self):
+        """
+        Recursively walk trough all __slots__ and describe all elements
+        """
+        d = {}
+        d['bulk'] = np.array(self)
+        d['bulk_type'] = self.__class__.__name__
+        for attr in self._iter_slots():
+            value = getattr(self, attr)
+            if isinstance(value, list):
+                if len(value) == 0:
+                    d[attr] = None
+                if all([isinstance(part, AbstractNdarray) for part in value]):
+                    for i, part in enumerate(value):
+                        part_dict = part._as_dict()
+                        for part_attr, part_value in part_dict.items():
+                            d["{attr}::{i}::{part_attr}".format(**locals())] = part_value
+                continue
+            if isinstance(value, AbstractNdarray):
+                value = value._as_dict()
+            d[attr] = value
+        return d
+
+    @classmethod
+    def _from_dict(cls, **d):
+        """
+        Opposite of _as_dict
+        """
+        list_dict = {}
+        kwargs = {}
+        '''
+        De-Flatten the first layer of lists
+        '''
+        for key in sorted(d.keys()):
+            if '::' in key:
+                splits = key.split('::')
+                attr, _, end = key.partition('::')
+                if attr not in list_dict:
+                    list_dict[attr] = {}
+
+                index, _, end = end.partition('::')
+                if not index.isdigit():
+                    raise ValueError("None digit index given")
+                index = int(index)
+                if index not in list_dict[attr]:
+                    list_dict[attr][index] = {}
+                list_dict[attr][index][end] = d[key]
+            else:
+                kwargs[key] = d[key]
+
+        '''
+        Build the lists (recursively)
+        '''
+        for key in list_dict.keys():
+            sub_dict = list_dict[key]
+            list_dict[key] = []
+            for index in sorted(sub_dict.keys()):
+                bulk_type = sub_dict[index].get('bulk_type')
+                bulk_type = getattr(tfields, bulk_type.tolist())
+                list_dict[key].append(bulk_type._from_dict(**sub_dict[index]))
+
+        '''
+        Build the normal way
+        '''
+        bulk = kwargs.pop('bulk')
+        bulk_type = kwargs.pop('bulk_type')
+        obj = cls.__new__(cls, bulk, **kwargs)
+
+        '''
+        Set list attributes
+        '''
+        for attr, list_value in list_dict.items():
+            setattr(obj, attr, list_value)
+        return obj
 
 
 class Tensors(AbstractNdarray):
@@ -1753,3 +1835,4 @@ if __name__ == '__main__':  # pragma: no cover
     import doctest
     doctest.testmod()
     # doctest.run_docstring_examples(TensorFields.__getitem__, globals())
+    # doctest.run_docstring_examples(AbstractNdarray._save_npz, globals())
diff --git a/tfields/mesh3D.py b/tfields/mesh3D.py
index 8d17498..b23c5b3 100644
--- a/tfields/mesh3D.py
+++ b/tfields/mesh3D.py
@@ -200,8 +200,7 @@ class Mesh3D(tfields.TensorMaps):
         >>> m1 = tfields.Mesh3D.load(outFile.name)
         >>> bool(np.all(m == m1))
         True
-        >>> m1.faces
-        array([[0, 1, 2]])
+        >>> assert np.array_equal(m1.faces, np.array([[0, 1, 2]]))
 
     """
     def __new__(cls, tensors, *fields, **kwargs):
diff --git a/tfields/plotting/mpl.py b/tfields/plotting/mpl.py
index 3367fbf..1d39712 100644
--- a/tfields/plotting/mpl.py
+++ b/tfields/plotting/mpl.py
@@ -121,7 +121,10 @@ def save(path, *fmts, **kwargs):
             if first_label:
                 if not extra_artists:
                     extra_artists = []
-                extra_artists.append(first_label)
+                if isinstance(first_label, list):
+                    extra_artists.extend(first_label)
+                else:
+                    extra_artists.append(first_label)
         kwargs['bbox_extra_artists'] = kwargs.pop('bbox_extra_artists',
                                                   extra_artists)
 
@@ -564,7 +567,7 @@ def autoscale_3d(axis, array=None, xLim=None, yLim=None, zLim=None):
     axis.set_zlim([zMin, zMax])
 
 
-def set_legend(axis, artists):
+def set_legend(axis, artists, **kwargs):
     """
     Convenience method to set a legend from multiple artists to an axis.
     """
@@ -574,7 +577,7 @@ def set_legend(axis, artists):
             handles.append(artist[0])
         else:
             handles.append(artist)
-    axis.legend(handles=handles)
+    return axis.legend(handles=handles, **kwargs)
 
 
 def set_colorbar(axis, artist, label=None, divide=True, **kwargs):
diff --git a/tfields/triangles3D.py b/tfields/triangles3D.py
index c9f22f5..bf24ae5 100644
--- a/tfields/triangles3D.py
+++ b/tfields/triangles3D.py
@@ -544,6 +544,10 @@ class Triangles3D(tfields.TensorFields):
                     there is just one True for each points index
                 triangle indices can have multiple true values
 
+                For Example, if you want to know the number of points in one
+                face, just do:
+                >> tris.in_triangles(poits).sum(axis=0)
+
         """
         if self.ntriangles() == 0:
             return np.empty((tensors.shape[0], 0), dtype=bool)
-- 
GitLab