Commit 0e699865 authored by dboe's avatar dboe
Browse files

indexing fields by name now possible

parent 87b011cd
Pipeline #107185 passed with stage
in 23 seconds
# pylint:disable=missing_function_docstring
import numpy as np import numpy as np
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
import pickle import pickle
import pathlib import pathlib
import unittest import unittest
import tfields import tfields
import uuid
ATOL = 1e-8 ATOL = 1e-8
# pylint:disable=no-member
class Base_Check(object): class Base_Check(object):
def demand_equal(self, other): def demand_equal(self, other):
self.assertIsInstance(other, type(self._inst)) self.assertIsInstance(other, type(self._inst))
...@@ -61,6 +64,7 @@ class AbstractNdarray_Check(Base_Check): ...@@ -61,6 +64,7 @@ class AbstractNdarray_Check(Base_Check):
pass pass
# pylint:disable=no-member
class Tensors_Check(AbstractNdarray_Check): class Tensors_Check(AbstractNdarray_Check):
""" """
Testing derivatives of Tensors Testing derivatives of Tensors
...@@ -165,6 +169,7 @@ class Tensors_Check(AbstractNdarray_Check): ...@@ -165,6 +169,7 @@ class Tensors_Check(AbstractNdarray_Check):
self.assertTrue(value) self.assertTrue(value)
# pylint:disable=no-member
class TensorFields_Check(Tensors_Check): class TensorFields_Check(Tensors_Check):
def test_fields(self): def test_fields(self):
self.assertIsNotNone(self._inst.fields) self.assertIsNotNone(self._inst.fields)
...@@ -173,10 +178,13 @@ class TensorFields_Check(Tensors_Check): ...@@ -173,10 +178,13 @@ class TensorFields_Check(Tensors_Check):
self.assertTrue(isinstance(self._inst.fields, list)) self.assertTrue(isinstance(self._inst.fields, list))
self.assertTrue(len(self._inst.fields) == len(self._fields)) self.assertTrue(len(self._inst.fields) == len(self._fields))
for field, target_field in zip(self._inst.fields, self._fields): self.check_fields_equal(self._inst.fields, self._fields)
self.assertTrue(np.array_equal(field, target_field))
# fields are copied not reffered by a pointer def check_fields_equal(self, fields_a, fields_b):
self.assertFalse(field is target_field) for field, target_field in zip(self._inst.fields, self._fields):
self.assertTrue(np.array_equal(field, target_field))
# fields are copied not reffered by a pointer
self.assertFalse(field is target_field)
def demand_index_equal(self, index, check_type): def demand_index_equal(self, index, check_type):
super().demand_index_equal(index, check_type) super().demand_index_equal(index, check_type)
...@@ -199,6 +207,43 @@ class TensorFields_Check(Tensors_Check): ...@@ -199,6 +207,43 @@ class TensorFields_Check(Tensors_Check):
for i in range(len(self._inst.fields)): for i in range(len(self._inst.fields)):
self.assertIsNot(self._inst.fields[i], other.fields[i]) self.assertIsNot(self._inst.fields[i], other.fields[i])
def test_list_like_field(self):
if self._inst.fields:
fields = self._inst.fields
self._inst.fields = []
for i, field in enumerate(fields):
self._inst.fields.append(field)
# indexing
self.assertTrue(self._inst.fields[i].equal(field))
self.check_fields_equal(fields, self._inst.fields)
def test_field_name_getitem(self):
if self._inst.fields:
for i, field in enumerate(self._inst.fields):
if field.name is not None:
self.assertTrue(self._inst.fields[field.name].equal(field))
def test_field_name_setitem(self):
if self._inst.fields:
fields = self._inst.fields
self._inst.fields = []
for i, field in enumerate(fields):
if field.name is not None:
name = field.name
else:
name = str(uuid.uuid4())
# setitem via fields
self._inst.fields[name] = field
field_item = self._inst.fields[name]
self.assertTrue(field_item.equal(field))
self.assertTrue(self._inst.fields[i].equal(field_item))
self.assertEqual(field_item.name, name)
self.check_fields_equal(fields, self._inst.fields)
class TensorMaps_Check(TensorFields_Check): class TensorMaps_Check(TensorFields_Check):
def test_maps(self): def test_maps(self):
...@@ -245,7 +290,7 @@ class TensorMaps_Empty_Test(TensorMaps_Check, unittest.TestCase): ...@@ -245,7 +290,7 @@ class TensorMaps_Empty_Test(TensorMaps_Check, unittest.TestCase):
self._maps_fields = [] self._maps_fields = []
class TensorFields_Test(Tensors_Check, unittest.TestCase): class TensorFields_Test(TensorFields_Check, unittest.TestCase):
def setUp(self): def setUp(self):
base = [(-5, 5, 11)] * 3 base = [(-5, 5, 11)] * 3
self._fields = [ self._fields = [
...@@ -259,7 +304,7 @@ class TensorFields_Test(Tensors_Check, unittest.TestCase): ...@@ -259,7 +304,7 @@ class TensorFields_Test(Tensors_Check, unittest.TestCase):
self.assertTrue(self._fields[1].coord_sys, "cartesian") self.assertTrue(self._fields[1].coord_sys, "cartesian")
class TensorMaps_Test(Tensors_Check, unittest.TestCase): class TensorMaps_Test(TensorFields_Check, unittest.TestCase):
def setUp(self): def setUp(self):
base = [(-1, 1, 3)] * 3 base = [(-1, 1, 3)] * 3
tensors = tfields.Tensors.grid(*base) tensors = tfields.Tensors.grid(*base)
......
...@@ -1694,7 +1694,7 @@ class Tensors(AbstractNdarray): # pylint: disable=too-many-public-methods ...@@ -1694,7 +1694,7 @@ class Tensors(AbstractNdarray): # pylint: disable=too-many-public-methods
return artist return artist
def as_tensors_list(tensors_list): def as_fields(fields):
""" """
Setter for TensorFields.fields Setter for TensorFields.fields
Copies input Copies input
...@@ -1711,13 +1711,18 @@ def as_tensors_list(tensors_list): ...@@ -1711,13 +1711,18 @@ def as_tensors_list(tensors_list):
>>> assert mesh.maps[3].fields[0].equal([42, 21]) >>> assert mesh.maps[3].fields[0].equal([42, 21])
""" """
if tensors_list is not None: if fields is None:
new_list = [] # None key passed e.g. by copy. We do not change keys here.
for tensors in tensors_list: fields = []
tensors_list = Tensors(tensors)
new_list.append(tensors_list) if fields is not None:
tensors_list = new_list new_fields = []
return tensors_list for tensors in fields:
tensors = Tensors(tensors)
new_fields.append(tensors)
new_fields = Fields(new_fields)
return new_fields
def as_maps(maps): def as_maps(maps):
...@@ -1807,7 +1812,7 @@ class TensorFields(Tensors): ...@@ -1807,7 +1812,7 @@ class TensorFields(Tensors):
""" """
__slots__ = ["coord_sys", "name", "fields"] __slots__ = ["coord_sys", "name", "fields"]
__slot_setters__ = [tfields.bases.get_coord_system_name, None, as_tensors_list] __slot_setters__ = [tfields.bases.get_coord_system_name, None, as_fields]
def __new__(cls, tensors, *fields, **kwargs): def __new__(cls, tensors, *fields, **kwargs):
rigid = kwargs.pop("rigid", True) rigid = kwargs.pop("rigid", True)
...@@ -2082,9 +2087,35 @@ class Fields(list, AbstractObject): ...@@ -2082,9 +2087,35 @@ class Fields(list, AbstractObject):
them tensor fields them tensor fields
""" """
# TODO: maybe change type from list to dict laster on - no integer indexing
def _args(self): def _args(self):
return super()._args() + tuple(self) return super()._args() + tuple(self)
def __setitem__(self, index, value):
if isinstance(index, str):
if not hasattr(value, "name"):
raise TypeError(
f"Value type {type(value)} does not support the 'name' field"
)
value.name = index
for i, field in enumerate(self):
if hasattr(field, "name") and field.name == index:
index = i
break
else:
self.append(value)
return
super().__setitem__(index, value)
def __getitem__(self, index):
if isinstance(index, str):
for i, field in enumerate(self):
if hasattr(field, "name") and field.name == index:
return field
return super().__getitem__(index)
class Container(Fields): class Container(Fields):
""" """
...@@ -2270,7 +2301,7 @@ class TensorMaps(TensorFields): ...@@ -2270,7 +2301,7 @@ class TensorMaps(TensorFields):
__slot_setters__ = [ __slot_setters__ = [
tfields.bases.get_coord_system_name, tfields.bases.get_coord_system_name,
None, None,
as_tensors_list, as_fields,
as_maps, as_maps,
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment