Commit 87b011cd authored by dboe's avatar dboe
Browse files

ltensor plotting

parent 05f7bdde
Pipeline #107170 passed with stages
in 57 seconds
......@@ -39,7 +39,7 @@ class TensorGrid_Test(unittest.TestCase, TensorFields_Check):
[1.0, 4.0, 8.0],
]
)
self._inst = TensorGrid.from_base_vectors(*self.bv, iter_order=self.iter_order)
self._inst = TensorGrid.grid(*self.bv, iter_order=self.iter_order)
def check_filled(self, tg):
self.assertTrue(tg.equal(self.res))
......@@ -48,7 +48,7 @@ class TensorGrid_Test(unittest.TestCase, TensorFields_Check):
def check_empty(self, tg):
self.assertTrue(tg.is_empty())
def test_from_base_vectors(self):
def test_grid(self):
self.check_filled(self._inst)
def test_empty(self):
......@@ -107,7 +107,7 @@ class TensorGrid_Test_Permutation1(TensorGrid_Test):
[1.0, 4.0, 8.0],
]
)
self._inst = TensorGrid.from_base_vectors(*self.bv, iter_order=self.iter_order)
self._inst = TensorGrid.grid(*self.bv, iter_order=self.iter_order)
class TensorGrid_Test_IO_Change(unittest.TestCase):
......
......@@ -33,6 +33,7 @@ import rna
import tfields.bases
np.seterr(all="warn", over="raise")
LOGGER = logging.getLogger(__name__)
def rank(tensor):
......@@ -1683,11 +1684,13 @@ class Tensors(AbstractNdarray): # pylint: disable=too-many-public-methods
# return np.divide(self.T, self.norm(*args, **kwargs)).T
return np.divide(self, self.norm(*args, **kwargs)[:, None])
def plot(self, **kwargs):
def plot(self, *args, **kwargs):
"""
Forwarding to rna.plotting.plot_tensor
"""
artist = rna.plotting.plot_tensor(self, **kwargs) # pylint: disable=no-member
artist = rna.plotting.plot_tensor(
self, *args, **kwargs
) # pylint: disable=no-member
return artist
......@@ -2023,15 +2026,28 @@ class TensorFields(Tensors):
weights = self.fields[weights]
return super(TensorFields, self)._weights(weights, rigid=rigid)
def plot(self, **kwargs):
def plot(self, *args, **kwargs):
"""
Override Tensors plot method:
By default, vector fields are plotted with the quiver method
Plotting the tensor field.
Args:
field_index: index of the field to plot (as quiver by default)
normalize: If true, normalize the field vectors to show only the direction
color: additional str argument 'norm' added. If color="norm", color with the norm.
"""
field_index = kwargs.pop("field_index", None)
field_args = ["normalize"]
if field_index is None:
artist = super(TensorFields, self).plot(**kwargs)
for field_arg in field_args:
if field_arg in kwargs:
kwargs.pop(field_arg)
LOGGER.warning("Unused option %s", field_arg)
artist = super(TensorFields, self).plot(*args, **kwargs)
else:
normalize_field = kwargs.pop("normalize", False)
color = kwargs.get("color", None)
field = self.fields[field_index].copy()
if self.dim == field.dim:
field.transform(self.coord_sys)
......@@ -2040,10 +2056,17 @@ class TensorFields(Tensors):
"Careful: Plotting tensors with field of"
"different dimension. No coord_sys check performed."
)
if color == "norm":
norm = field.norm()
kwargs["color"] = norm
if normalize_field:
field = field.normalized()
if field.dim <= 3:
artist = (
rna.plotting.plot_tensor( # noqa: E501 pylint: disable=no-member
self, field, **kwargs
self, *args, field, **kwargs
)
)
else:
......@@ -2830,9 +2853,9 @@ class TensorMaps(TensorFields):
paths = sorted_paths
return paths
def plot(self, **kwargs): # pragma: no cover
def plot(self, *args, **kwargs): # pragma: no cover
"""
Forwarding to rna.plotting.plot_mesh
Plot the tensor map.
"""
scalars_demanded = (
"color" not in kwargs
......@@ -2845,8 +2868,8 @@ class TensorMaps(TensorFields):
if not len(map_) == 0:
kwargs["color"] = map_.fields[map_index]
if map_.dim == 3:
return rna.plotting.plot_mesh(self, map_, **kwargs)
return rna.plotting.plot_tensor_map(self, map_, **kwargs)
return rna.plotting.plot_mesh(self, *args, map_, **kwargs)
return rna.plotting.plot_tensor_map(self, *args, map_, **kwargs)
if __name__ == "__main__": # pragma: no cover
......
......@@ -4,15 +4,26 @@ import functools
import tfields.lib.util
def ensure_complex(*base_vectors):
# ensure, that the third entry in base_vector of type tuple becomes a complex type
def ensure_complex(*base_vectors) -> typing.List[typing.Tuple[float, float, complex]]:
"""
Ensure, that the third entry in base_vector of type tuple becomes a complex type.
The first two are mapped to float if they they are complex but with imag == 0.
"""
base_vectors = list(base_vectors)
for i, vector in enumerate(base_vectors):
if isinstance(vector, tuple):
if len(vector) == 3:
vector = list(vector)
vector[2] = complex(vector[2])
base_vectors[i] = tuple(vector)
new_vector = []
for j, entry in enumerate(vector):
# vector -> start, stop, n_steps
if isinstance(entry, complex) and entry.imag == 0:
# start and stop mapping to float
entry = entry.real
if j == 2:
# n_steps mapping to complex
entry = complex(entry)
new_vector.append(entry)
base_vectors[i] = tuple(new_vector)
return base_vectors
......
......@@ -217,6 +217,61 @@ def index(arr, entry, rtol=0, atol=0, equal_nan=False, axis=None):
return i
def is_full_slice(index, shape):
"""
Determine if an index is the full slice (i.e. __getitem__ with this index returns the full
array) w.r.t the shape given.
Examples:
>>> import numpy as np
>>> import tfields
>>> class index_getter:
... def __getitem__(self, index):
... return index
>>> get_index = index_getter()
>>> a = np.array([[1, 0, 0], [1, 0, 0], [2, 3, 4]])
>>> shape = a.shape
>>> tfields.lib.util.is_full_slice(get_index[:], shape)
True
>>> tfields.lib.util.is_full_slice(get_index[:, :], shape)
True
>>> tfields.lib.util.is_full_slice(get_index[:, 1], shape)
False
>>> tfields.lib.util.is_full_slice(get_index[1:, :], shape)
False
>>> tfields.lib.util.is_full_slice(get_index[:1, :], shape)
False
>>> tfields.lib.util.is_full_slice(get_index[:, 1:], shape)
False
>>> tfields.lib.util.is_full_slice(get_index[:, :1], shape)
False
>>> tfields.lib.util.is_full_slice(get_index[:, :-1], shape)
True
>>> tfields.lib.util.is_full_slice(get_index[np.array([True, True, True])], shape)
True
>>> tfields.lib.util.is_full_slice(get_index[np.array([True, True, False])], shape)
False
"""
if isinstance(index, slice):
if (
index.step in (None, 1)
and index.start in (None, 0)
and index.stop in (None, -1, shape[0])
):
# full slice -> no type change
return True
else:
return False
elif isinstance(index, tuple):
return all((is_full_slice(ind, (shp,)) for ind, shp in zip(index, shape)))
elif isinstance(index, int):
return index == 0 and shape[0] == 1
elif isinstance(index, (np.ndarray, list)):
return all(index)
else:
raise NotImplementedError("Index Type %s", type(index))
if __name__ == "__main__":
import doctest
......
......@@ -2,7 +2,7 @@
Implementaiton of TensorGrid class
"""
import numpy as np
from .lib import grid
from .lib import grid, util
from .core import TensorFields
......@@ -31,9 +31,11 @@ class TensorGrid(TensorFields):
else:
default_base_vectors = kwargs.pop("base_vectors", None)
default_iter_order = np.arange(len(default_base_vectors))
base_vectors = kwargs.pop("base_vectors", default_base_vectors)
iter_order = kwargs.pop("iter_order", default_iter_order)
base_vectors = [
tuple(bv) for bv in kwargs.pop("base_vectors", default_base_vectors)
]
base_vectors = grid.ensure_complex(*base_vectors)
iter_order = kwargs.pop("iter_order", default_iter_order)
obj = super(TensorGrid, cls).__new__(cls, tensors, *fields, **kwargs)
obj.base_vectors = base_vectors
obj.iter_order = iter_order
......@@ -43,24 +45,19 @@ class TensorGrid(TensorFields):
if not self.is_empty():
return super().__getitem__(index)
item = self.explicit()
if isinstance(index, slice):
if (
index.step in (None, 1)
and index.start in (None, 0)
and index.stop in (None, -1, len(item))
):
# full slice -> no type change
pass
else:
# downgrade
item = TensorFields(item)
if not util.is_full_slice(index, item.shape):
# downgrade to TensorFields
item = TensorFields(item)
return item.__getitem__(index)
@classmethod
def from_base_vectors(cls, *base_vectors, tensors=None, fields=None, **kwargs):
def grid(cls, *base_vectors, tensors=None, fields=None, **kwargs):
"""
Build the grid (explicitly) from base vectors
Args:
explicit args: see __new__
**kwargs: see TensorFields
"""
iter_order = kwargs.pop("iter_order", np.arange(len(base_vectors)))
if tensors is None:
......@@ -80,7 +77,7 @@ class TensorGrid(TensorFields):
bv_lengths = [int(bv[2].imag) for bv in base_vectors]
tensors = np.empty(shape=(np.prod(bv_lengths), 0))
return cls.from_base_vectors(*base_vectors, tensors=tensors, **kwargs)
return cls.grid(*base_vectors, tensors=tensors, **kwargs)
@classmethod
def merged(cls, *objects, **kwargs):
......@@ -117,7 +114,7 @@ class TensorGrid(TensorFields):
"""
kwargs = {attr: getattr(self, attr) for attr in self.__slots__}
base_vectors = kwargs.pop("base_vectors")
return self.from_base_vectors(*base_vectors, **kwargs)
return self.grid(*base_vectors, **kwargs)
def change_iter_order(self, iter_order):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment