Commit 87b011cd authored by dboe's avatar dboe
Browse files

ltensor plotting

parent 05f7bdde
Pipeline #107170 passed with stages
in 57 seconds
...@@ -39,7 +39,7 @@ class TensorGrid_Test(unittest.TestCase, TensorFields_Check): ...@@ -39,7 +39,7 @@ class TensorGrid_Test(unittest.TestCase, TensorFields_Check):
[1.0, 4.0, 8.0], [1.0, 4.0, 8.0],
] ]
) )
self._inst = TensorGrid.from_base_vectors(*self.bv, iter_order=self.iter_order) self._inst = TensorGrid.grid(*self.bv, iter_order=self.iter_order)
def check_filled(self, tg): def check_filled(self, tg):
self.assertTrue(tg.equal(self.res)) self.assertTrue(tg.equal(self.res))
...@@ -48,7 +48,7 @@ class TensorGrid_Test(unittest.TestCase, TensorFields_Check): ...@@ -48,7 +48,7 @@ class TensorGrid_Test(unittest.TestCase, TensorFields_Check):
def check_empty(self, tg): def check_empty(self, tg):
self.assertTrue(tg.is_empty()) self.assertTrue(tg.is_empty())
def test_from_base_vectors(self): def test_grid(self):
self.check_filled(self._inst) self.check_filled(self._inst)
def test_empty(self): def test_empty(self):
...@@ -107,7 +107,7 @@ class TensorGrid_Test_Permutation1(TensorGrid_Test): ...@@ -107,7 +107,7 @@ class TensorGrid_Test_Permutation1(TensorGrid_Test):
[1.0, 4.0, 8.0], [1.0, 4.0, 8.0],
] ]
) )
self._inst = TensorGrid.from_base_vectors(*self.bv, iter_order=self.iter_order) self._inst = TensorGrid.grid(*self.bv, iter_order=self.iter_order)
class TensorGrid_Test_IO_Change(unittest.TestCase): class TensorGrid_Test_IO_Change(unittest.TestCase):
......
...@@ -33,6 +33,7 @@ import rna ...@@ -33,6 +33,7 @@ import rna
import tfields.bases import tfields.bases
np.seterr(all="warn", over="raise") np.seterr(all="warn", over="raise")
LOGGER = logging.getLogger(__name__)
def rank(tensor): def rank(tensor):
...@@ -1683,11 +1684,13 @@ class Tensors(AbstractNdarray): # pylint: disable=too-many-public-methods ...@@ -1683,11 +1684,13 @@ class Tensors(AbstractNdarray): # pylint: disable=too-many-public-methods
# return np.divide(self.T, self.norm(*args, **kwargs)).T # return np.divide(self.T, self.norm(*args, **kwargs)).T
return np.divide(self, self.norm(*args, **kwargs)[:, None]) return np.divide(self, self.norm(*args, **kwargs)[:, None])
def plot(self, **kwargs): def plot(self, *args, **kwargs):
""" """
Forwarding to rna.plotting.plot_tensor Forwarding to rna.plotting.plot_tensor
""" """
artist = rna.plotting.plot_tensor(self, **kwargs) # pylint: disable=no-member artist = rna.plotting.plot_tensor(
self, *args, **kwargs
) # pylint: disable=no-member
return artist return artist
...@@ -2023,15 +2026,28 @@ class TensorFields(Tensors): ...@@ -2023,15 +2026,28 @@ class TensorFields(Tensors):
weights = self.fields[weights] weights = self.fields[weights]
return super(TensorFields, self)._weights(weights, rigid=rigid) return super(TensorFields, self)._weights(weights, rigid=rigid)
def plot(self, **kwargs): def plot(self, *args, **kwargs):
""" """
Override Tensors plot method: Plotting the tensor field.
By default, vector fields are plotted with the quiver method
Args:
field_index: index of the field to plot (as quiver by default)
normalize: If true, normalize the field vectors to show only the direction
color: additional str argument 'norm' added. If color="norm", color with the norm.
""" """
field_index = kwargs.pop("field_index", None) field_index = kwargs.pop("field_index", None)
field_args = ["normalize"]
if field_index is None: if field_index is None:
artist = super(TensorFields, self).plot(**kwargs) for field_arg in field_args:
if field_arg in kwargs:
kwargs.pop(field_arg)
LOGGER.warning("Unused option %s", field_arg)
artist = super(TensorFields, self).plot(*args, **kwargs)
else: else:
normalize_field = kwargs.pop("normalize", False)
color = kwargs.get("color", None)
field = self.fields[field_index].copy() field = self.fields[field_index].copy()
if self.dim == field.dim: if self.dim == field.dim:
field.transform(self.coord_sys) field.transform(self.coord_sys)
...@@ -2040,10 +2056,17 @@ class TensorFields(Tensors): ...@@ -2040,10 +2056,17 @@ class TensorFields(Tensors):
"Careful: Plotting tensors with field of" "Careful: Plotting tensors with field of"
"different dimension. No coord_sys check performed." "different dimension. No coord_sys check performed."
) )
if color == "norm":
norm = field.norm()
kwargs["color"] = norm
if normalize_field:
field = field.normalized()
if field.dim <= 3: if field.dim <= 3:
artist = ( artist = (
rna.plotting.plot_tensor( # noqa: E501 pylint: disable=no-member rna.plotting.plot_tensor( # noqa: E501 pylint: disable=no-member
self, field, **kwargs self, *args, field, **kwargs
) )
) )
else: else:
...@@ -2830,9 +2853,9 @@ class TensorMaps(TensorFields): ...@@ -2830,9 +2853,9 @@ class TensorMaps(TensorFields):
paths = sorted_paths paths = sorted_paths
return paths return paths
def plot(self, **kwargs): # pragma: no cover def plot(self, *args, **kwargs): # pragma: no cover
""" """
Forwarding to rna.plotting.plot_mesh Plot the tensor map.
""" """
scalars_demanded = ( scalars_demanded = (
"color" not in kwargs "color" not in kwargs
...@@ -2845,8 +2868,8 @@ class TensorMaps(TensorFields): ...@@ -2845,8 +2868,8 @@ class TensorMaps(TensorFields):
if not len(map_) == 0: if not len(map_) == 0:
kwargs["color"] = map_.fields[map_index] kwargs["color"] = map_.fields[map_index]
if map_.dim == 3: if map_.dim == 3:
return rna.plotting.plot_mesh(self, map_, **kwargs) return rna.plotting.plot_mesh(self, *args, map_, **kwargs)
return rna.plotting.plot_tensor_map(self, map_, **kwargs) return rna.plotting.plot_tensor_map(self, *args, map_, **kwargs)
if __name__ == "__main__": # pragma: no cover if __name__ == "__main__": # pragma: no cover
......
...@@ -4,15 +4,26 @@ import functools ...@@ -4,15 +4,26 @@ import functools
import tfields.lib.util import tfields.lib.util
def ensure_complex(*base_vectors): def ensure_complex(*base_vectors) -> typing.List[typing.Tuple[float, float, complex]]:
# ensure, that the third entry in base_vector of type tuple becomes a complex type """
Ensure, that the third entry in base_vector of type tuple becomes a complex type.
The first two are mapped to float if they they are complex but with imag == 0.
"""
base_vectors = list(base_vectors) base_vectors = list(base_vectors)
for i, vector in enumerate(base_vectors): for i, vector in enumerate(base_vectors):
if isinstance(vector, tuple): if isinstance(vector, tuple):
if len(vector) == 3: if len(vector) == 3:
vector = list(vector) new_vector = []
vector[2] = complex(vector[2]) for j, entry in enumerate(vector):
base_vectors[i] = tuple(vector) # vector -> start, stop, n_steps
if isinstance(entry, complex) and entry.imag == 0:
# start and stop mapping to float
entry = entry.real
if j == 2:
# n_steps mapping to complex
entry = complex(entry)
new_vector.append(entry)
base_vectors[i] = tuple(new_vector)
return base_vectors return base_vectors
......
...@@ -217,6 +217,61 @@ def index(arr, entry, rtol=0, atol=0, equal_nan=False, axis=None): ...@@ -217,6 +217,61 @@ def index(arr, entry, rtol=0, atol=0, equal_nan=False, axis=None):
return i return i
def is_full_slice(index, shape):
"""
Determine if an index is the full slice (i.e. __getitem__ with this index returns the full
array) w.r.t the shape given.
Examples:
>>> import numpy as np
>>> import tfields
>>> class index_getter:
... def __getitem__(self, index):
... return index
>>> get_index = index_getter()
>>> a = np.array([[1, 0, 0], [1, 0, 0], [2, 3, 4]])
>>> shape = a.shape
>>> tfields.lib.util.is_full_slice(get_index[:], shape)
True
>>> tfields.lib.util.is_full_slice(get_index[:, :], shape)
True
>>> tfields.lib.util.is_full_slice(get_index[:, 1], shape)
False
>>> tfields.lib.util.is_full_slice(get_index[1:, :], shape)
False
>>> tfields.lib.util.is_full_slice(get_index[:1, :], shape)
False
>>> tfields.lib.util.is_full_slice(get_index[:, 1:], shape)
False
>>> tfields.lib.util.is_full_slice(get_index[:, :1], shape)
False
>>> tfields.lib.util.is_full_slice(get_index[:, :-1], shape)
True
>>> tfields.lib.util.is_full_slice(get_index[np.array([True, True, True])], shape)
True
>>> tfields.lib.util.is_full_slice(get_index[np.array([True, True, False])], shape)
False
"""
if isinstance(index, slice):
if (
index.step in (None, 1)
and index.start in (None, 0)
and index.stop in (None, -1, shape[0])
):
# full slice -> no type change
return True
else:
return False
elif isinstance(index, tuple):
return all((is_full_slice(ind, (shp,)) for ind, shp in zip(index, shape)))
elif isinstance(index, int):
return index == 0 and shape[0] == 1
elif isinstance(index, (np.ndarray, list)):
return all(index)
else:
raise NotImplementedError("Index Type %s", type(index))
if __name__ == "__main__": if __name__ == "__main__":
import doctest import doctest
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
Implementaiton of TensorGrid class Implementaiton of TensorGrid class
""" """
import numpy as np import numpy as np
from .lib import grid from .lib import grid, util
from .core import TensorFields from .core import TensorFields
...@@ -31,9 +31,11 @@ class TensorGrid(TensorFields): ...@@ -31,9 +31,11 @@ class TensorGrid(TensorFields):
else: else:
default_base_vectors = kwargs.pop("base_vectors", None) default_base_vectors = kwargs.pop("base_vectors", None)
default_iter_order = np.arange(len(default_base_vectors)) default_iter_order = np.arange(len(default_base_vectors))
base_vectors = kwargs.pop("base_vectors", default_base_vectors) base_vectors = [
iter_order = kwargs.pop("iter_order", default_iter_order) tuple(bv) for bv in kwargs.pop("base_vectors", default_base_vectors)
]
base_vectors = grid.ensure_complex(*base_vectors) base_vectors = grid.ensure_complex(*base_vectors)
iter_order = kwargs.pop("iter_order", default_iter_order)
obj = super(TensorGrid, cls).__new__(cls, tensors, *fields, **kwargs) obj = super(TensorGrid, cls).__new__(cls, tensors, *fields, **kwargs)
obj.base_vectors = base_vectors obj.base_vectors = base_vectors
obj.iter_order = iter_order obj.iter_order = iter_order
...@@ -43,24 +45,19 @@ class TensorGrid(TensorFields): ...@@ -43,24 +45,19 @@ class TensorGrid(TensorFields):
if not self.is_empty(): if not self.is_empty():
return super().__getitem__(index) return super().__getitem__(index)
item = self.explicit() item = self.explicit()
if isinstance(index, slice): if not util.is_full_slice(index, item.shape):
if ( # downgrade to TensorFields
index.step in (None, 1) item = TensorFields(item)
and index.start in (None, 0)
and index.stop in (None, -1, len(item))
):
# full slice -> no type change
pass
else:
# downgrade
item = TensorFields(item)
return item.__getitem__(index) return item.__getitem__(index)
@classmethod @classmethod
def from_base_vectors(cls, *base_vectors, tensors=None, fields=None, **kwargs): def grid(cls, *base_vectors, tensors=None, fields=None, **kwargs):
""" """
Build the grid (explicitly) from base vectors Build the grid (explicitly) from base vectors
Args:
explicit args: see __new__
**kwargs: see TensorFields
""" """
iter_order = kwargs.pop("iter_order", np.arange(len(base_vectors))) iter_order = kwargs.pop("iter_order", np.arange(len(base_vectors)))
if tensors is None: if tensors is None:
...@@ -80,7 +77,7 @@ class TensorGrid(TensorFields): ...@@ -80,7 +77,7 @@ class TensorGrid(TensorFields):
bv_lengths = [int(bv[2].imag) for bv in base_vectors] bv_lengths = [int(bv[2].imag) for bv in base_vectors]
tensors = np.empty(shape=(np.prod(bv_lengths), 0)) tensors = np.empty(shape=(np.prod(bv_lengths), 0))
return cls.from_base_vectors(*base_vectors, tensors=tensors, **kwargs) return cls.grid(*base_vectors, tensors=tensors, **kwargs)
@classmethod @classmethod
def merged(cls, *objects, **kwargs): def merged(cls, *objects, **kwargs):
...@@ -117,7 +114,7 @@ class TensorGrid(TensorFields): ...@@ -117,7 +114,7 @@ class TensorGrid(TensorFields):
""" """
kwargs = {attr: getattr(self, attr) for attr in self.__slots__} kwargs = {attr: getattr(self, attr) for attr in self.__slots__}
base_vectors = kwargs.pop("base_vectors") base_vectors = kwargs.pop("base_vectors")
return self.from_base_vectors(*base_vectors, **kwargs) return self.grid(*base_vectors, **kwargs)
def change_iter_order(self, iter_order): def change_iter_order(self, iter_order):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment