From 9bc2a23f709c19adc2d7c49638ff627a39bae4cc Mon Sep 17 00:00:00 2001 From: "Boeckenhoff, Daniel (dboe)" <daniel.boeckenhoff@ipp.mpg.de> Date: Fri, 3 Aug 2018 18:43:37 +0200 Subject: [PATCH] np.savez implemented generally for all sub types of Tensors --- setup.py | 2 + tfields/__init__.py | 2 +- tfields/bases/bases.py | 5 +- tfields/core.py | 127 +++++++++++++++++++++++++++++++++------- tfields/mesh3D.py | 3 +- tfields/plotting/mpl.py | 9 ++- tfields/triangles3D.py | 4 ++ 7 files changed, 122 insertions(+), 30 deletions(-) diff --git a/setup.py b/setup.py index aa93c74..33fce85 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,8 @@ setup( 'sympy', 'scipy', 'pathlib', + 'six', + 'matplotlib', ], entry_points={ 'console_scripts': ['tfields = tfields.__main__:runDoctests'] diff --git a/tfields/__init__.py b/tfields/__init__.py index 9518a8e..6565cde 100644 --- a/tfields/__init__.py +++ b/tfields/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.1.0.dev4' +__version__ = '0.1.0.dev5' __author__ = "Daniel Boeckenhoff" __email__ = "daniel.boeckenhoff@ipp.mpg.de" diff --git a/tfields/bases/bases.py b/tfields/bases/bases.py index f66b21f..c94ef72 100644 --- a/tfields/bases/bases.py +++ b/tfields/bases/bases.py @@ -22,8 +22,9 @@ def get_coord_system(base): Return: sympy.diffgeom.get_coord_system """ - if isinstance(base, string_types): - base = getattr(tfields.bases, base) + if (isinstance(base, string_types) + or (isinstance(base, np.ndarray) and base.dtype.kind in {'U', 'S'})): + base = getattr(tfields.bases, str(base)) if not isinstance(base, sympy.diffgeom.CoordSystem): bse_tpe = type(base) expctd_tpe = type(sympy.diffgeom.CoordSystem) diff --git a/tfields/core.py b/tfields/core.py index 6075736..8c08e3d 100644 --- a/tfields/core.py +++ b/tfields/core.py @@ -115,6 +115,8 @@ class AbstractNdarray(np.ndarray): def __reduce__(self): """ important for pickling + see https://stackoverflow.com/questions/26598109/ + preserve-custom-attributes-when-pickling-subclass-of-numpy-array Examples: >>> from tempfile import NamedTemporaryFile >>> import pickle @@ -127,13 +129,13 @@ class AbstractNdarray(np.ndarray): >>> scalarField = TensorFields(vectors, scalars, coord_sys='cylinder') Save it and restore it - >>> outFile = NamedTemporaryFile(suffix='.pickle') + >>> out_file = NamedTemporaryFile(suffix='.pickle') >>> pickle.dump(scalarField, - ... outFile) - >>> _ = outFile.seek(0) + ... out_file) + >>> _ = out_file.seek(0) - >>> sf = pickle.load(outFile) + >>> sf = pickle.load(out_file) >>> sf.coord_sys == 'cylinder' True >>> sf.fields[0][2] == 2. @@ -250,26 +252,34 @@ class AbstractNdarray(np.ndarray): Args: path (open file or str/unicode): destination to save file to. Examples: + Build some dummies: >>> import tfields >>> from tempfile import NamedTemporaryFile - >>> outFile = NamedTemporaryFile(suffix='.npz') + >>> out_file = NamedTemporaryFile(suffix='.npz') >>> p = tfields.Points3D([[1., 2., 3.], [4., 5., 6.], [1, 2, -6]]) - >>> p.save(outFile.name) - >>> _ = outFile.seek(0) - >>> p1 = tfields.Points3D.load(outFile.name) + + >>> scalars = tfields.Tensors([0, 1, 2]) + >>> vectors = tfields.Tensors([[0, 0, 0], [0, 0, 1], [0, -1, 0]]) + >>> maps = [tfields.TensorFields([[0, 1, 2], [0, 1, 2]], [42, 21]), + ... tfields.TensorFields([[1], [2]], [-42, -21])] + >>> m = tfields.TensorMaps(vectors, scalars, + ... maps=maps) + + Simply give the file name to save + >>> p.save(out_file.name) + >>> _ = out_file.seek(0) + >>> p1 = tfields.Points3D.load(out_file.name) >>> assert p.equal(p1) - """ - kwargs = {} - for attr in self._iter_slots(): - if not hasattr(self, attr): - # attribute in __slots__ not found. - warnings.warn("When saving instance of class {0} Attribute {1} not set." - "This Attribute is not saved.".format(self.__class__, attr), Warning) - else: - kwargs[attr] = getattr(self, attr) + The fully nested structure of a TensorMaps object is reconstructed + >>> out_file_maps = NamedTemporaryFile(suffix='.npz') + >>> m.save(out_file_maps.name) + >>> _ = out_file_maps.seek(0) + >>> m1 = tfields.TensorMaps.load(out_file_maps.name) + >>> assert m.equal(m1) - np.savez(path, self, **kwargs) + """ + np.savez(path, **self._as_dict()) @classmethod def _load_npz(cls, path, **load_kwargs): @@ -278,10 +288,82 @@ class AbstractNdarray(np.ndarray): Given a path to a npz file, construct the object """ np_file = np.load(path, **load_kwargs) - keys = np_file.keys() - bulk = np_file['arr_0'] - data_kwargs = {key: np_file[key] for key in keys if key not in ['arr_0']} - return cls.__new__(cls, bulk, **data_kwargs) + return cls._from_dict(**np_file) + + def _as_dict(self): + """ + Recursively walk trough all __slots__ and describe all elements + """ + d = {} + d['bulk'] = np.array(self) + d['bulk_type'] = self.__class__.__name__ + for attr in self._iter_slots(): + value = getattr(self, attr) + if isinstance(value, list): + if len(value) == 0: + d[attr] = None + if all([isinstance(part, AbstractNdarray) for part in value]): + for i, part in enumerate(value): + part_dict = part._as_dict() + for part_attr, part_value in part_dict.items(): + d["{attr}::{i}::{part_attr}".format(**locals())] = part_value + continue + if isinstance(value, AbstractNdarray): + value = value._as_dict() + d[attr] = value + return d + + @classmethod + def _from_dict(cls, **d): + """ + Opposite of _as_dict + """ + list_dict = {} + kwargs = {} + ''' + De-Flatten the first layer of lists + ''' + for key in sorted(d.keys()): + if '::' in key: + splits = key.split('::') + attr, _, end = key.partition('::') + if attr not in list_dict: + list_dict[attr] = {} + + index, _, end = end.partition('::') + if not index.isdigit(): + raise ValueError("None digit index given") + index = int(index) + if index not in list_dict[attr]: + list_dict[attr][index] = {} + list_dict[attr][index][end] = d[key] + else: + kwargs[key] = d[key] + + ''' + Build the lists (recursively) + ''' + for key in list_dict.keys(): + sub_dict = list_dict[key] + list_dict[key] = [] + for index in sorted(sub_dict.keys()): + bulk_type = sub_dict[index].get('bulk_type') + bulk_type = getattr(tfields, bulk_type.tolist()) + list_dict[key].append(bulk_type._from_dict(**sub_dict[index])) + + ''' + Build the normal way + ''' + bulk = kwargs.pop('bulk') + bulk_type = kwargs.pop('bulk_type') + obj = cls.__new__(cls, bulk, **kwargs) + + ''' + Set list attributes + ''' + for attr, list_value in list_dict.items(): + setattr(obj, attr, list_value) + return obj class Tensors(AbstractNdarray): @@ -1753,3 +1835,4 @@ if __name__ == '__main__': # pragma: no cover import doctest doctest.testmod() # doctest.run_docstring_examples(TensorFields.__getitem__, globals()) + # doctest.run_docstring_examples(AbstractNdarray._save_npz, globals()) diff --git a/tfields/mesh3D.py b/tfields/mesh3D.py index 8d17498..b23c5b3 100644 --- a/tfields/mesh3D.py +++ b/tfields/mesh3D.py @@ -200,8 +200,7 @@ class Mesh3D(tfields.TensorMaps): >>> m1 = tfields.Mesh3D.load(outFile.name) >>> bool(np.all(m == m1)) True - >>> m1.faces - array([[0, 1, 2]]) + >>> assert np.array_equal(m1.faces, np.array([[0, 1, 2]])) """ def __new__(cls, tensors, *fields, **kwargs): diff --git a/tfields/plotting/mpl.py b/tfields/plotting/mpl.py index 3367fbf..1d39712 100644 --- a/tfields/plotting/mpl.py +++ b/tfields/plotting/mpl.py @@ -121,7 +121,10 @@ def save(path, *fmts, **kwargs): if first_label: if not extra_artists: extra_artists = [] - extra_artists.append(first_label) + if isinstance(first_label, list): + extra_artists.extend(first_label) + else: + extra_artists.append(first_label) kwargs['bbox_extra_artists'] = kwargs.pop('bbox_extra_artists', extra_artists) @@ -564,7 +567,7 @@ def autoscale_3d(axis, array=None, xLim=None, yLim=None, zLim=None): axis.set_zlim([zMin, zMax]) -def set_legend(axis, artists): +def set_legend(axis, artists, **kwargs): """ Convenience method to set a legend from multiple artists to an axis. """ @@ -574,7 +577,7 @@ def set_legend(axis, artists): handles.append(artist[0]) else: handles.append(artist) - axis.legend(handles=handles) + return axis.legend(handles=handles, **kwargs) def set_colorbar(axis, artist, label=None, divide=True, **kwargs): diff --git a/tfields/triangles3D.py b/tfields/triangles3D.py index c9f22f5..bf24ae5 100644 --- a/tfields/triangles3D.py +++ b/tfields/triangles3D.py @@ -544,6 +544,10 @@ class Triangles3D(tfields.TensorFields): there is just one True for each points index triangle indices can have multiple true values + For Example, if you want to know the number of points in one + face, just do: + >> tris.in_triangles(poits).sum(axis=0) + """ if self.ntriangles() == 0: return np.empty((tensors.shape[0], 0), dtype=bool) -- GitLab