Commit 2e733190 authored by dboe's avatar dboe
Browse files

important core tests work

parent e87afa0f
...@@ -215,7 +215,17 @@ class Maps_Test(Base_Check, unittest.TestCase): ...@@ -215,7 +215,17 @@ class Maps_Test(Base_Check, unittest.TestCase):
class Container_Check(AbstractNdarray_Check): class Container_Check(AbstractNdarray_Check):
def demand_equal(self, other): def demand_equal(self, other):
raise NotImplementedError(self.__class__) for i, item in enumerate(self._inst.items):
if issubclass(type(item), tfields.core.AbstractNdarray):
self.assertTrue(other.items[i].equal(item))
else:
self.assertEqual(other.items[i], item)
try:
self._inst.labels[i]
except (IndexError, TypeError):
pass
else:
self.assertEqual(other.labels[i], self._inst.labels[i])
def test_item(self): def test_item(self):
if len(self._inst.items) > 0: if len(self._inst.items) > 0:
...@@ -231,7 +241,7 @@ class Container_Test(Container_Check, unittest.TestCase): ...@@ -231,7 +241,7 @@ class Container_Test(Container_Check, unittest.TestCase):
(-np.pi / 2, np.pi / 2, 3), (-np.pi / 2, np.pi / 2, 3),
coord_sys='spherical') coord_sys='spherical')
sphere2 = sphere.copy() * 3 sphere2 = sphere.copy() * 3
self._inst = tfields.Container([sphere, sphere2]) self._inst = tfields.Container([sphere, sphere2], labels=['test'])
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -151,7 +151,7 @@ class AbstractObject(object): ...@@ -151,7 +151,7 @@ class AbstractObject(object):
[42] [42]
""" """
content_dict = self._as_dict() content_dict = self._as_new_dict()
np.savez(path, **content_dict) np.savez(path, **content_dict)
@classmethod @classmethod
...@@ -164,7 +164,7 @@ class AbstractObject(object): ...@@ -164,7 +164,7 @@ class AbstractObject(object):
# wheter we could avoid pickling (potential security issue) # wheter we could avoid pickling (potential security issue)
load_kwargs.setdefault('allow_pickle', True) load_kwargs.setdefault('allow_pickle', True)
np_file = np.load(path, **load_kwargs) np_file = np.load(path, **load_kwargs)
return cls._from_dict(**np_file) return cls._from_new_dict(dict(np_file))
def _args(self) -> tuple: def _args(self) -> tuple:
return tuple() return tuple()
...@@ -218,17 +218,18 @@ class AbstractObject(object): ...@@ -218,17 +218,18 @@ class AbstractObject(object):
""" """
for attr in here: for attr in here:
for key in here[attr]: for key in here[attr]:
if len(here[attr][key]) == 1: if 'type' in here[attr][key]:
attr_value = here[attr][key].pop('')
else:
obj_type = here[attr][key].get("type") obj_type = here[attr][key].get("type")
# obj_type = bulk_type.tolist() was necessary before. no clue if isinstance(obj_type, np.ndarray): # happens on np.load
obj_type = obj_type.tolist()
if isinstance(obj_type, bytes): if isinstance(obj_type, bytes):
# asthonishingly, this is not necessary under linux. # asthonishingly, this is not necessary under linux.
# Found under nt. ??? # Found under nt. ???
obj_type = obj_type.decode("UTF-8") obj_type = obj_type.decode("UTF-8")
obj_type = getattr(tfields, obj_type) obj_type = getattr(tfields, obj_type)
attr_value = obj_type._from_new_dict(here[attr][key]) attr_value = obj_type._from_new_dict(here[attr][key])
else: # if len(here[attr][key]) == 1:
attr_value = here[attr][key].pop('')
here[attr][key] = attr_value here[attr][key] = attr_value
''' '''
...@@ -583,65 +584,6 @@ class AbstractNdarray(np.ndarray, AbstractObject): ...@@ -583,65 +584,6 @@ class AbstractNdarray(np.ndarray, AbstractObject):
return inst return inst
@classmethod
def _from_dict(cls, **d):
"""
legacy method
Opposite of old _as_dict
"""
list_dict = {}
kwargs = {}
"""
De-Flatten the first layer of lists
"""
for key in sorted(list(d)):
if "::" in key:
splits = key.split("::")
attr, _, end = key.partition("::")
if attr not in list_dict:
list_dict[attr] = {}
index, _, end = end.partition("::")
if not index.isdigit():
raise ValueError("None digit index given")
index = int(index)
if index not in list_dict[attr]:
list_dict[attr][index] = {}
list_dict[attr][index][end] = d[key]
else:
kwargs[key] = d[key]
"""
Build the lists (recursively)
"""
for key in list(list_dict):
sub_dict = list_dict[key]
list_dict[key] = []
for index in sorted(list(sub_dict)):
bulk_type = sub_dict[index].get("bulk_type")
bulk_type = bulk_type.tolist()
if isinstance(bulk_type, bytes):
# asthonishingly, this is not necessary under linux.
# Found under nt. ???
bulk_type = bulk_type.decode("UTF-8")
bulk_type = getattr(tfields, bulk_type)
list_dict[key].append(bulk_type._from_dict(**sub_dict[index]))
with cls._bypass_setters('fields', demand_existence=False):
'''
Build the normal way
'''
bulk = kwargs.pop('bulk')
bulk_type = kwargs.pop('bulk_type')
obj = cls.__new__(cls, bulk, **kwargs)
'''
Set list attributes
'''
for attr, list_value in list_dict.items():
setattr(obj, attr, list_value)
return obj
class Tensors(AbstractNdarray): class Tensors(AbstractNdarray):
""" """
...@@ -769,9 +711,11 @@ class Tensors(AbstractNdarray): ...@@ -769,9 +711,11 @@ class Tensors(AbstractNdarray):
if issubclass(type(tensors), np.ndarray): if issubclass(type(tensors), np.ndarray):
# np.empty # np.empty
pass pass
elif hasattr(tensors, 'shape'):
dim = dim(tensors)
else: else:
raise ValueError( raise ValueError(
"Empty tensors need dimension " "parameter 'dim'." "Empty tensors need dimension parameter 'dim'."
) )
tensors = np.asarray(tensors, dtype=dtype, order=order) tensors = np.asarray(tensors, dtype=dtype, order=order)
...@@ -828,13 +772,13 @@ class Tensors(AbstractNdarray): ...@@ -828,13 +772,13 @@ class Tensors(AbstractNdarray):
def merged(cls, *objects, **kwargs): def merged(cls, *objects, **kwargs):
""" """
Factory method Factory method
Merges all tensor inputs to one tensor Merges all input arguments to one object
Args: Args:
**kwargs: passed to cls
dim (int):
return_templates (bool): return the templates which can be used return_templates (bool): return the templates which can be used
together with cut to retrieve the original objects together with cut to retrieve the original objects
dim (int):
**kwargs: passed to cls
Examples: Examples:
>>> import numpy as np >>> import numpy as np
...@@ -928,13 +872,10 @@ class Tensors(AbstractNdarray): ...@@ -928,13 +872,10 @@ class Tensors(AbstractNdarray):
for i, obj in enumerate(remainingObjects): for i, obj in enumerate(remainingObjects):
tensors = np.append(tensors, obj, axis=0) tensors = np.append(tensors, obj, axis=0)
if len(tensors) == 0 and 'dim' not in kwargs: if len(tensors) == 0 and not kwargs.get('dim', None):
# if you can not determine the tensor dimension, search for the # if you can not determine the tensor dimension, search for the
# first object with some entries # first object with some entries
for obj in objects: kwargs['dim'] = dim(objects[0])
if len(obj) != 0:
kwargs['dim'] = dim(obj)
break
if not return_templates: if not return_templates:
return cls.__new__(cls, tensors, **kwargs) return cls.__new__(cls, tensors, **kwargs)
...@@ -2089,7 +2030,7 @@ class Container(AbstractNdarray): ...@@ -2089,7 +2030,7 @@ class Container(AbstractNdarray):
if issubclass(type(items), Container): if issubclass(type(items), Container):
kwargs.setdefault('labels', items.labels) kwargs.setdefault('labels', items.labels)
items = items.items items = items.items
kwargs["items"] = items kwargs["items"] = kwargs.pop('items_hack', items)
cls._update_slot_kwargs(kwargs) cls._update_slot_kwargs(kwargs)
empty = np.empty(0, int) empty = np.empty(0, int)
...@@ -2105,6 +2046,11 @@ class Container(AbstractNdarray): ...@@ -2105,6 +2046,11 @@ class Container(AbstractNdarray):
setattr(obj, attr, kwargs[attr]) setattr(obj, attr, kwargs[attr])
return obj return obj
def _kwargs(self):
d = super()._kwargs()
d['items_hack'] = d.pop('items') # hack for backwards compatibility
return d
def __getitem__(self, index): def __getitem__(self, index):
return self.items[index] return self.items[index]
...@@ -2138,9 +2084,17 @@ class Maps(sortedcontainers.SortedDict, AbstractObject): ...@@ -2138,9 +2084,17 @@ class Maps(sortedcontainers.SortedDict, AbstractObject):
sortedcontainers.SortedItemsView): sortedcontainers.SortedItemsView):
args = tuple([v for k, v in args[0]]) args = tuple([v for k, v in args[0]])
elif len(args) == 1 and isinstance(args[0], list): elif len(args) == 1 and isinstance(args[0], list):
if len(args[0]) > 0 and not issubclass(type(args[0][0]), tuple): if args[0]:
# Maps([mp, mp2])- not Maps([]) and not Maps([(key, value)]) # not not Maps([])
args = tuple(args[0]) if issubclass(type(args[0][0]), tuple):
# Maps([(key, value), (key, value), ...])
args = tuple(v for k,v in args[0])
else:
# Maps([mp, mp, ...])
args = tuple(args[0])
else:
# Maps([]) -> Maps()
args = tuple()
elif len(args) == 1 and issubclass(type(args[0]), dict): elif len(args) == 1 and issubclass(type(args[0]), dict):
# Maps([]) - includes Maps i.e. copy # Maps([]) - includes Maps i.e. copy
# dangerous because we run beefore super init # dangerous because we run beefore super init
...@@ -2380,7 +2334,7 @@ class TensorMaps(TensorFields): ...@@ -2380,7 +2334,7 @@ class TensorMaps(TensorFields):
cum_tensor_lengths = [sum(tensor_lengths[:i]) cum_tensor_lengths = [sum(tensor_lengths[:i])
for i in range(len(objects))] for i in range(len(objects))]
return_value = super(TensorMaps, cls).merged(*objects, **kwargs) return_value = super().merged(*objects, **kwargs)
return_templates = kwargs.get('return_templates', False) return_templates = kwargs.get('return_templates', False)
if return_templates: if return_templates:
inst, templates = return_value inst, templates = return_value
...@@ -2392,15 +2346,15 @@ class TensorMaps(TensorFields): ...@@ -2392,15 +2346,15 @@ class TensorMaps(TensorFields):
for dimension, mp in obj.maps.items(): for dimension, mp in obj.maps.items():
mp = mp + cum_tensor_lengths[i] mp = mp + cum_tensor_lengths[i]
if dimension not in dim_maps_dict: if dimension not in dim_maps_dict:
dim_maps_dict[dimension] = [] dim_maps_dict[dimension] = {}
dim_maps_dict[dimension].append(mp) dim_maps_dict[dimension][i] = mp
maps = [] maps = []
template_maps_list = [[] for i in range(len(objects))] template_maps_list = [[] for i in range(len(objects))]
for dimension in sorted(dim_maps_dict): for dimension in sorted(dim_maps_dict):
# sort by object index # sort by object index
obj_indices = sorted(dim_maps_dict[dimension].keys()) dim_maps = [dim_maps_dict[dimension][i]
dim_maps = [dim_maps_dict[dimension][i] for i in obj_indices] for i in range(len(objects))]
return_value = TensorFields.merged( return_value = TensorFields.merged(
*dim_maps, *dim_maps,
...@@ -2408,14 +2362,14 @@ class TensorMaps(TensorFields): ...@@ -2408,14 +2362,14 @@ class TensorMaps(TensorFields):
) )
if return_templates: if return_templates:
mp, dimension_map_templates = return_value mp, dimension_map_templates = return_value
for i in obj_indices: for i in range(len(objects)):
template_maps_list[i].append(dimension_map_templates[i]) template_maps_list[i].append(dimension_map_templates[i])
else: else:
mp = return_value mp = return_value
maps.append(mp) maps.append(mp)
inst = cls.__new__(cls, inst, maps=maps) inst.maps = maps
if 'return_templates' in kwargs: if return_templates:
for i, template_maps in enumerate(template_maps_list): for i, template_maps in enumerate(template_maps_list):
templates[i] = tfields.TensorMaps( templates[i] = tfields.TensorMaps(
templates[i], templates[i],
......
...@@ -233,7 +233,7 @@ class Mesh3D(tfields.TensorMaps): ...@@ -233,7 +233,7 @@ class Mesh3D(tfields.TensorMaps):
obj = super(Mesh3D, cls).__new__(cls, tensors, *fields, **kwargs) obj = super(Mesh3D, cls).__new__(cls, tensors, *fields, **kwargs)
if len(obj.maps) > 1: if len(obj.maps) > 1:
raise ValueError("Mesh3D only allows one map") raise ValueError("Mesh3D only allows one map")
if obj.maps and obj.maps[0].dim != 3: if obj.maps and (len(obj.maps) > 1 or obj.maps.keys()[0] != 3):
raise ValueError("Face dimension should be 3") raise ValueError("Face dimension should be 3")
return obj return obj
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment