From 03cc96cdee50a177389cfe9af9f2ce99c6c00d57 Mon Sep 17 00:00:00 2001 From: "Boeckenhoff, Daniel (dboe)" <daniel.boeckenhoff@ipp.mpg.de> Date: Tue, 7 Aug 2018 09:57:58 +0200 Subject: [PATCH] before publishing. especially saving npz neat now --- tfields/__init__.py | 2 +- tfields/bases/bases.py | 4 ++-- tfields/core.py | 16 +++++++++++----- tfields/lib/util.py | 8 +++++--- tfields/plotting/__init__.py | 4 ++-- tfields/plotting/mpl.py | 23 +++++++++++++++-------- tfields/triangles3D.py | 3 +-- 7 files changed, 37 insertions(+), 23 deletions(-) diff --git a/tfields/__init__.py b/tfields/__init__.py index 6565cde..4771b49 100644 --- a/tfields/__init__.py +++ b/tfields/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.1.0.dev5' +__version__ = '0.1.0.dev6' __author__ = "Daniel Boeckenhoff" __email__ = "daniel.boeckenhoff@ipp.mpg.de" diff --git a/tfields/bases/bases.py b/tfields/bases/bases.py index c94ef72..d27f4d9 100644 --- a/tfields/bases/bases.py +++ b/tfields/bases/bases.py @@ -42,13 +42,13 @@ def get_coord_system_name(base): str: name of base """ if isinstance(base, sympy.diffgeom.CoordSystem): - base = str(getattr(base, 'name')) + base = getattr(base, 'name') # if not (isinstance(base, string_types) or base is None): # baseType = type(base) # raise ValueError("Coordinate system must be string_type." # " Retrieved value '{base}' of type {baseType}." # .format(**locals())) - return base + return str(base) def lambdifiedTrafo(base_old, base_new): diff --git a/tfields/core.py b/tfields/core.py index 3931d67..f4a931d 100644 --- a/tfields/core.py +++ b/tfields/core.py @@ -277,9 +277,11 @@ class AbstractNdarray(np.ndarray): >>> _ = out_file_maps.seek(0) >>> m1 = tfields.TensorMaps.load(out_file_maps.name) >>> assert m.equal(m1) + >>> assert m.maps[0].dtype == m1.maps[0].dtype """ - np.savez(path, **self._as_dict()) + content_dict = self._as_dict() + np.savez(path, **content_dict) @classmethod def _load_npz(cls, path, **load_kwargs): @@ -454,7 +456,10 @@ class Tensors(AbstractNdarray): dtype = tensors.dtype else: if dtype is None: - dtype = np.float64 + if hasattr(tensors, 'dtype'): + dtype = tensors.dtype + else: + dtype = np.float64 ''' demand iterable structure ''' try: @@ -980,7 +985,7 @@ class Tensors(AbstractNdarray): """ coords = sympy.symbols('x y z') with self.tmp_transform(coord_sys or self.coord_sys): - mask = tfields.evalf(self, expression, coords=coords) + mask = tfields.evalf(np.array(self), expression, coords=coords) return mask def cut(self, expression, coord_sys=None): @@ -1087,7 +1092,7 @@ class Tensors(AbstractNdarray): return d[d > 0].reshape(d.shape[0], - 1).min(axis=1) except MemoryError: min_dists = np.empty(self.shape[0]) - for i, point in enumerate(other): + for i, point in enumerate(np.array(other)): d = self.distances([point], **kwargs) min_dists[i] = d[d > 0].reshape(-1).min() return min_dists @@ -1762,7 +1767,7 @@ class TensorMaps(TensorFields): indices = np.array(range(len(self))) keep_indices = indices[mask] if isinstance(keep_indices, int): - keep_indices = [keep_indices] + keep_indices = np.array([keep_indices]) delete_indices = set(indices.flat).difference(set(keep_indices.flat)) masks = [] @@ -1832,5 +1837,6 @@ class TensorMaps(TensorFields): if __name__ == '__main__': # pragma: no cover import doctest doctest.testmod() + # doctest.run_docstring_examples(Tensors._save_npz, globals()) # doctest.run_docstring_examples(TensorMaps.cut, globals()) # doctest.run_docstring_examples(AbstractNdarray._save_npz, globals()) diff --git a/tfields/lib/util.py b/tfields/lib/util.py index 66e3926..41a8ffc 100644 --- a/tfields/lib/util.py +++ b/tfields/lib/util.py @@ -57,10 +57,12 @@ def flatten(seq, container=None, keep_types=None): def multi_sort(array, *others, **kwargs): """ - Sort both lists with list 1 + Sort all given lists parralel with array sorting, ie rearrange the items in + the other lists in the same way, you rearrange them for array due to array + sorting Args: - array - *others + array (list) + *others (list) **kwargs: method (function): sorting function. Default is 'sorted' ...: further arguments are passed to method. Default rest is diff --git a/tfields/plotting/__init__.py b/tfields/plotting/__init__.py index 23014fd..77d5d65 100644 --- a/tfields/plotting/__init__.py +++ b/tfields/plotting/__init__.py @@ -206,12 +206,12 @@ class PlotOptions(object): else: return self.pop(attr, default) - def retrieveChain(self, *args, **kwargs): + def retrieve_chain(self, *args, **kwargs): default = kwargs.pop('default', None) keep = kwargs.pop('keep', True) if len(args) > 1: return self.retrieve(args[0], - self.retrieveChain(*args[1:], + self.retrieve_chain(*args[1:], default=default, keep=keep), keep=keep) diff --git a/tfields/plotting/mpl.py b/tfields/plotting/mpl.py index 1d39712..4cdb794 100644 --- a/tfields/plotting/mpl.py +++ b/tfields/plotting/mpl.py @@ -191,6 +191,8 @@ def plot_mesh(vertices, faces, **kwargs): vmin vmax """ + vertices = np.array(vertices) + faces = np.array(faces) if faces.shape[0] == 0: warnings.warn("No faces to plot") return None @@ -201,9 +203,9 @@ def plot_mesh(vertices, faces, **kwargs): full = True mesh = tfields.Mesh3D(vertices, faces=faces) xAxis, yAxis, zAxis = po.getXYZAxis() - facecolors = po.retrieveChain('facecolors', 'color', - default=0, - keep=False) + facecolors = po.retrieve_chain('facecolors', 'color', + default=0, + keep=False) if full: # implementation that will sort the triangles by zAxis centroids = mesh.centroids() @@ -212,8 +214,13 @@ def plot_mesh(vertices, faces, **kwargs): axesIndices.pop(axesIndices.index(yAxis)) zAxis = axesIndices[0] zs = centroids[:, zAxis] - zs, faces, facecolors = tfields.lib.util.multi_sort(zs, faces, - facecolors) + try: + iter(facecolors) + zs, faces, facecolors = tfields.lib.util.multi_sort(zs, faces, + facecolors) + except TypeError: + zs, faces = tfields.lib.util.multi_sort(zs, faces) + nFacesInitial = len(faces) else: # cut away "back sides" implementation @@ -246,9 +253,9 @@ def plot_mesh(vertices, faces, **kwargs): artist = plot_array(vertices, **d) elif po.dim == 3: label = po.pop('label', None) - color = po.retrieveChain('color', 'c', 'facecolors', - default='grey', - keep=False) + color = po.retrieve_chain('color', 'c', 'facecolors', + default='grey', + keep=False) color = po.formatColors(color, fmt='rgba', length=len(faces)) diff --git a/tfields/triangles3D.py b/tfields/triangles3D.py index 2134f91..54fbba7 100644 --- a/tfields/triangles3D.py +++ b/tfields/triangles3D.py @@ -488,7 +488,6 @@ class Triangles3D(tfields.TensorFields): not invertable matrices the you will always get False >>> m3 = tfields.Mesh3D([[0,0,0], [2,0,0], [4,0,0], [0,1,0]], ... faces=[[0, 1, 2], [0, 1, 3]]); - >>> import pytest >>> mask = m3.triangles()._in_triangles(np.array([0.2, 0.2, 0]), delta=0.3) >>> assert np.array_equal(mask, ... np.array([False, True], dtype=bool)) @@ -550,7 +549,7 @@ class Triangles3D(tfields.TensorFields): For Example, if you want to know the number of points in one face, just do: - >> tris.in_triangles(poits).sum(axis=0) + >> tris.in_triangles(poits).sum(axis=0)[face_index] """ if self.ntriangles() == 0: -- GitLab