diff --git a/setup.cfg b/setup.cfg index be477f5500741f3e2d6eba8d1593478698f02dc9..ff786503bc7690e7ee5d1d4e7b39afbf66fd7aed 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,7 +48,7 @@ packages = find: install_requires = pathlib;python_version<'3.10' six - numpy + numpy>=1.20.0 sympy<=1.6.2 # diffgeom changes in 1.7 see unit branch for first implementation of compatibility scipy rna>=0.6.3 diff --git a/tests/test_core.py b/tests/test_core.py index be2a588d8dfb8747a35aaf526fd4766767f75bb4..b703af11fc043a1dde6d2443acecc043d8002b9f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,6 +1,5 @@ # pylint:disable=missing-function-docstring,missing-module-docstring,missing-class-docstring,invalid-name,protected-access import pickle -import pathlib import unittest import uuid from tempfile import NamedTemporaryFile @@ -337,14 +336,15 @@ class TensorMaps_Test(TensorFields_Check, unittest.TestCase): tfields.TensorMaps([], maps=tfields.Maps([]), dim=3), ] - def test_legacy(self): - this_dir = pathlib.Path(__file__).parent - legacy_file = ( - this_dir - / "resources/TensorMaps_0.2.1_ce3ea1fb69058dc39815be65f485abebb487a6bd.npz" - ) # NOQA - tm = tfields.TensorMaps.load(legacy_file) - self.assertTrue(self._inst.equal(tm)) + # Removed legacy AFTER commit 3fe8d037655f17bc7985a11e9fb64dd9c3d54b7e + # def test_legacy(self): + # this_dir = pathlib.Path(__file__).parent + # legacy_file = ( + # this_dir + # / "resources/TensorMaps_0.2.1_ce3ea1fb69058dc39815be65f485abebb487a6bd.npz" + # ) # NOQA + # tm = tfields.TensorMaps.load(legacy_file) + # self.assertTrue(self._inst.equal(tm)) class TensorMaps_Indexing_Test(unittest.TestCase): @@ -396,7 +396,13 @@ class TensorMaps_NoFields_Test(Tensors_Check, unittest.TestCase): class Maps_Test(Base_Check, unittest.TestCase): def demand_equal(self, other): super().demand_equal(other) - self._inst.equal(other) + self.assertTrue(self._inst.equal(other)) + + def test_dict(self): + super().test_dict() + dict_ = self._inst._as_dict() + self.assertIsInstance(dict_['args::0::args::0::args::1::args::0'], np.ndarray) + self.assertNotIsInstance(dict_['args::0::args::0::args::1::args::0'], tfields.Tensors) def setUp(self): self._inst = tfields.Maps( diff --git a/tfields/core.py b/tfields/core.py index 02797192d4e01201033359415d1f1fc619cf2f3f..5ef52bdba9d8bfde8207467fea7fcd4a7178cc1a 100644 --- a/tfields/core.py +++ b/tfields/core.py @@ -17,6 +17,8 @@ TODO: """ # builtin import warnings +import typing +import builtins import inspect from contextlib import contextmanager from collections import Counter @@ -55,6 +57,63 @@ def dim(tensor): return tensor.shape[1] +_HIERARCHY_SEPARATOR = "::" + + +def _from_dict(content: dict, eval_type=None): + # build the nested dict from the chierarchy_flattened dict keys + here = {} + for string in content: + if string == "type": + continue + + value = content[string] + + attr, _, end = string.partition(_HIERARCHY_SEPARATOR) + key, _, end = end.partition(_HIERARCHY_SEPARATOR) + if attr not in here: + here[attr] = {} + if key not in here[attr]: + here[attr][key] = {} + here[attr][key][end] = value + + if "type" not in content: + return here[attr][key].pop("") + + obj_type = content.get("type") + # build type for recursion + if isinstance(obj_type, np.ndarray): # happens on np.load + obj_type = obj_type.tolist() + if isinstance(obj_type, bytes): + # astonishingly, this is not necessary under linux. + # Found under nt. ??? + obj_type = obj_type.decode("UTF-8") + + try: + cls = getattr(builtins, obj_type) + except AttributeError: + cls = getattr(tfields, obj_type) + + # Do the recursion + # pylint:disable=consider-using-dict-items + for attr in here: + # pylint:disable=consider-using-dict-items + for key in here[attr]: + here[attr][key] = _from_dict(here[attr][key]) + + # Build the generic way + args = here.pop("args", tuple()) + args = tuple(args[key] for key in sorted(args)) + kwargs = here.pop("kwargs", {}) + assert len(here) == 0 + if cls in (tuple, list): + # TODO: remove this in favour of better pacing tuple? + obj = cls(args, **kwargs) + return obj + obj = cls(*args, **kwargs) + return obj + + class AbstractObject(rna.polymorphism.Storable): """ Abstract base class for all tfields objects implementing polymorphisms @@ -135,13 +194,14 @@ class AbstractObject(rna.polymorphism.Storable): """ return dict() - _HIERARCHY_SEPARATOR = "::" - - def _as_dict(self) -> dict: + def _as_dict(self, recurse: typing.Dict[str, typing.Callable] = None) -> dict: """ Get an object represenation in a dict format. This is necessary e.g. for saving the full file uniquely in the npz format + Args: + recurse: dict of {attribute: callable(iterable) -> dict} + Returns: dict: object packed as nested dictionary """ @@ -156,12 +216,19 @@ class AbstractObject(rna.polymorphism.Storable): ("kwargs", self._kwargs().items()), ]: for attr, value in iterable: - attr = base_attr + self._HIERARCHY_SEPARATOR + attr - if hasattr(value, "_as_dict"): - part_dict = value._as_dict() # pylint: disable=protected-access + attr = base_attr + _HIERARCHY_SEPARATOR + attr + if ( + (recurse is not None and attr in recurse) + or hasattr(value, "_as_dict") + ): + if recurse is not None and attr in recurse: + part_dict = recurse[attr](value) + else: + part_dict = value._as_dict() # pylint: disable=protected-access + for part_attr, part_value in part_dict.items(): content[ - attr + self._HIERARCHY_SEPARATOR + part_attr + attr + _HIERARCHY_SEPARATOR + part_attr ] = part_value else: content[attr] = value @@ -169,110 +236,10 @@ class AbstractObject(rna.polymorphism.Storable): @classmethod def _from_dict(cls, content: dict): - try: - content.pop("type") - except KeyError: - # legacy - return cls._from_dict_legacy(**content) - - here = {} - for string in content: # TODO no sortelist - value = content[string] - - attr, _, end = string.partition(cls._HIERARCHY_SEPARATOR) - key, _, end = end.partition(cls._HIERARCHY_SEPARATOR) - if attr not in here: - here[attr] = {} - if key not in here[attr]: - here[attr][key] = {} - here[attr][key][end] = value - - # Do the recursion - # pylint:disable=consider-using-dict-items - for attr in here: - # pylint:disable=consider-using-dict-items - for key in here[attr]: - if "type" in here[attr][key]: - obj_type = here[attr][key].get("type") - if isinstance(obj_type, np.ndarray): # happens on np.load - obj_type = obj_type.tolist() - if isinstance(obj_type, bytes): - # astonishingly, this is not necessary under linux. - # Found under nt. ??? - obj_type = obj_type.decode("UTF-8") - obj_type = getattr(tfields, obj_type) - attr_value = ( - obj_type._from_dict( # pylint: disable=protected-access - here[attr][key] - ) - ) - else: # if len(here[attr][key]) == 1: - attr_value = here[attr][key].pop("") - here[attr][key] = attr_value - - # Build the generic way - args = here.pop("args", tuple()) - args = tuple(args[key] for key in sorted(args)) - kwargs = here.pop("kwargs", {}) - assert len(here) == 0 - obj = cls(*args, **kwargs) - return obj - - @classmethod - def _from_dict_legacy(cls, **content): - """ - legacy method of _from_dict - Opposite of old _as_dict method - which is overridden in this version - """ - list_dict = {} - kwargs = {} - # De-Flatten the first layer of lists - for key in sorted(list(content)): - if "::" in key: - attr, _, end = key.partition("::") - if attr not in list_dict: - list_dict[attr] = {} - - index, _, end = end.partition("::") - if not index.isdigit(): - raise ValueError("None digit index given") - index = int(index) - if index not in list_dict[attr]: - list_dict[attr][index] = {} - list_dict[attr][index][end] = content[key] - else: - kwargs[key] = content[key] - - # Build the lists (recursively) - for key in list(list_dict): - sub_dict = list_dict[key] - list_dict[key] = [] - for index in sorted(list(sub_dict)): - bulk_type = sub_dict[index].get("bulk_type") - bulk_type = bulk_type.tolist() - if isinstance(bulk_type, bytes): - # asthonishingly, this is not necessary under linux. - # Found under nt. ??? - bulk_type = bulk_type.decode("UTF-8") - bulk_type = getattr(tfields, bulk_type) - list_dict[key].append( - bulk_type._from_dict_legacy( # noqa: E501 pylint: disable=protected-access - **sub_dict[index] - ) - ) + type_ = content.get("type") + assert type_ == cls.__name__ - with cls._bypass_setters( # pylint: disable=protected-access,no-member - "fields", demand_existence=False - ): - # Build the normal way - bulk = kwargs.pop("bulk") - bulk_type = kwargs.pop("bulk_type") - obj = cls.__new__(cls, bulk, **kwargs) - - # Set list attributes - for attr, list_value in list_dict.items(): - setattr(obj, attr, list_value) - return obj + return _from_dict(content) class AbstractNdarray(np.ndarray, AbstractObject): @@ -1952,7 +1919,7 @@ class TensorFields(Tensors): def _kwargs(self) -> dict: content = super()._kwargs() - content.pop("fields") + content.pop("fields") # instantiated via _args return content def __getitem__(self, index): @@ -2411,6 +2378,31 @@ class Maps(sortedcontainers.SortedDict, AbstractObject): def _args(self): return super()._args() + (list(self.items()),) + def _as_dict(self, recurse: typing.Dict[str, typing.Callable] = None) -> dict: + if recurse is None: + recurse = {} + + def recurse_args_0(iterable: typing.List[typing.Tuple[int, typing.Any]]) -> dict: + # iterable is list of tuple + part_dict = {"type": "list"} + for i, (dim, tensor) in enumerate(iterable): + content = tensor._as_dict() + tuple_key = _HIERARCHY_SEPARATOR.join(["args", str(i), ""]) + part_dict[tuple_key + "type"] = "tuple" + args_key = tuple_key + _HIERARCHY_SEPARATOR.join(["args", ""]) + + part_dict[args_key + _HIERARCHY_SEPARATOR.join(["0", "args", "0"])] = dim + part_dict[args_key + _HIERARCHY_SEPARATOR.join(["0", "type"])] = 'int' + + for key, value in content.items(): + part_dict[args_key + _HIERARCHY_SEPARATOR.join(["1", key])] = value + + return part_dict + + attr = 'args' + _HIERARCHY_SEPARATOR + str(0) + recurse[attr] = recurse_args_0 + return super()._as_dict(recurse=recurse) + def equal(self, other, **kwargs): """ Test equality with other object. diff --git a/tfields/lib/util.py b/tfields/lib/util.py index c71978fe64fd86f54367226875532557f139cc73..cf365652f3efe699b24abb4c923a7bb4dc3723c2 100644 --- a/tfields/lib/util.py +++ b/tfields/lib/util.py @@ -2,6 +2,7 @@ Various utility functions """ import itertools +import typing from six import string_types import numpy as np @@ -20,15 +21,18 @@ def pairwise(iterable): return zip(a, b) -def flatten(seq, container=None, keep_types=None): +def flatten(seq, container=None, keep_types=None, key: typing.Callable = None): """ Approach to flatten a nested sequence. + Args: seq (iterable): iterable to be flattened containter (iterable): iterable defining an append method. Values will be appended there keep_types (list of type): types that should not be flattened but kept in nested form + key (callable): callable with the signature key(iterable) -> iterable + Examples: >>> from tfields.lib.util import flatten >>> import numpy as np @@ -43,21 +47,36 @@ def flatten(seq, container=None, keep_types=None): >>> flatten([[0, 0, 0, 'A'], [1, 2, 3]]) [0, 0, 0, 'A', 1, 2, 3] + Dictionaries will return flattened keys + >>> flatten({"a": 1, "b": 2}) + ['a', 'b'] + + You can use the key keyword to specify a transformation on the iterable: + >>> flatten({"a": {"a1": 1, "a2": 4}, "b": 2}, key=dict.values) + [1, 4, 2] + + >>> def dict_flat_key(item): + ... if isinstance(item, dict): + ... return item.values() + ... return item + >>> flatten({"a": {"a1": 1, "a2": [3, 4]}, "b": 2}, key=dict_flat_key) + [1, 3, 4, 2] """ if keep_types is None: keep_types = [] if container is None: container = [] - # pylint:disable=invalid-name - for s in seq: + if key is not None: + seq = key(seq) + for item in seq: if ( - hasattr(s, "__iter__") - and not isinstance(s, string_types) - and not any((isinstance(s, t) for t in keep_types)) + hasattr(item, "__iter__") + and not isinstance(item, string_types) + and not any((isinstance(item, t) for t in keep_types)) ): - flatten(s, container, keep_types) + flatten(item, container, keep_types, key=key) else: - container.append(s) + container.append(item) return container