From 03cc96cdee50a177389cfe9af9f2ce99c6c00d57 Mon Sep 17 00:00:00 2001
From: "Boeckenhoff, Daniel (dboe)" <daniel.boeckenhoff@ipp.mpg.de>
Date: Tue, 7 Aug 2018 09:57:58 +0200
Subject: [PATCH] before publishing. especially saving npz neat now

---
 tfields/__init__.py          |  2 +-
 tfields/bases/bases.py       |  4 ++--
 tfields/core.py              | 16 +++++++++++-----
 tfields/lib/util.py          |  8 +++++---
 tfields/plotting/__init__.py |  4 ++--
 tfields/plotting/mpl.py      | 23 +++++++++++++++--------
 tfields/triangles3D.py       |  3 +--
 7 files changed, 37 insertions(+), 23 deletions(-)

diff --git a/tfields/__init__.py b/tfields/__init__.py
index 6565cde..4771b49 100644
--- a/tfields/__init__.py
+++ b/tfields/__init__.py
@@ -1,4 +1,4 @@
-__version__ = '0.1.0.dev5'
+__version__ = '0.1.0.dev6'
 __author__ = "Daniel Boeckenhoff"
 __email__ = "daniel.boeckenhoff@ipp.mpg.de"
 
diff --git a/tfields/bases/bases.py b/tfields/bases/bases.py
index c94ef72..d27f4d9 100644
--- a/tfields/bases/bases.py
+++ b/tfields/bases/bases.py
@@ -42,13 +42,13 @@ def get_coord_system_name(base):
         str: name of base
     """
     if isinstance(base, sympy.diffgeom.CoordSystem):
-        base = str(getattr(base, 'name'))
+        base = getattr(base, 'name')
     # if not (isinstance(base, string_types) or base is None):
     #     baseType = type(base)
     #     raise ValueError("Coordinate system must be string_type."
     #                      " Retrieved value '{base}' of type {baseType}."
     #                      .format(**locals()))
-    return base
+    return str(base)
 
 
 def lambdifiedTrafo(base_old, base_new):
diff --git a/tfields/core.py b/tfields/core.py
index 3931d67..f4a931d 100644
--- a/tfields/core.py
+++ b/tfields/core.py
@@ -277,9 +277,11 @@ class AbstractNdarray(np.ndarray):
             >>> _ = out_file_maps.seek(0)
             >>> m1 = tfields.TensorMaps.load(out_file_maps.name)
             >>> assert m.equal(m1)
+            >>> assert m.maps[0].dtype == m1.maps[0].dtype
 
         """
-        np.savez(path, **self._as_dict())
+        content_dict = self._as_dict()
+        np.savez(path, **content_dict)
 
     @classmethod
     def _load_npz(cls, path, **load_kwargs):
@@ -454,7 +456,10 @@ class Tensors(AbstractNdarray):
                 dtype = tensors.dtype
         else:
             if dtype is None:
-                dtype = np.float64
+                if hasattr(tensors, 'dtype'):
+                    dtype = tensors.dtype
+                else:
+                    dtype = np.float64
 
         ''' demand iterable structure '''
         try:
@@ -980,7 +985,7 @@ class Tensors(AbstractNdarray):
         """
         coords = sympy.symbols('x y z')
         with self.tmp_transform(coord_sys or self.coord_sys):
-            mask = tfields.evalf(self, expression, coords=coords)
+            mask = tfields.evalf(np.array(self), expression, coords=coords)
         return mask
 
     def cut(self, expression, coord_sys=None):
@@ -1087,7 +1092,7 @@ class Tensors(AbstractNdarray):
             return d[d > 0].reshape(d.shape[0], - 1).min(axis=1)
         except MemoryError:
             min_dists = np.empty(self.shape[0])
-            for i, point in enumerate(other):
+            for i, point in enumerate(np.array(other)):
                 d = self.distances([point], **kwargs)
                 min_dists[i] = d[d > 0].reshape(-1).min()
             return min_dists
@@ -1762,7 +1767,7 @@ class TensorMaps(TensorFields):
         indices = np.array(range(len(self)))
         keep_indices = indices[mask]
         if isinstance(keep_indices, int):
-            keep_indices = [keep_indices]
+            keep_indices = np.array([keep_indices])
         delete_indices = set(indices.flat).difference(set(keep_indices.flat))
 
         masks = []
@@ -1832,5 +1837,6 @@ class TensorMaps(TensorFields):
 if __name__ == '__main__':  # pragma: no cover
     import doctest
     doctest.testmod()
+    # doctest.run_docstring_examples(Tensors._save_npz, globals())
     # doctest.run_docstring_examples(TensorMaps.cut, globals())
     # doctest.run_docstring_examples(AbstractNdarray._save_npz, globals())
diff --git a/tfields/lib/util.py b/tfields/lib/util.py
index 66e3926..41a8ffc 100644
--- a/tfields/lib/util.py
+++ b/tfields/lib/util.py
@@ -57,10 +57,12 @@ def flatten(seq, container=None, keep_types=None):
 
 def multi_sort(array, *others, **kwargs):
     """
-    Sort both lists with list 1
+    Sort all given lists parralel with array sorting, ie rearrange the items in
+    the other lists in the same way, you rearrange them for array due to array
+    sorting
     Args:
-        array
-        *others
+        array (list)
+        *others (list)
         **kwargs:
             method (function): sorting function. Default is 'sorted' 
             ...: further arguments are passed to method. Default rest is 
diff --git a/tfields/plotting/__init__.py b/tfields/plotting/__init__.py
index 23014fd..77d5d65 100644
--- a/tfields/plotting/__init__.py
+++ b/tfields/plotting/__init__.py
@@ -206,12 +206,12 @@ class PlotOptions(object):
         else:
             return self.pop(attr, default)
 
-    def retrieveChain(self, *args, **kwargs):
+    def retrieve_chain(self, *args, **kwargs):
         default = kwargs.pop('default', None)
         keep = kwargs.pop('keep', True)
         if len(args) > 1:
             return self.retrieve(args[0],
-                                 self.retrieveChain(*args[1:],
+                                 self.retrieve_chain(*args[1:],
                                                     default=default,
                                                     keep=keep),
                                  keep=keep)
diff --git a/tfields/plotting/mpl.py b/tfields/plotting/mpl.py
index 1d39712..4cdb794 100644
--- a/tfields/plotting/mpl.py
+++ b/tfields/plotting/mpl.py
@@ -191,6 +191,8 @@ def plot_mesh(vertices, faces, **kwargs):
         vmin
         vmax
     """
+    vertices = np.array(vertices)
+    faces = np.array(faces)
     if faces.shape[0] == 0:
         warnings.warn("No faces to plot")
         return None
@@ -201,9 +203,9 @@ def plot_mesh(vertices, faces, **kwargs):
         full = True
         mesh = tfields.Mesh3D(vertices, faces=faces)
         xAxis, yAxis, zAxis = po.getXYZAxis()
-        facecolors = po.retrieveChain('facecolors', 'color',
-                                      default=0,
-                                      keep=False)
+        facecolors = po.retrieve_chain('facecolors', 'color',
+                                       default=0,
+                                       keep=False)
         if full:
             # implementation that will sort the triangles by zAxis
             centroids = mesh.centroids()
@@ -212,8 +214,13 @@ def plot_mesh(vertices, faces, **kwargs):
             axesIndices.pop(axesIndices.index(yAxis))
             zAxis = axesIndices[0]
             zs = centroids[:, zAxis]
-            zs, faces, facecolors = tfields.lib.util.multi_sort(zs, faces,
-                                                                facecolors)
+            try:
+                iter(facecolors)
+                zs, faces, facecolors = tfields.lib.util.multi_sort(zs, faces,
+                                                                    facecolors)
+            except TypeError:
+                zs, faces = tfields.lib.util.multi_sort(zs, faces)
+            
             nFacesInitial = len(faces)
         else:
             # cut away "back sides" implementation
@@ -246,9 +253,9 @@ def plot_mesh(vertices, faces, **kwargs):
         artist = plot_array(vertices, **d)
     elif po.dim == 3:
         label = po.pop('label', None)
-        color = po.retrieveChain('color', 'c', 'facecolors',
-                                 default='grey',
-                                 keep=False)
+        color = po.retrieve_chain('color', 'c', 'facecolors',
+                                  default='grey',
+                                  keep=False)
         color = po.formatColors(color,
                                 fmt='rgba',
                                 length=len(faces))
diff --git a/tfields/triangles3D.py b/tfields/triangles3D.py
index 2134f91..54fbba7 100644
--- a/tfields/triangles3D.py
+++ b/tfields/triangles3D.py
@@ -488,7 +488,6 @@ class Triangles3D(tfields.TensorFields):
             not invertable matrices the you will always get False
             >>> m3 = tfields.Mesh3D([[0,0,0], [2,0,0], [4,0,0], [0,1,0]],
             ...                     faces=[[0, 1, 2], [0, 1, 3]]);
-            >>> import pytest
             >>> mask = m3.triangles()._in_triangles(np.array([0.2, 0.2, 0]), delta=0.3)
             >>> assert np.array_equal(mask,
             ...                       np.array([False,  True], dtype=bool))
@@ -550,7 +549,7 @@ class Triangles3D(tfields.TensorFields):
 
                 For Example, if you want to know the number of points in one
                 face, just do:
-                >> tris.in_triangles(poits).sum(axis=0)
+                >> tris.in_triangles(poits).sum(axis=0)[face_index]
 
         """
         if self.ntriangles() == 0:
-- 
GitLab