Commit f17a7d97 authored by dboe's avatar dboe
Browse files

sorted test_core

parent 3c0ecff7
......@@ -7,35 +7,70 @@ import tfields
ATOL = 1e-8
class Base_Check(object):
class AbstractNdarray_Check(object):
def demand_equal(self, other):
raise NotImplementedError(self.__class__)
def test_pickle(self):
with NamedTemporaryFile(suffix='.pickle') as out_file:
pickle.dump(self._inst,
out_file)
out_file.flush()
out_file.seek(0)
reloaded = pickle.load(out_file)
self.demand_equal(reloaded)
def test_save_npz(self):
out_file = NamedTemporaryFile(suffix='.npz')
self._inst.save(out_file.name)
_ = out_file.seek(0) # this is only necessary in the test
load_inst = self._inst.__class__.load(out_file.name)
# allow_pickle=True) ?
self.demand_equal(load_inst)
def tearDown(self):
del self._inst
class Tensors_Check(AbstractNdarray_Check):
"""
Testing derivatives of Tensors
"""
_inst = None
def demand_equal(self, other, atol=False, transformed=False):
if atol:
self.assertTrue(self._inst.equal(other, atol=ATOL))
else:
self.assertTrue(self._inst.equal(other))
if not transformed:
self.assertEqual(self._inst.coord_sys, other.coord_sys)
self.assertEqual(self._inst.name, other.name)
def test_self_equality(self):
# Test equality
self.assertTrue(self._inst.equal(self._inst))
self.demand_equal(self._inst)
transformer = self._inst.copy()
transformer.transform(tfields.bases.CYLINDER)
self.demand_equal(transformer, atol=True, transformed=True)
def test_cylinderTrafo(self):
# Test coordinate transformations in circle
transformer = self._inst.copy()
transformer.transform(tfields.bases.CYLINDER)
self.assertTrue(tfields.Tensors(self._inst).equal(transformer, atol=ATOL))
self.assertTrue(self._inst.equal(transformer, atol=ATOL))
if len(self._inst) > 0:
self.assertFalse(np.array_equal(self._inst, transformer))
transformer.transform(tfields.bases.CARTESIAN)
self.assertTrue(self._inst.equal(transformer, atol=ATOL))
self.demand_equal(transformer, atol=True, transformed=True)
def test_spericalTrafo(self):
# Test coordinate transformations in circle
transformer = self._inst.copy()
transformer.transform(tfields.bases.SPHERICAL)
transformer.transform(tfields.bases.CARTESIAN)
self.assertTrue(self._inst.equal(transformer, atol=ATOL))
self.demand_equal(transformer, atol=True, transformed=True)
def test_basic_merge(self):
# create 3 copies with different coord_sys
......@@ -65,22 +100,8 @@ class Base_Check(object):
atol=ATOL)
self.assertTrue(value)
def test_pickle(self):
with NamedTemporaryFile(suffix='.pickle') as out_file:
pickle.dump(self._inst,
out_file)
out_file.flush()
out_file.seek(0)
reloaded = pickle.load(out_file)
self.assertTrue(self._inst.equal(reloaded))
def tearDown(self):
del self._inst
class Tensor_Fields_Check(object):
class TensorFields_Check(Tensors_Check):
def test_fields(self):
# field is of type list
self.assertTrue(isinstance(self._inst.fields, list))
......@@ -92,32 +113,34 @@ class Tensor_Fields_Check(object):
self.assertFalse(field is target_field)
class TensorMaps_Check(TensorFields_Check):
def test_maps(self):
self.assertIsNotNone(self._inst.maps)
"""
EMPTY TESTS
"""
class Tensors_Empty_Test(Base_Check, unittest.TestCase):
class Tensors_Empty_Test(Tensors_Check, unittest.TestCase):
def setUp(self):
self._inst = tfields.Tensors([], dim=3)
class TensorFields_Empty_Test(Tensors_Empty_Test, Tensor_Fields_Check):
class TensorFields_Empty_Test(TensorFields_Check, unittest.TestCase):
def setUp(self):
self._fields = []
self._inst = tfields.TensorFields([], dim=3)
class TensorMaps_Empty_Test(TensorFields_Empty_Test):
class TensorMaps_Empty_Test(TensorMaps_Check, unittest.TestCase):
def setUp(self):
self._fields = []
self._inst = tfields.TensorMaps([], dim=3)
self._maps = []
self._maps_fields = []
def test_maps(self):
self.assertIsNotNone(self._inst.maps)
class TensorFields_Copy_Test(TensorFields_Empty_Test):
def setUp(self):
......@@ -153,5 +176,38 @@ class TensorMaps_Copy_Test(TensorMaps_Empty_Test):
maps=self._maps)
class Container_Check(AbstractNdarray_Check):
def demand_equal(self, other):
raise NotImplementedError(self.__class__)
def test_item(self):
if len(self._inst.items) > 0:
self.assertEqual(len(self._inst), len(self._inst))
self.assertEqual(type(self._inst), type(self._inst))
class Container_Test(Container_Check, unittest.TestCase):
def setUp(self):
sphere = tfields.Mesh3D.grid(
(1, 1, 1),
(-np.pi, np.pi, 3),
(-np.pi / 2, np.pi / 2, 3),
coord_sys='spherical')
sphere2 = sphere.copy() * 3
self._inst = tfields.Container([sphere, sphere2])
class ContainerFolded_Test(Container_Check, unittest.TestCase):
def setUp(self):
sphere = tfields.Mesh3D.grid(
(1, 1, 1),
(-np.pi, np.pi, 3),
(-np.pi / 2, np.pi / 2, 3),
coord_sys='spherical')
sphere2 = sphere.copy() * 3
self._container = tfields.Container([sphere, sphere2])
self._inst = tfields.Container(self._container)
if __name__ == '__main__':
unittest.main()
......@@ -4,13 +4,13 @@ import unittest
import sympy # NOQA: F401
import os
import sys
from .test_core import Base_Check
from .test_core import Tensors_Check
THIS_DIR = os.path.dirname(
os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(THIS_DIR)))
class Sphere_Test(Base_Check, unittest.TestCase):
class Sphere_Test(Tensors_Check, unittest.TestCase):
def setUp(self):
self._inst = tfields.Mesh3D.grid(
(1, 1, 1),
......
......@@ -12,48 +12,53 @@ class Base_Check(object):
self.assertTrue(len(templates), len(self._instances))
for template, inst in zip(templates, self._instances):
merged_cut = merged.cut(template)
self.assertEqual(len(inst.maps), len(merged_cut.maps))
self.assertEqual(len(merged_cut.maps), len(template.maps))
for i, mp in enumerate(inst.maps):
self.assertEqual(len(mp),
len(merged_cut.maps[i]))
self.assertEqual(tfields.core.dim(mp),
tfields.core.dim(merged_cut.maps[i]))
self.assertEqual(tfields.core.dim(template.maps[i]),
tfields.core.dim(merged_cut.maps[i]))
self.assertTrue(tfields.TensorFields(inst).equal(
tfields.TensorFields(merged_cut)))
self.assertTrue(inst.equal(merged_cut))
# class Tensors_Empty_Test(Base_Check, unittest.TestCase):
# def setUp(self):
# self._instances = [tfields.Tensors([], dim=3) for i in range(3)]
#
#
# class TensorFields_Empty_Test(Base_Check, unittest.TestCase):
# def setUp(self):
# self._fields = []
# self._instances = [tfields.TensorFields([], dim=3) for i in range(3)]
# class TensorMaps_Empty_Test(Base_Check, unittest.TestCase):
# def setUp(self):
# self._instances = [tfields.TensorMaps([], dim=3) for i in range(3)]
#class TensorFields_Test(TensorFields_Empty_Test):
# def setUp(self):
# base = [(-5, 5, 7)] * 3
# self._fields = [tfields.Tensors.grid(*base, coord_sys='cylinder'),
# tfields.Tensors(range(7**3))]
# tensors = tfields.Tensors.grid(*base)
# self._instances = [tfields.TensorFields(tensors, *self._fields)
# for i in range(3)]
# class TensorMaps_Test(TensorMaps_Empty_Test):
self._check_maps(inst, template, merged_cut)
def _check_maps(self, inst, template, merged_cut):
pass
class Tensors_Empty_Test(Base_Check, unittest.TestCase):
def setUp(self):
self._instances = [tfields.Tensors([], dim=3) for i in range(3)]
class TensorFields_Empty_Test(Base_Check, unittest.TestCase):
def setUp(self):
self._fields = []
self._instances = [tfields.TensorFields([], dim=3) for i in range(3)]
class TensorMaps_Empty_Test(Base_Check, unittest.TestCase):
def setUp(self):
self._instances = [tfields.TensorMaps([], dim=3) for i in range(3)]
def _check_maps(self, inst, template, merged_cut):
self.assertEqual(len(inst.maps), len(merged_cut.maps))
self.assertEqual(len(merged_cut.maps), len(template.maps))
for i, mp in enumerate(inst.maps):
self.assertEqual(len(mp),
len(merged_cut.maps[i]))
self.assertEqual(tfields.core.dim(mp),
tfields.core.dim(merged_cut.maps[i]))
self.assertEqual(tfields.core.dim(template.maps[i]),
tfields.core.dim(merged_cut.maps[i]))
self.assertTrue(tfields.TensorFields(inst).equal(
tfields.TensorFields(merged_cut)))
self.assertTrue(inst.equal(merged_cut))
class TensorFields_Test(TensorFields_Empty_Test):
def setUp(self):
base = [(-5, 5, 7)] * 3
self._fields = [tfields.Tensors.grid(*base, coord_sys='cylinder'),
tfields.Tensors(range(7**3))]
tensors = tfields.Tensors.grid(*base)
self._instances = [tfields.TensorFields(tensors, *self._fields)
for i in range(3)]
class TensorMaps_Test(TensorMaps_Empty_Test):
def setUp(self):
base = [(-1, 1, 3)] * 3
tensors = tfields.Tensors.grid(*base)
......
......@@ -220,7 +220,9 @@ class AbstractNdarray(np.ndarray):
@classmethod
@contextmanager
def _bypass_setters(cls, *slots, empty_means_all=True):
def _bypass_setters(cls, *slots,
empty_means_all=True,
demand_existence=False):
"""
Temporarily remove the setter in __slot_setters__ corresponding to slot
position in __slot__. You should know what you do, when using this.
......@@ -229,13 +231,22 @@ class AbstractNdarray(np.ndarray):
*slots (str): attribute names in __slots__
empty_means_all (bool): defines behaviour when slots is empty.
When True: if slots is empty mute all slots in __slots__
demand_existence (bool): if false do not check the existence of the
slot in __slots__ - do nothing for that slot. Handle with care!
"""
if not slots and empty_means_all:
slots = cls.__slots__
slot_indices = []
setters = []
for slot in slots:
slot_index = cls.__slots__.index(slot)
slot_index = cls.__slots__.index(slot)\
if slot in cls.__slots__ else None
if slot_index is None:
# slot not in cls.__slots__.
if demand_existence:
raise ValueError(
"Slot {slot} not existing".format(**locals()))
continue
if len(cls.__slot_setters__) < slot_index + 1:
# no setter to be found
continue
......@@ -466,7 +477,7 @@ class AbstractNdarray(np.ndarray):
bulk_type = getattr(tfields, bulk_type)
list_dict[key].append(bulk_type._from_dict(**sub_dict[index]))
with cls._bypass_setters('fields'):
with cls._bypass_setters('fields', demand_existence=False):
'''
Build the normal way
'''
......@@ -1722,11 +1733,12 @@ class TensorFields(Tensors):
index = index[0]
if item.fields:
# circumvent the setter here.
with self._bypass_setters('fields'):
with self._bypass_setters('fields',
demand_existence=False):
item.fields = [
field.__getitem__(index) for field in item.fields
]
except IndexError as err:
except IndexError as err: # noqa: F841
warnings.warn(
"Index error occured for field.__getitem__. Error "
"message: {err}".format(**locals())
......@@ -1953,10 +1965,9 @@ class Maps(Container):
A Maps object is a container for TensorFields sorted by dimension.
"""
def __new__(cls, maps, **kwargs):
if not issubclass(type(maps), Maps):
dims = [dim(obj) for obj in maps]
dims, maps = tfields.lib.util.multi_sort(dims, maps)
kwargs['labels'] = dims
if issubclass(type(maps), Maps):
kwargs['labels'] = maps.labels
maps = maps.items
maps_cp = []
for mp in maps:
......@@ -1969,8 +1980,16 @@ class Maps(Container):
maps = maps_cp
obj = super().__new__(cls, maps, **kwargs)
# obj._update()
return obj
def _update(self):
maps = self.items
dims = [dim(obj) for obj in maps]
dims, maps = tfields.lib.util.multi_sort(dims, maps)
self.items = maps
self.labels = dims
@property
def dims(self):
return self.labels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment