From 5327d1694e310973a2acd3d7911855b02d5ec71a Mon Sep 17 00:00:00 2001 From: "Boeckenhoff, Daniel (dboe)" <daniel.boeckenhoff@ipp.mpg.de> Date: Fri, 18 Jan 2019 14:59:03 +0100 Subject: [PATCH] Container added --- tfields/__init__.py | 2 +- tfields/core.py | 61 +++++++++++++++++++++++++++++++----- tfields/lib/in_out.py | 19 +++++++++++ tfields/plotting/__init__.py | 44 +++++++++++++------------- tfields/plotting/mpl.py | 30 +++++++++--------- 5 files changed, 111 insertions(+), 45 deletions(-) diff --git a/tfields/__init__.py b/tfields/__init__.py index fb36f4b..e192083 100644 --- a/tfields/__init__.py +++ b/tfields/__init__.py @@ -7,7 +7,7 @@ from .lib import * from . import plotting # __all__ = ['core', 'points3D'] -from .core import Tensors, TensorFields, TensorMaps +from .core import Tensors, TensorFields, TensorMaps, Container from .points3D import Points3D from .mask import evalf diff --git a/tfields/core.py b/tfields/core.py index 3ea0f2f..f215b3e 100644 --- a/tfields/core.py +++ b/tfields/core.py @@ -6,6 +6,11 @@ Mail: daniel.boeckenhoff@ipp.mpg.de core of tfields library contains numpy ndarray derived bases of the tfields package + +Notes: + It could be worthwhile concidering np.li.mixins.NDArrayOperatorsMixin + ... see https://docs.scipy.org/doc/numpy-1.15.1/reference + /generated/numpy.lib.mixins.NDArrayOperatorsMixin.html """ import warnings import os @@ -164,6 +169,14 @@ class AbstractNdarray(np.ndarray): index = -(i + 1) setattr(self, slot, state[index]) + @property + def bulk(self): + """ + The pure ndarray version of the actual state + -> nothing attached + """ + return np.array(self) + def copy(self, *args, **kwargs): """ The standard ndarray copy does not copy slots. Correct for this. @@ -696,14 +709,6 @@ class Tensors(AbstractNdarray): **cls_kwargs) return inst - @property - def bulk(self): - """ - The pure ndarray version of the actual state - -> nothing attached - """ - return np.array(self) - @property def rank(self): """ @@ -1881,6 +1886,10 @@ class TensorMaps(TensorFields): used like: self.maps[map_pos_idx] map_indices_list (list of list of int): each int refers to index in a map. + + Returns: + list of type(self): One TensorMaps or TensorMaps subclass per + map_description """ # raise ValueError(map_descriptions) parts = [] @@ -1900,6 +1909,8 @@ class TensorMaps(TensorFields): def disjoint_map(self, mp_idx): """ Find the disjoint sets of map = self.maps[mp_idx] + As an example, this method is interesting for splitting a mesh + consisting of seperate parts Args: mp_idx (int): reference to map position used like: self.maps[mp_idx] @@ -1926,6 +1937,40 @@ class TensorMaps(TensorFields): return (0, maps_list) +class Container(AbstractNdarray): + """ + Story lists of tfields objects. Save mechanisms are provided + Examples: + >>> import numpy as np + >>> import tfields + >>> sphere = tfields.Mesh3D.grid( + ... (1, 1, 1), + ... (-np.pi, np.pi, 3), + ... (-np.pi / 2, np.pi / 2, 3), + ... coord_sys='spherical') + >>> sphere2 = sphere.copy() * 3 + >>> c = tfields.Container([sphere, sphere2]) + + # >>> c.save("~/tmp/spheres.npz") + # >>> c1 = tfields.Container.load("~/tmp/spheres.npz") + """ + __slots__ = ['items', 'labels'] + def __new__(cls, items, **kwargs): + kwargs['items'] = items + cls._update_slot_kwargs(kwargs) + + empty = np.empty(0, int) + obj = empty.view(cls) + + ''' set kwargs to slots attributes ''' + for attr in kwargs: + if attr not in cls._iter_slots(): + raise AttributeError("Keyword argument {attr} not accepted " + "for class {cls}".format(**locals())) + setattr(obj, attr, kwargs[attr]) + return obj + + if __name__ == '__main__': # pragma: no cover import doctest doctest.testmod() diff --git a/tfields/lib/in_out.py b/tfields/lib/in_out.py index 8809a9e..9d2e746 100644 --- a/tfields/lib/in_out.py +++ b/tfields/lib/in_out.py @@ -42,6 +42,25 @@ def resolve(path): return os.path.realpath(os.path.abspath(os.path.expanduser(path))) +def provide(cls, path, function, *args, **kwargs): + """ + provide the object of type tfields_clas saved unter path + and generated by function with *args and **kwargs + Args: + cls (type): + """ + path = resolve(path) + if os.path.exists(path): + return cls.load(path) + obj = function(*args, **kwargs) + if not isinstance(obj, cls): + return_type = type(obj) + raise TypeError("Return value of function is not {cls} but" + "{return_type}".format(**locals())) + obj.save(path) + return obj + + def cp(source, dest, overwrite=True): """ copy with shutil diff --git a/tfields/plotting/__init__.py b/tfields/plotting/__init__.py index cdd636e..8f14760 100644 --- a/tfields/plotting/__init__.py +++ b/tfields/plotting/__init__.py @@ -33,12 +33,12 @@ class PlotOptions(object): self.dim = kwargs.pop('dim', None) self.method = kwargs.pop('methodName', None) self.setXYZAxis(kwargs) - self.plotKwargs = kwargs + self.plot_kwargs = kwargs @property def method(self): """ - Method for plotting. Will be callable together with plotKwargs + Method for plotting. Will be callable together with plot_kwargs """ return self._method @@ -109,10 +109,10 @@ class PlotOptions(object): return if vmin is None: vmin = min(scalars) - self.plotKwargs['vmin'] = vmin + self.plot_kwargs['vmin'] = vmin if vmax is None: vmax = max(scalars) - self.plotKwargs['vmax'] = vmax + self.plot_kwargs['vmax'] = vmax def getNormArgs(self, vminDefault=0, vmaxDefault=1, cmapDefault=None): if cmapDefault is None: @@ -140,15 +140,16 @@ class PlotOptions(object): hasIter = False colors = [colors] - if hasattr(colors[0], '__iter__') and fmt == 'norm': - # rgba given but norm wanted - cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified', - vminDefault=None, - vmaxDefault=None) - colors = to_scalars(colors, cmap, vmin, vmax) - self.plotKwargs['vmin'] = vmin - self.plotKwargs['vmax'] = vmax - self.plotKwargs['cmap'] = cmap + if fmt == 'norm': + if hasattr(colors[0], '__iter__'): + # rgba given but norm wanted + cmap, vmin, vmax = self.getNormArgs(cmapDefault='NotSpecified', + vminDefault=None, + vmaxDefault=None) + colors = to_scalars(colors, cmap, vmin, vmax) + self.plot_kwargs['vmin'] = vmin + self.plot_kwargs['vmax'] = vmax + self.plot_kwargs['cmap'] = cmap elif fmt == 'rgba': if isinstance(colors[0], string_types): # string color defined @@ -165,10 +166,11 @@ class PlotOptions(object): vmin=vmin, vmax=vmax, cmap=cmap) + print(colors) elif fmt == 'hex': colors = [mpl.colors.to_hex(color) for color in colors] else: - raise NotImplementedError("Color fmt {fmt} not implemented." + raise NotImplementedError("Color fmt '{fmt}' not implemented." .format(**locals())) if length is not None: @@ -185,9 +187,9 @@ class PlotOptions(object): return colors def delNormArgs(self): - self.plotKwargs.pop('vmin', None) - self.plotKwargs.pop('vmax', None) - self.plotKwargs.pop('cmap', None) + self.plot_kwargs.pop('vmin', None) + self.plot_kwargs.pop('vmax', None) + self.plot_kwargs.pop('cmap', None) def getSortedLabels(self, labels): """ @@ -196,16 +198,16 @@ class PlotOptions(object): return [labels[i] for i in self.getXYZAxis() if i is not None] def get(self, attr, default=None): - return self.plotKwargs.get(attr, default) + return self.plot_kwargs.get(attr, default) def pop(self, attr, default=None): - return self.plotKwargs.pop(attr, default) + return self.plot_kwargs.pop(attr, default) def set(self, attr, value): - self.plotKwargs[attr] = value + self.plot_kwargs[attr] = value def set_default(self, attr, value): - set_default(self.plotKwargs, attr, value) + set_default(self.plot_kwargs, attr, value) def retrieve(self, attr, default=None, keep=True): if keep: diff --git a/tfields/plotting/mpl.py b/tfields/plotting/mpl.py index af4e96f..435eb3a 100644 --- a/tfields/plotting/mpl.py +++ b/tfields/plotting/mpl.py @@ -198,7 +198,7 @@ def plot_array(array, **kwargs): array[:, yAxis], array[:, zAxis]] artist = po.method(*args, - **po.plotKwargs) + **po.plot_kwargs) return artist @@ -258,8 +258,8 @@ def plot_mesh(vertices, faces, **kwargs): vertices = mesh - po.plotKwargs['methodName'] = 'tripcolor' - po.plotKwargs['triangles'] = faces + po.plot_kwargs['methodName'] = 'tripcolor' + po.plot_kwargs['triangles'] = faces """ sort out color arguments @@ -269,9 +269,9 @@ def plot_mesh(vertices, faces, **kwargs): length=nFacesInitial) if not full: facecolors = facecolors[dotProduct > 0] - po.plotKwargs['facecolors'] = facecolors + po.plot_kwargs['facecolors'] = facecolors - d = po.plotKwargs + d = po.plot_kwargs d['xAxis'] = xAxis d['yAxis'] = yAxis artist = plot_array(vertices, **d) @@ -294,7 +294,7 @@ def plot_mesh(vertices, faces, **kwargs): po.delNormArgs() triangles = np.array([vertices[face] for face in faces]) - artist = plt3D.art3d.Poly3DCollection(triangles, **po.plotKwargs) + artist = plt3D.art3d.Poly3DCollection(triangles, **po.plot_kwargs) po.axis.add_collection3d(artist) if edgecolor is not None: @@ -341,11 +341,11 @@ def plot_tensor_field(points, vectors, **kwargs): if po.dim == 3: artists.append(po.axis.quiver(point[xAxis], point[yAxis], point[zAxis], vector[xAxis], vector[yAxis], vector[zAxis], - **po.plotKwargs)) + **po.plot_kwargs)) elif po.dim == 2: artists.append(po.axis.quiver(point[xAxis], point[yAxis], vector[xAxis], vector[yAxis], - **po.plotKwargs)) + **po.plot_kwargs)) else: raise NotImplementedError("Dimension != 2|3") return artists @@ -406,7 +406,7 @@ def plot_plane(point, normal, **kwargs): kwargs['alpha'] = kwargs.pop('alpha', 0.5) po = tfields.plotting.PlotOptions(kwargs) - patch = Circle((0, 0), **po.plotKwargs) + patch = Circle((0, 0), **po.plot_kwargs) po.axis.add_patch(patch) pathpatch_2d_to_3d(patch, z=0, normal=normal) pathpatch_translate(patch, (point[0], point[1], point[2])) @@ -422,7 +422,7 @@ def plot_sphere(point, radius, **kwargs): z = point[2] + radius * np.outer(np.ones(np.size(u)), np.cos(v)) # Plot the surface - return po.axis.plot_surface(x, y, z, **po.plotKwargs) + return po.axis.plot_surface(x, y, z, **po.plot_kwargs) def plot_function(fun, **kwargs): @@ -443,7 +443,7 @@ def plot_function(fun, **kwargs): vals = np.linspace(xMin, xMax, n) args = (vals, map(fun, vals)) artist = po.axis.plot(*args, - **po.plotKwargs) + **po.plot_kwargs) return artist @@ -465,7 +465,7 @@ def plot_errorbar(points, errors_up, errors_down=None, **kwargs): artists = [] # plot points - # artists.append(po.axis.plot(points, **po.plotKwargs)) + # artists.append(po.axis.plot(points, **po.plot_kwargs)) # plot errorbars for i in range(len(points)): @@ -474,19 +474,19 @@ def plot_errorbar(points, errors_up, errors_down=None, **kwargs): points[i, 0] - errors_down[i, 0]], [points[i, 1], points[i, 1]], [points[i, 2], points[i, 2]], - **po.plotKwargs)) + **po.plot_kwargs)) artists.append( po.axis.plot([points[i, 0], points[i, 0]], [points[i, 1] + errors_up[i, 1], points[i, 1] - errors_down[i, 1]], [points[i, 2], points[i, 2]], - **po.plotKwargs)) + **po.plot_kwargs)) artists.append( po.axis.plot([points[i, 0], points[i, 0]], [points[i, 1], points[i, 1]], [points[i, 2] + errors_up[i, 2], points[i, 2] - errors_down[i, 2]], - **po.plotKwargs)) + **po.plot_kwargs)) return artists -- GitLab