Commit 0e699865 authored by dboe's avatar dboe
Browse files

indexing fields by name now possible

parent 87b011cd
Pipeline #107185 passed with stage
in 23 seconds
# pylint:disable=missing_function_docstring
import numpy as np
from tempfile import NamedTemporaryFile
import pickle
import pathlib
import unittest
import tfields
import uuid
ATOL = 1e-8
# pylint:disable=no-member
class Base_Check(object):
def demand_equal(self, other):
self.assertIsInstance(other, type(self._inst))
......@@ -61,6 +64,7 @@ class AbstractNdarray_Check(Base_Check):
pass
# pylint:disable=no-member
class Tensors_Check(AbstractNdarray_Check):
"""
Testing derivatives of Tensors
......@@ -165,6 +169,7 @@ class Tensors_Check(AbstractNdarray_Check):
self.assertTrue(value)
# pylint:disable=no-member
class TensorFields_Check(Tensors_Check):
def test_fields(self):
self.assertIsNotNone(self._inst.fields)
......@@ -173,10 +178,13 @@ class TensorFields_Check(Tensors_Check):
self.assertTrue(isinstance(self._inst.fields, list))
self.assertTrue(len(self._inst.fields) == len(self._fields))
for field, target_field in zip(self._inst.fields, self._fields):
self.assertTrue(np.array_equal(field, target_field))
# fields are copied not reffered by a pointer
self.assertFalse(field is target_field)
self.check_fields_equal(self._inst.fields, self._fields)
def check_fields_equal(self, fields_a, fields_b):
for field, target_field in zip(self._inst.fields, self._fields):
self.assertTrue(np.array_equal(field, target_field))
# fields are copied not reffered by a pointer
self.assertFalse(field is target_field)
def demand_index_equal(self, index, check_type):
super().demand_index_equal(index, check_type)
......@@ -199,6 +207,43 @@ class TensorFields_Check(Tensors_Check):
for i in range(len(self._inst.fields)):
self.assertIsNot(self._inst.fields[i], other.fields[i])
def test_list_like_field(self):
if self._inst.fields:
fields = self._inst.fields
self._inst.fields = []
for i, field in enumerate(fields):
self._inst.fields.append(field)
# indexing
self.assertTrue(self._inst.fields[i].equal(field))
self.check_fields_equal(fields, self._inst.fields)
def test_field_name_getitem(self):
if self._inst.fields:
for i, field in enumerate(self._inst.fields):
if field.name is not None:
self.assertTrue(self._inst.fields[field.name].equal(field))
def test_field_name_setitem(self):
if self._inst.fields:
fields = self._inst.fields
self._inst.fields = []
for i, field in enumerate(fields):
if field.name is not None:
name = field.name
else:
name = str(uuid.uuid4())
# setitem via fields
self._inst.fields[name] = field
field_item = self._inst.fields[name]
self.assertTrue(field_item.equal(field))
self.assertTrue(self._inst.fields[i].equal(field_item))
self.assertEqual(field_item.name, name)
self.check_fields_equal(fields, self._inst.fields)
class TensorMaps_Check(TensorFields_Check):
def test_maps(self):
......@@ -245,7 +290,7 @@ class TensorMaps_Empty_Test(TensorMaps_Check, unittest.TestCase):
self._maps_fields = []
class TensorFields_Test(Tensors_Check, unittest.TestCase):
class TensorFields_Test(TensorFields_Check, unittest.TestCase):
def setUp(self):
base = [(-5, 5, 11)] * 3
self._fields = [
......@@ -259,7 +304,7 @@ class TensorFields_Test(Tensors_Check, unittest.TestCase):
self.assertTrue(self._fields[1].coord_sys, "cartesian")
class TensorMaps_Test(Tensors_Check, unittest.TestCase):
class TensorMaps_Test(TensorFields_Check, unittest.TestCase):
def setUp(self):
base = [(-1, 1, 3)] * 3
tensors = tfields.Tensors.grid(*base)
......
......@@ -1694,7 +1694,7 @@ class Tensors(AbstractNdarray): # pylint: disable=too-many-public-methods
return artist
def as_tensors_list(tensors_list):
def as_fields(fields):
"""
Setter for TensorFields.fields
Copies input
......@@ -1711,13 +1711,18 @@ def as_tensors_list(tensors_list):
>>> assert mesh.maps[3].fields[0].equal([42, 21])
"""
if tensors_list is not None:
new_list = []
for tensors in tensors_list:
tensors_list = Tensors(tensors)
new_list.append(tensors_list)
tensors_list = new_list
return tensors_list
if fields is None:
# None key passed e.g. by copy. We do not change keys here.
fields = []
if fields is not None:
new_fields = []
for tensors in fields:
tensors = Tensors(tensors)
new_fields.append(tensors)
new_fields = Fields(new_fields)
return new_fields
def as_maps(maps):
......@@ -1807,7 +1812,7 @@ class TensorFields(Tensors):
"""
__slots__ = ["coord_sys", "name", "fields"]
__slot_setters__ = [tfields.bases.get_coord_system_name, None, as_tensors_list]
__slot_setters__ = [tfields.bases.get_coord_system_name, None, as_fields]
def __new__(cls, tensors, *fields, **kwargs):
rigid = kwargs.pop("rigid", True)
......@@ -2082,9 +2087,35 @@ class Fields(list, AbstractObject):
them tensor fields
"""
# TODO: maybe change type from list to dict laster on - no integer indexing
def _args(self):
return super()._args() + tuple(self)
def __setitem__(self, index, value):
if isinstance(index, str):
if not hasattr(value, "name"):
raise TypeError(
f"Value type {type(value)} does not support the 'name' field"
)
value.name = index
for i, field in enumerate(self):
if hasattr(field, "name") and field.name == index:
index = i
break
else:
self.append(value)
return
super().__setitem__(index, value)
def __getitem__(self, index):
if isinstance(index, str):
for i, field in enumerate(self):
if hasattr(field, "name") and field.name == index:
return field
return super().__getitem__(index)
class Container(Fields):
"""
......@@ -2270,7 +2301,7 @@ class TensorMaps(TensorFields):
__slot_setters__ = [
tfields.bases.get_coord_system_name,
None,
as_tensors_list,
as_fields,
as_maps,
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment