diff --git a/test/test_core.py b/test/test_core.py index b0c89e95e2fcd2248118c99e711c8e465c09f9ca..5d17c83472660092817c8f848f7fe7e5550caf97 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -215,7 +215,17 @@ class Maps_Test(Base_Check, unittest.TestCase): class Container_Check(AbstractNdarray_Check): def demand_equal(self, other): - raise NotImplementedError(self.__class__) + for i, item in enumerate(self._inst.items): + if issubclass(type(item), tfields.core.AbstractNdarray): + self.assertTrue(other.items[i].equal(item)) + else: + self.assertEqual(other.items[i], item) + try: + self._inst.labels[i] + except (IndexError, TypeError): + pass + else: + self.assertEqual(other.labels[i], self._inst.labels[i]) def test_item(self): if len(self._inst.items) > 0: @@ -231,7 +241,7 @@ class Container_Test(Container_Check, unittest.TestCase): (-np.pi / 2, np.pi / 2, 3), coord_sys='spherical') sphere2 = sphere.copy() * 3 - self._inst = tfields.Container([sphere, sphere2]) + self._inst = tfields.Container([sphere, sphere2], labels=['test']) if __name__ == '__main__': diff --git a/tfields/core.py b/tfields/core.py index da0f815c7640b4b248af55a9ba4ebf09ae3373d9..88f9107b9d91c1555eef8aa4e22aa5d176157c17 100644 --- a/tfields/core.py +++ b/tfields/core.py @@ -151,7 +151,7 @@ class AbstractObject(object): [42] """ - content_dict = self._as_dict() + content_dict = self._as_new_dict() np.savez(path, **content_dict) @classmethod @@ -164,7 +164,7 @@ class AbstractObject(object): # wheter we could avoid pickling (potential security issue) load_kwargs.setdefault('allow_pickle', True) np_file = np.load(path, **load_kwargs) - return cls._from_dict(**np_file) + return cls._from_new_dict(dict(np_file)) def _args(self) -> tuple: return tuple() @@ -218,17 +218,18 @@ class AbstractObject(object): """ for attr in here: for key in here[attr]: - if len(here[attr][key]) == 1: - attr_value = here[attr][key].pop('') - else: + if 'type' in here[attr][key]: obj_type = here[attr][key].get("type") - # obj_type = bulk_type.tolist() was necessary before. no clue + if isinstance(obj_type, np.ndarray): # happens on np.load + obj_type = obj_type.tolist() if isinstance(obj_type, bytes): # asthonishingly, this is not necessary under linux. # Found under nt. ??? obj_type = obj_type.decode("UTF-8") obj_type = getattr(tfields, obj_type) attr_value = obj_type._from_new_dict(here[attr][key]) + else: # if len(here[attr][key]) == 1: + attr_value = here[attr][key].pop('') here[attr][key] = attr_value ''' @@ -583,65 +584,6 @@ class AbstractNdarray(np.ndarray, AbstractObject): return inst - @classmethod - def _from_dict(cls, **d): - """ - legacy method - Opposite of old _as_dict - """ - list_dict = {} - kwargs = {} - """ - De-Flatten the first layer of lists - """ - for key in sorted(list(d)): - if "::" in key: - splits = key.split("::") - attr, _, end = key.partition("::") - if attr not in list_dict: - list_dict[attr] = {} - - index, _, end = end.partition("::") - if not index.isdigit(): - raise ValueError("None digit index given") - index = int(index) - if index not in list_dict[attr]: - list_dict[attr][index] = {} - list_dict[attr][index][end] = d[key] - else: - kwargs[key] = d[key] - - """ - Build the lists (recursively) - """ - for key in list(list_dict): - sub_dict = list_dict[key] - list_dict[key] = [] - for index in sorted(list(sub_dict)): - bulk_type = sub_dict[index].get("bulk_type") - bulk_type = bulk_type.tolist() - if isinstance(bulk_type, bytes): - # asthonishingly, this is not necessary under linux. - # Found under nt. ??? - bulk_type = bulk_type.decode("UTF-8") - bulk_type = getattr(tfields, bulk_type) - list_dict[key].append(bulk_type._from_dict(**sub_dict[index])) - - with cls._bypass_setters('fields', demand_existence=False): - ''' - Build the normal way - ''' - bulk = kwargs.pop('bulk') - bulk_type = kwargs.pop('bulk_type') - obj = cls.__new__(cls, bulk, **kwargs) - - ''' - Set list attributes - ''' - for attr, list_value in list_dict.items(): - setattr(obj, attr, list_value) - return obj - class Tensors(AbstractNdarray): """ @@ -769,9 +711,11 @@ class Tensors(AbstractNdarray): if issubclass(type(tensors), np.ndarray): # np.empty pass + elif hasattr(tensors, 'shape'): + dim = dim(tensors) else: raise ValueError( - "Empty tensors need dimension " "parameter 'dim'." + "Empty tensors need dimension parameter 'dim'." ) tensors = np.asarray(tensors, dtype=dtype, order=order) @@ -828,13 +772,13 @@ class Tensors(AbstractNdarray): def merged(cls, *objects, **kwargs): """ Factory method - Merges all tensor inputs to one tensor + Merges all input arguments to one object Args: - **kwargs: passed to cls - dim (int): return_templates (bool): return the templates which can be used together with cut to retrieve the original objects + dim (int): + **kwargs: passed to cls Examples: >>> import numpy as np @@ -928,13 +872,10 @@ class Tensors(AbstractNdarray): for i, obj in enumerate(remainingObjects): tensors = np.append(tensors, obj, axis=0) - if len(tensors) == 0 and 'dim' not in kwargs: + if len(tensors) == 0 and not kwargs.get('dim', None): # if you can not determine the tensor dimension, search for the # first object with some entries - for obj in objects: - if len(obj) != 0: - kwargs['dim'] = dim(obj) - break + kwargs['dim'] = dim(objects[0]) if not return_templates: return cls.__new__(cls, tensors, **kwargs) @@ -2089,7 +2030,7 @@ class Container(AbstractNdarray): if issubclass(type(items), Container): kwargs.setdefault('labels', items.labels) items = items.items - kwargs["items"] = items + kwargs["items"] = kwargs.pop('items_hack', items) cls._update_slot_kwargs(kwargs) empty = np.empty(0, int) @@ -2105,6 +2046,11 @@ class Container(AbstractNdarray): setattr(obj, attr, kwargs[attr]) return obj + def _kwargs(self): + d = super()._kwargs() + d['items_hack'] = d.pop('items') # hack for backwards compatibility + return d + def __getitem__(self, index): return self.items[index] @@ -2138,9 +2084,17 @@ class Maps(sortedcontainers.SortedDict, AbstractObject): sortedcontainers.SortedItemsView): args = tuple([v for k, v in args[0]]) elif len(args) == 1 and isinstance(args[0], list): - if len(args[0]) > 0 and not issubclass(type(args[0][0]), tuple): - # Maps([mp, mp2])- not Maps([]) and not Maps([(key, value)]) - args = tuple(args[0]) + if args[0]: + # not not Maps([]) + if issubclass(type(args[0][0]), tuple): + # Maps([(key, value), (key, value), ...]) + args = tuple(v for k,v in args[0]) + else: + # Maps([mp, mp, ...]) + args = tuple(args[0]) + else: + # Maps([]) -> Maps() + args = tuple() elif len(args) == 1 and issubclass(type(args[0]), dict): # Maps([]) - includes Maps i.e. copy # dangerous because we run beefore super init @@ -2380,7 +2334,7 @@ class TensorMaps(TensorFields): cum_tensor_lengths = [sum(tensor_lengths[:i]) for i in range(len(objects))] - return_value = super(TensorMaps, cls).merged(*objects, **kwargs) + return_value = super().merged(*objects, **kwargs) return_templates = kwargs.get('return_templates', False) if return_templates: inst, templates = return_value @@ -2392,15 +2346,15 @@ class TensorMaps(TensorFields): for dimension, mp in obj.maps.items(): mp = mp + cum_tensor_lengths[i] if dimension not in dim_maps_dict: - dim_maps_dict[dimension] = [] - dim_maps_dict[dimension].append(mp) + dim_maps_dict[dimension] = {} + dim_maps_dict[dimension][i] = mp maps = [] template_maps_list = [[] for i in range(len(objects))] for dimension in sorted(dim_maps_dict): # sort by object index - obj_indices = sorted(dim_maps_dict[dimension].keys()) - dim_maps = [dim_maps_dict[dimension][i] for i in obj_indices] + dim_maps = [dim_maps_dict[dimension][i] + for i in range(len(objects))] return_value = TensorFields.merged( *dim_maps, @@ -2408,14 +2362,14 @@ class TensorMaps(TensorFields): ) if return_templates: mp, dimension_map_templates = return_value - for i in obj_indices: + for i in range(len(objects)): template_maps_list[i].append(dimension_map_templates[i]) else: mp = return_value maps.append(mp) - inst = cls.__new__(cls, inst, maps=maps) - if 'return_templates' in kwargs: + inst.maps = maps + if return_templates: for i, template_maps in enumerate(template_maps_list): templates[i] = tfields.TensorMaps( templates[i], diff --git a/tfields/mesh3D.py b/tfields/mesh3D.py index a35b631341a1cdea0eb6d5738caba9823ed54dbd..9747c29354ce78fe17001ed49b8d1ebb85b321e2 100644 --- a/tfields/mesh3D.py +++ b/tfields/mesh3D.py @@ -233,7 +233,7 @@ class Mesh3D(tfields.TensorMaps): obj = super(Mesh3D, cls).__new__(cls, tensors, *fields, **kwargs) if len(obj.maps) > 1: raise ValueError("Mesh3D only allows one map") - if obj.maps and obj.maps[0].dim != 3: + if obj.maps and (len(obj.maps) > 1 or obj.maps.keys()[0] != 3): raise ValueError("Face dimension should be 3") return obj