diff --git a/tests/test_grid.py b/tests/test_grid.py index f8e08cd6312d87ab94f8e0df863a53149f400c30..88ef8bc1c7762884b4e7f3c8e20625f9069a843f 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -39,7 +39,7 @@ class TensorGrid_Test(unittest.TestCase, TensorFields_Check): [1.0, 4.0, 8.0], ] ) - self._inst = TensorGrid.from_base_vectors(*self.bv, iter_order=self.iter_order) + self._inst = TensorGrid.grid(*self.bv, iter_order=self.iter_order) def check_filled(self, tg): self.assertTrue(tg.equal(self.res)) @@ -48,7 +48,7 @@ class TensorGrid_Test(unittest.TestCase, TensorFields_Check): def check_empty(self, tg): self.assertTrue(tg.is_empty()) - def test_from_base_vectors(self): + def test_grid(self): self.check_filled(self._inst) def test_empty(self): @@ -107,7 +107,7 @@ class TensorGrid_Test_Permutation1(TensorGrid_Test): [1.0, 4.0, 8.0], ] ) - self._inst = TensorGrid.from_base_vectors(*self.bv, iter_order=self.iter_order) + self._inst = TensorGrid.grid(*self.bv, iter_order=self.iter_order) class TensorGrid_Test_IO_Change(unittest.TestCase): diff --git a/tfields/core.py b/tfields/core.py index 8c4ddc9ba796977b1c6190ac38cfd3f58c1e0294..2956fcc50e3dc1e6b856df0401341df0e05264b8 100644 --- a/tfields/core.py +++ b/tfields/core.py @@ -33,6 +33,7 @@ import rna import tfields.bases np.seterr(all="warn", over="raise") +LOGGER = logging.getLogger(__name__) def rank(tensor): @@ -1683,11 +1684,13 @@ class Tensors(AbstractNdarray): # pylint: disable=too-many-public-methods # return np.divide(self.T, self.norm(*args, **kwargs)).T return np.divide(self, self.norm(*args, **kwargs)[:, None]) - def plot(self, **kwargs): + def plot(self, *args, **kwargs): """ Forwarding to rna.plotting.plot_tensor """ - artist = rna.plotting.plot_tensor(self, **kwargs) # pylint: disable=no-member + artist = rna.plotting.plot_tensor( + self, *args, **kwargs + ) # pylint: disable=no-member return artist @@ -2023,15 +2026,28 @@ class TensorFields(Tensors): weights = self.fields[weights] return super(TensorFields, self)._weights(weights, rigid=rigid) - def plot(self, **kwargs): + def plot(self, *args, **kwargs): """ - Override Tensors plot method: - By default, vector fields are plotted with the quiver method + Plotting the tensor field. + + Args: + field_index: index of the field to plot (as quiver by default) + normalize: If true, normalize the field vectors to show only the direction + color: additional str argument 'norm' added. If color="norm", color with the norm. """ field_index = kwargs.pop("field_index", None) + + field_args = ["normalize"] if field_index is None: - artist = super(TensorFields, self).plot(**kwargs) + for field_arg in field_args: + if field_arg in kwargs: + kwargs.pop(field_arg) + LOGGER.warning("Unused option %s", field_arg) + artist = super(TensorFields, self).plot(*args, **kwargs) else: + normalize_field = kwargs.pop("normalize", False) + color = kwargs.get("color", None) + field = self.fields[field_index].copy() if self.dim == field.dim: field.transform(self.coord_sys) @@ -2040,10 +2056,17 @@ class TensorFields(Tensors): "Careful: Plotting tensors with field of" "different dimension. No coord_sys check performed." ) + + if color == "norm": + norm = field.norm() + kwargs["color"] = norm + if normalize_field: + field = field.normalized() + if field.dim <= 3: artist = ( rna.plotting.plot_tensor( # noqa: E501 pylint: disable=no-member - self, field, **kwargs + self, *args, field, **kwargs ) ) else: @@ -2830,9 +2853,9 @@ class TensorMaps(TensorFields): paths = sorted_paths return paths - def plot(self, **kwargs): # pragma: no cover + def plot(self, *args, **kwargs): # pragma: no cover """ - Forwarding to rna.plotting.plot_mesh + Plot the tensor map. """ scalars_demanded = ( "color" not in kwargs @@ -2845,8 +2868,8 @@ class TensorMaps(TensorFields): if not len(map_) == 0: kwargs["color"] = map_.fields[map_index] if map_.dim == 3: - return rna.plotting.plot_mesh(self, map_, **kwargs) - return rna.plotting.plot_tensor_map(self, map_, **kwargs) + return rna.plotting.plot_mesh(self, *args, map_, **kwargs) + return rna.plotting.plot_tensor_map(self, *args, map_, **kwargs) if __name__ == "__main__": # pragma: no cover diff --git a/tfields/lib/grid.py b/tfields/lib/grid.py index 2789fe58020cb51ac8c5ddb38b044df998765114..d9db0fa15a195ffe53b716cf75f2c95ecc90c9d4 100644 --- a/tfields/lib/grid.py +++ b/tfields/lib/grid.py @@ -4,15 +4,26 @@ import functools import tfields.lib.util -def ensure_complex(*base_vectors): - # ensure, that the third entry in base_vector of type tuple becomes a complex type +def ensure_complex(*base_vectors) -> typing.List[typing.Tuple[float, float, complex]]: + """ + Ensure, that the third entry in base_vector of type tuple becomes a complex type. + The first two are mapped to float if they they are complex but with imag == 0. + """ base_vectors = list(base_vectors) for i, vector in enumerate(base_vectors): if isinstance(vector, tuple): if len(vector) == 3: - vector = list(vector) - vector[2] = complex(vector[2]) - base_vectors[i] = tuple(vector) + new_vector = [] + for j, entry in enumerate(vector): + # vector -> start, stop, n_steps + if isinstance(entry, complex) and entry.imag == 0: + # start and stop mapping to float + entry = entry.real + if j == 2: + # n_steps mapping to complex + entry = complex(entry) + new_vector.append(entry) + base_vectors[i] = tuple(new_vector) return base_vectors diff --git a/tfields/lib/util.py b/tfields/lib/util.py index 05db5a7c2ca4506e4f0c26dcf1c8e714a915802d..78500436d646d6fbf4a60ea18e0b1548b8f4d974 100644 --- a/tfields/lib/util.py +++ b/tfields/lib/util.py @@ -217,6 +217,61 @@ def index(arr, entry, rtol=0, atol=0, equal_nan=False, axis=None): return i +def is_full_slice(index, shape): + """ + Determine if an index is the full slice (i.e. __getitem__ with this index returns the full + array) w.r.t the shape given. + + Examples: + >>> import numpy as np + >>> import tfields + >>> class index_getter: + ... def __getitem__(self, index): + ... return index + >>> get_index = index_getter() + >>> a = np.array([[1, 0, 0], [1, 0, 0], [2, 3, 4]]) + >>> shape = a.shape + >>> tfields.lib.util.is_full_slice(get_index[:], shape) + True + >>> tfields.lib.util.is_full_slice(get_index[:, :], shape) + True + >>> tfields.lib.util.is_full_slice(get_index[:, 1], shape) + False + >>> tfields.lib.util.is_full_slice(get_index[1:, :], shape) + False + >>> tfields.lib.util.is_full_slice(get_index[:1, :], shape) + False + >>> tfields.lib.util.is_full_slice(get_index[:, 1:], shape) + False + >>> tfields.lib.util.is_full_slice(get_index[:, :1], shape) + False + >>> tfields.lib.util.is_full_slice(get_index[:, :-1], shape) + True + >>> tfields.lib.util.is_full_slice(get_index[np.array([True, True, True])], shape) + True + >>> tfields.lib.util.is_full_slice(get_index[np.array([True, True, False])], shape) + False + """ + if isinstance(index, slice): + if ( + index.step in (None, 1) + and index.start in (None, 0) + and index.stop in (None, -1, shape[0]) + ): + # full slice -> no type change + return True + else: + return False + elif isinstance(index, tuple): + return all((is_full_slice(ind, (shp,)) for ind, shp in zip(index, shape))) + elif isinstance(index, int): + return index == 0 and shape[0] == 1 + elif isinstance(index, (np.ndarray, list)): + return all(index) + else: + raise NotImplementedError("Index Type %s", type(index)) + + if __name__ == "__main__": import doctest diff --git a/tfields/tensor_grid.py b/tfields/tensor_grid.py index 770c501769bc90cb2d6f754d0511be98d7522bb4..7759c499429257e3a340507c116751da78f3c465 100644 --- a/tfields/tensor_grid.py +++ b/tfields/tensor_grid.py @@ -2,7 +2,7 @@ Implementaiton of TensorGrid class """ import numpy as np -from .lib import grid +from .lib import grid, util from .core import TensorFields @@ -31,9 +31,11 @@ class TensorGrid(TensorFields): else: default_base_vectors = kwargs.pop("base_vectors", None) default_iter_order = np.arange(len(default_base_vectors)) - base_vectors = kwargs.pop("base_vectors", default_base_vectors) - iter_order = kwargs.pop("iter_order", default_iter_order) + base_vectors = [ + tuple(bv) for bv in kwargs.pop("base_vectors", default_base_vectors) + ] base_vectors = grid.ensure_complex(*base_vectors) + iter_order = kwargs.pop("iter_order", default_iter_order) obj = super(TensorGrid, cls).__new__(cls, tensors, *fields, **kwargs) obj.base_vectors = base_vectors obj.iter_order = iter_order @@ -43,24 +45,19 @@ class TensorGrid(TensorFields): if not self.is_empty(): return super().__getitem__(index) item = self.explicit() - if isinstance(index, slice): - if ( - index.step in (None, 1) - and index.start in (None, 0) - and index.stop in (None, -1, len(item)) - ): - # full slice -> no type change - pass - else: - # downgrade - item = TensorFields(item) - + if not util.is_full_slice(index, item.shape): + # downgrade to TensorFields + item = TensorFields(item) return item.__getitem__(index) @classmethod - def from_base_vectors(cls, *base_vectors, tensors=None, fields=None, **kwargs): + def grid(cls, *base_vectors, tensors=None, fields=None, **kwargs): """ Build the grid (explicitly) from base vectors + + Args: + explicit args: see __new__ + **kwargs: see TensorFields """ iter_order = kwargs.pop("iter_order", np.arange(len(base_vectors))) if tensors is None: @@ -80,7 +77,7 @@ class TensorGrid(TensorFields): bv_lengths = [int(bv[2].imag) for bv in base_vectors] tensors = np.empty(shape=(np.prod(bv_lengths), 0)) - return cls.from_base_vectors(*base_vectors, tensors=tensors, **kwargs) + return cls.grid(*base_vectors, tensors=tensors, **kwargs) @classmethod def merged(cls, *objects, **kwargs): @@ -117,7 +114,7 @@ class TensorGrid(TensorFields): """ kwargs = {attr: getattr(self, attr) for attr in self.__slots__} base_vectors = kwargs.pop("base_vectors") - return self.from_base_vectors(*base_vectors, **kwargs) + return self.grid(*base_vectors, **kwargs) def change_iter_order(self, iter_order): """