From 0e699865119ec35f89f682512f640678dd22f05d Mon Sep 17 00:00:00 2001
From: dboe <dboe@ipp.mpg.de>
Date: Fri, 6 Aug 2021 13:53:44 +0200
Subject: [PATCH] indexing fields by name now possible

---
 tests/test_core.py | 57 +++++++++++++++++++++++++++++++++++++++++-----
 tfields/core.py    | 51 +++++++++++++++++++++++++++++++++--------
 2 files changed, 92 insertions(+), 16 deletions(-)

diff --git a/tests/test_core.py b/tests/test_core.py
index ff794da..b760d0d 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -1,13 +1,16 @@
+# pylint:disable=missing_function_docstring
 import numpy as np
 from tempfile import NamedTemporaryFile
 import pickle
 import pathlib
 import unittest
 import tfields
+import uuid
 
 ATOL = 1e-8
 
 
+# pylint:disable=no-member
 class Base_Check(object):
     def demand_equal(self, other):
         self.assertIsInstance(other, type(self._inst))
@@ -61,6 +64,7 @@ class AbstractNdarray_Check(Base_Check):
     pass
 
 
+# pylint:disable=no-member
 class Tensors_Check(AbstractNdarray_Check):
     """
     Testing derivatives of Tensors
@@ -165,6 +169,7 @@ class Tensors_Check(AbstractNdarray_Check):
             self.assertTrue(value)
 
 
+# pylint:disable=no-member
 class TensorFields_Check(Tensors_Check):
     def test_fields(self):
         self.assertIsNotNone(self._inst.fields)
@@ -173,10 +178,13 @@ class TensorFields_Check(Tensors_Check):
             self.assertTrue(isinstance(self._inst.fields, list))
             self.assertTrue(len(self._inst.fields) == len(self._fields))
 
-            for field, target_field in zip(self._inst.fields, self._fields):
-                self.assertTrue(np.array_equal(field, target_field))
-                # fields are copied not reffered by a pointer
-                self.assertFalse(field is target_field)
+            self.check_fields_equal(self._inst.fields, self._fields)
+
+    def check_fields_equal(self, fields_a, fields_b):
+        for field, target_field in zip(self._inst.fields, self._fields):
+            self.assertTrue(np.array_equal(field, target_field))
+            # fields are copied not reffered by a pointer
+            self.assertFalse(field is target_field)
 
     def demand_index_equal(self, index, check_type):
         super().demand_index_equal(index, check_type)
@@ -199,6 +207,43 @@ class TensorFields_Check(Tensors_Check):
         for i in range(len(self._inst.fields)):
             self.assertIsNot(self._inst.fields[i], other.fields[i])
 
+    def test_list_like_field(self):
+        if self._inst.fields:
+            fields = self._inst.fields
+            self._inst.fields = []
+            for i, field in enumerate(fields):
+                self._inst.fields.append(field)
+
+                # indexing
+                self.assertTrue(self._inst.fields[i].equal(field))
+
+            self.check_fields_equal(fields, self._inst.fields)
+
+    def test_field_name_getitem(self):
+        if self._inst.fields:
+            for i, field in enumerate(self._inst.fields):
+                if field.name is not None:
+                    self.assertTrue(self._inst.fields[field.name].equal(field))
+
+    def test_field_name_setitem(self):
+        if self._inst.fields:
+            fields = self._inst.fields
+            self._inst.fields = []
+            for i, field in enumerate(fields):
+                if field.name is not None:
+                    name = field.name
+                else:
+                    name = str(uuid.uuid4())
+                # setitem via fields
+                self._inst.fields[name] = field
+
+                field_item = self._inst.fields[name]
+                self.assertTrue(field_item.equal(field))
+                self.assertTrue(self._inst.fields[i].equal(field_item))
+                self.assertEqual(field_item.name, name)
+
+            self.check_fields_equal(fields, self._inst.fields)
+
 
 class TensorMaps_Check(TensorFields_Check):
     def test_maps(self):
@@ -245,7 +290,7 @@ class TensorMaps_Empty_Test(TensorMaps_Check, unittest.TestCase):
         self._maps_fields = []
 
 
-class TensorFields_Test(Tensors_Check, unittest.TestCase):
+class TensorFields_Test(TensorFields_Check, unittest.TestCase):
     def setUp(self):
         base = [(-5, 5, 11)] * 3
         self._fields = [
@@ -259,7 +304,7 @@ class TensorFields_Test(Tensors_Check, unittest.TestCase):
         self.assertTrue(self._fields[1].coord_sys, "cartesian")
 
 
-class TensorMaps_Test(Tensors_Check, unittest.TestCase):
+class TensorMaps_Test(TensorFields_Check, unittest.TestCase):
     def setUp(self):
         base = [(-1, 1, 3)] * 3
         tensors = tfields.Tensors.grid(*base)
diff --git a/tfields/core.py b/tfields/core.py
index 2956fcc..cfc6810 100644
--- a/tfields/core.py
+++ b/tfields/core.py
@@ -1694,7 +1694,7 @@ class Tensors(AbstractNdarray):  # pylint: disable=too-many-public-methods
         return artist
 
 
-def as_tensors_list(tensors_list):
+def as_fields(fields):
     """
     Setter for TensorFields.fields
     Copies input
@@ -1711,13 +1711,18 @@ def as_tensors_list(tensors_list):
         >>> assert mesh.maps[3].fields[0].equal([42, 21])
 
     """
-    if tensors_list is not None:
-        new_list = []
-        for tensors in tensors_list:
-            tensors_list = Tensors(tensors)
-            new_list.append(tensors_list)
-        tensors_list = new_list
-    return tensors_list
+    if fields is None:
+        # None key passed e.g. by copy. We do not change keys here.
+        fields = []
+
+    if fields is not None:
+        new_fields = []
+        for tensors in fields:
+            tensors = Tensors(tensors)
+            new_fields.append(tensors)
+
+    new_fields = Fields(new_fields)
+    return new_fields
 
 
 def as_maps(maps):
@@ -1807,7 +1812,7 @@ class TensorFields(Tensors):
     """
 
     __slots__ = ["coord_sys", "name", "fields"]
-    __slot_setters__ = [tfields.bases.get_coord_system_name, None, as_tensors_list]
+    __slot_setters__ = [tfields.bases.get_coord_system_name, None, as_fields]
 
     def __new__(cls, tensors, *fields, **kwargs):
         rigid = kwargs.pop("rigid", True)
@@ -2082,9 +2087,35 @@ class Fields(list, AbstractObject):
     them tensor fields
     """
 
+    # TODO: maybe change type from list to dict laster on - no integer indexing
     def _args(self):
         return super()._args() + tuple(self)
 
+    def __setitem__(self, index, value):
+        if isinstance(index, str):
+            if not hasattr(value, "name"):
+                raise TypeError(
+                    f"Value type {type(value)} does not support the 'name' field"
+                )
+            value.name = index
+
+            for i, field in enumerate(self):
+                if hasattr(field, "name") and field.name == index:
+                    index = i
+                    break
+            else:
+                self.append(value)
+                return
+
+        super().__setitem__(index, value)
+
+    def __getitem__(self, index):
+        if isinstance(index, str):
+            for i, field in enumerate(self):
+                if hasattr(field, "name") and field.name == index:
+                    return field
+        return super().__getitem__(index)
+
 
 class Container(Fields):
     """
@@ -2270,7 +2301,7 @@ class TensorMaps(TensorFields):
     __slot_setters__ = [
         tfields.bases.get_coord_system_name,
         None,
-        as_tensors_list,
+        as_fields,
         as_maps,
     ]
 
-- 
GitLab