Commit 03cc96cd authored by Daniel Boeckenhoff's avatar Daniel Boeckenhoff
Browse files

before publishing. especially saving npz neat now

parent 31f2b698
__version__ = '0.1.0.dev5'
__version__ = '0.1.0.dev6'
__author__ = "Daniel Boeckenhoff"
__email__ = "daniel.boeckenhoff@ipp.mpg.de"
......
......@@ -42,13 +42,13 @@ def get_coord_system_name(base):
str: name of base
"""
if isinstance(base, sympy.diffgeom.CoordSystem):
base = str(getattr(base, 'name'))
base = getattr(base, 'name')
# if not (isinstance(base, string_types) or base is None):
# baseType = type(base)
# raise ValueError("Coordinate system must be string_type."
# " Retrieved value '{base}' of type {baseType}."
# .format(**locals()))
return base
return str(base)
def lambdifiedTrafo(base_old, base_new):
......
......@@ -277,9 +277,11 @@ class AbstractNdarray(np.ndarray):
>>> _ = out_file_maps.seek(0)
>>> m1 = tfields.TensorMaps.load(out_file_maps.name)
>>> assert m.equal(m1)
>>> assert m.maps[0].dtype == m1.maps[0].dtype
"""
np.savez(path, **self._as_dict())
content_dict = self._as_dict()
np.savez(path, **content_dict)
@classmethod
def _load_npz(cls, path, **load_kwargs):
......@@ -454,6 +456,9 @@ class Tensors(AbstractNdarray):
dtype = tensors.dtype
else:
if dtype is None:
if hasattr(tensors, 'dtype'):
dtype = tensors.dtype
else:
dtype = np.float64
''' demand iterable structure '''
......@@ -980,7 +985,7 @@ class Tensors(AbstractNdarray):
"""
coords = sympy.symbols('x y z')
with self.tmp_transform(coord_sys or self.coord_sys):
mask = tfields.evalf(self, expression, coords=coords)
mask = tfields.evalf(np.array(self), expression, coords=coords)
return mask
def cut(self, expression, coord_sys=None):
......@@ -1087,7 +1092,7 @@ class Tensors(AbstractNdarray):
return d[d > 0].reshape(d.shape[0], - 1).min(axis=1)
except MemoryError:
min_dists = np.empty(self.shape[0])
for i, point in enumerate(other):
for i, point in enumerate(np.array(other)):
d = self.distances([point], **kwargs)
min_dists[i] = d[d > 0].reshape(-1).min()
return min_dists
......@@ -1762,7 +1767,7 @@ class TensorMaps(TensorFields):
indices = np.array(range(len(self)))
keep_indices = indices[mask]
if isinstance(keep_indices, int):
keep_indices = [keep_indices]
keep_indices = np.array([keep_indices])
delete_indices = set(indices.flat).difference(set(keep_indices.flat))
masks = []
......@@ -1832,5 +1837,6 @@ class TensorMaps(TensorFields):
if __name__ == '__main__': # pragma: no cover
import doctest
doctest.testmod()
# doctest.run_docstring_examples(Tensors._save_npz, globals())
# doctest.run_docstring_examples(TensorMaps.cut, globals())
# doctest.run_docstring_examples(AbstractNdarray._save_npz, globals())
......@@ -57,10 +57,12 @@ def flatten(seq, container=None, keep_types=None):
def multi_sort(array, *others, **kwargs):
"""
Sort both lists with list 1
Sort all given lists parralel with array sorting, ie rearrange the items in
the other lists in the same way, you rearrange them for array due to array
sorting
Args:
array
*others
array (list)
*others (list)
**kwargs:
method (function): sorting function. Default is 'sorted'
...: further arguments are passed to method. Default rest is
......
......@@ -206,12 +206,12 @@ class PlotOptions(object):
else:
return self.pop(attr, default)
def retrieveChain(self, *args, **kwargs):
def retrieve_chain(self, *args, **kwargs):
default = kwargs.pop('default', None)
keep = kwargs.pop('keep', True)
if len(args) > 1:
return self.retrieve(args[0],
self.retrieveChain(*args[1:],
self.retrieve_chain(*args[1:],
default=default,
keep=keep),
keep=keep)
......
......@@ -191,6 +191,8 @@ def plot_mesh(vertices, faces, **kwargs):
vmin
vmax
"""
vertices = np.array(vertices)
faces = np.array(faces)
if faces.shape[0] == 0:
warnings.warn("No faces to plot")
return None
......@@ -201,7 +203,7 @@ def plot_mesh(vertices, faces, **kwargs):
full = True
mesh = tfields.Mesh3D(vertices, faces=faces)
xAxis, yAxis, zAxis = po.getXYZAxis()
facecolors = po.retrieveChain('facecolors', 'color',
facecolors = po.retrieve_chain('facecolors', 'color',
default=0,
keep=False)
if full:
......@@ -212,8 +214,13 @@ def plot_mesh(vertices, faces, **kwargs):
axesIndices.pop(axesIndices.index(yAxis))
zAxis = axesIndices[0]
zs = centroids[:, zAxis]
try:
iter(facecolors)
zs, faces, facecolors = tfields.lib.util.multi_sort(zs, faces,
facecolors)
except TypeError:
zs, faces = tfields.lib.util.multi_sort(zs, faces)
nFacesInitial = len(faces)
else:
# cut away "back sides" implementation
......@@ -246,7 +253,7 @@ def plot_mesh(vertices, faces, **kwargs):
artist = plot_array(vertices, **d)
elif po.dim == 3:
label = po.pop('label', None)
color = po.retrieveChain('color', 'c', 'facecolors',
color = po.retrieve_chain('color', 'c', 'facecolors',
default='grey',
keep=False)
color = po.formatColors(color,
......
......@@ -488,7 +488,6 @@ class Triangles3D(tfields.TensorFields):
not invertable matrices the you will always get False
>>> m3 = tfields.Mesh3D([[0,0,0], [2,0,0], [4,0,0], [0,1,0]],
... faces=[[0, 1, 2], [0, 1, 3]]);
>>> import pytest
>>> mask = m3.triangles()._in_triangles(np.array([0.2, 0.2, 0]), delta=0.3)
>>> assert np.array_equal(mask,
... np.array([False, True], dtype=bool))
......@@ -550,7 +549,7 @@ class Triangles3D(tfields.TensorFields):
For Example, if you want to know the number of points in one
face, just do:
>> tris.in_triangles(poits).sum(axis=0)
>> tris.in_triangles(poits).sum(axis=0)[face_index]
"""
if self.ntriangles() == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment