From 0e699865119ec35f89f682512f640678dd22f05d Mon Sep 17 00:00:00 2001 From: dboe <dboe@ipp.mpg.de> Date: Fri, 6 Aug 2021 13:53:44 +0200 Subject: [PATCH] indexing fields by name now possible --- tests/test_core.py | 57 +++++++++++++++++++++++++++++++++++++++++----- tfields/core.py | 51 +++++++++++++++++++++++++++++++++-------- 2 files changed, 92 insertions(+), 16 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index ff794da..b760d0d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,13 +1,16 @@ +# pylint:disable=missing_function_docstring import numpy as np from tempfile import NamedTemporaryFile import pickle import pathlib import unittest import tfields +import uuid ATOL = 1e-8 +# pylint:disable=no-member class Base_Check(object): def demand_equal(self, other): self.assertIsInstance(other, type(self._inst)) @@ -61,6 +64,7 @@ class AbstractNdarray_Check(Base_Check): pass +# pylint:disable=no-member class Tensors_Check(AbstractNdarray_Check): """ Testing derivatives of Tensors @@ -165,6 +169,7 @@ class Tensors_Check(AbstractNdarray_Check): self.assertTrue(value) +# pylint:disable=no-member class TensorFields_Check(Tensors_Check): def test_fields(self): self.assertIsNotNone(self._inst.fields) @@ -173,10 +178,13 @@ class TensorFields_Check(Tensors_Check): self.assertTrue(isinstance(self._inst.fields, list)) self.assertTrue(len(self._inst.fields) == len(self._fields)) - for field, target_field in zip(self._inst.fields, self._fields): - self.assertTrue(np.array_equal(field, target_field)) - # fields are copied not reffered by a pointer - self.assertFalse(field is target_field) + self.check_fields_equal(self._inst.fields, self._fields) + + def check_fields_equal(self, fields_a, fields_b): + for field, target_field in zip(self._inst.fields, self._fields): + self.assertTrue(np.array_equal(field, target_field)) + # fields are copied not reffered by a pointer + self.assertFalse(field is target_field) def demand_index_equal(self, index, check_type): super().demand_index_equal(index, check_type) @@ -199,6 +207,43 @@ class TensorFields_Check(Tensors_Check): for i in range(len(self._inst.fields)): self.assertIsNot(self._inst.fields[i], other.fields[i]) + def test_list_like_field(self): + if self._inst.fields: + fields = self._inst.fields + self._inst.fields = [] + for i, field in enumerate(fields): + self._inst.fields.append(field) + + # indexing + self.assertTrue(self._inst.fields[i].equal(field)) + + self.check_fields_equal(fields, self._inst.fields) + + def test_field_name_getitem(self): + if self._inst.fields: + for i, field in enumerate(self._inst.fields): + if field.name is not None: + self.assertTrue(self._inst.fields[field.name].equal(field)) + + def test_field_name_setitem(self): + if self._inst.fields: + fields = self._inst.fields + self._inst.fields = [] + for i, field in enumerate(fields): + if field.name is not None: + name = field.name + else: + name = str(uuid.uuid4()) + # setitem via fields + self._inst.fields[name] = field + + field_item = self._inst.fields[name] + self.assertTrue(field_item.equal(field)) + self.assertTrue(self._inst.fields[i].equal(field_item)) + self.assertEqual(field_item.name, name) + + self.check_fields_equal(fields, self._inst.fields) + class TensorMaps_Check(TensorFields_Check): def test_maps(self): @@ -245,7 +290,7 @@ class TensorMaps_Empty_Test(TensorMaps_Check, unittest.TestCase): self._maps_fields = [] -class TensorFields_Test(Tensors_Check, unittest.TestCase): +class TensorFields_Test(TensorFields_Check, unittest.TestCase): def setUp(self): base = [(-5, 5, 11)] * 3 self._fields = [ @@ -259,7 +304,7 @@ class TensorFields_Test(Tensors_Check, unittest.TestCase): self.assertTrue(self._fields[1].coord_sys, "cartesian") -class TensorMaps_Test(Tensors_Check, unittest.TestCase): +class TensorMaps_Test(TensorFields_Check, unittest.TestCase): def setUp(self): base = [(-1, 1, 3)] * 3 tensors = tfields.Tensors.grid(*base) diff --git a/tfields/core.py b/tfields/core.py index 2956fcc..cfc6810 100644 --- a/tfields/core.py +++ b/tfields/core.py @@ -1694,7 +1694,7 @@ class Tensors(AbstractNdarray): # pylint: disable=too-many-public-methods return artist -def as_tensors_list(tensors_list): +def as_fields(fields): """ Setter for TensorFields.fields Copies input @@ -1711,13 +1711,18 @@ def as_tensors_list(tensors_list): >>> assert mesh.maps[3].fields[0].equal([42, 21]) """ - if tensors_list is not None: - new_list = [] - for tensors in tensors_list: - tensors_list = Tensors(tensors) - new_list.append(tensors_list) - tensors_list = new_list - return tensors_list + if fields is None: + # None key passed e.g. by copy. We do not change keys here. + fields = [] + + if fields is not None: + new_fields = [] + for tensors in fields: + tensors = Tensors(tensors) + new_fields.append(tensors) + + new_fields = Fields(new_fields) + return new_fields def as_maps(maps): @@ -1807,7 +1812,7 @@ class TensorFields(Tensors): """ __slots__ = ["coord_sys", "name", "fields"] - __slot_setters__ = [tfields.bases.get_coord_system_name, None, as_tensors_list] + __slot_setters__ = [tfields.bases.get_coord_system_name, None, as_fields] def __new__(cls, tensors, *fields, **kwargs): rigid = kwargs.pop("rigid", True) @@ -2082,9 +2087,35 @@ class Fields(list, AbstractObject): them tensor fields """ + # TODO: maybe change type from list to dict laster on - no integer indexing def _args(self): return super()._args() + tuple(self) + def __setitem__(self, index, value): + if isinstance(index, str): + if not hasattr(value, "name"): + raise TypeError( + f"Value type {type(value)} does not support the 'name' field" + ) + value.name = index + + for i, field in enumerate(self): + if hasattr(field, "name") and field.name == index: + index = i + break + else: + self.append(value) + return + + super().__setitem__(index, value) + + def __getitem__(self, index): + if isinstance(index, str): + for i, field in enumerate(self): + if hasattr(field, "name") and field.name == index: + return field + return super().__getitem__(index) + class Container(Fields): """ @@ -2270,7 +2301,7 @@ class TensorMaps(TensorFields): __slot_setters__ = [ tfields.bases.get_coord_system_name, None, - as_tensors_list, + as_fields, as_maps, ] -- GitLab