From a1ad0d1448ea4cb752ad2f0925fdf47f54805df3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20B=C3=B6ckenhoff?= <dboe@ipp.mpg.de>
Date: Thu, 13 Jul 2023 10:59:18 +0200
Subject: [PATCH] fix: maps serialization

---
 setup.cfg           |   2 +-
 tests/test_core.py  |  26 +++---
 tfields/core.py     | 214 +++++++++++++++++++++-----------------------
 tfields/lib/util.py |  35 ++++++--
 4 files changed, 147 insertions(+), 130 deletions(-)

diff --git a/setup.cfg b/setup.cfg
index be477f5..ff78650 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -48,7 +48,7 @@ packages = find:
 install_requires = 
 	pathlib;python_version<'3.10'
 	six
-	numpy
+	numpy>=1.20.0
 	sympy<=1.6.2  # diffgeom changes in 1.7 see unit branch for first implementation of compatibility
 	scipy
 	rna>=0.6.3
diff --git a/tests/test_core.py b/tests/test_core.py
index be2a588..b703af1 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -1,6 +1,5 @@
 # pylint:disable=missing-function-docstring,missing-module-docstring,missing-class-docstring,invalid-name,protected-access
 import pickle
-import pathlib
 import unittest
 import uuid
 from tempfile import NamedTemporaryFile
@@ -337,14 +336,15 @@ class TensorMaps_Test(TensorFields_Check, unittest.TestCase):
             tfields.TensorMaps([], maps=tfields.Maps([]), dim=3),
         ]
 
-    def test_legacy(self):
-        this_dir = pathlib.Path(__file__).parent
-        legacy_file = (
-            this_dir
-            / "resources/TensorMaps_0.2.1_ce3ea1fb69058dc39815be65f485abebb487a6bd.npz"
-        )  # NOQA
-        tm = tfields.TensorMaps.load(legacy_file)
-        self.assertTrue(self._inst.equal(tm))
+    # Removed legacy AFTER commit 3fe8d037655f17bc7985a11e9fb64dd9c3d54b7e
+    # def test_legacy(self):
+    #     this_dir = pathlib.Path(__file__).parent
+    #     legacy_file = (
+    #         this_dir
+    #         / "resources/TensorMaps_0.2.1_ce3ea1fb69058dc39815be65f485abebb487a6bd.npz"
+    #     )  # NOQA
+    #     tm = tfields.TensorMaps.load(legacy_file)
+    #     self.assertTrue(self._inst.equal(tm))
 
 
 class TensorMaps_Indexing_Test(unittest.TestCase):
@@ -396,7 +396,13 @@ class TensorMaps_NoFields_Test(Tensors_Check, unittest.TestCase):
 class Maps_Test(Base_Check, unittest.TestCase):
     def demand_equal(self, other):
         super().demand_equal(other)
-        self._inst.equal(other)
+        self.assertTrue(self._inst.equal(other))
+
+    def test_dict(self):
+        super().test_dict()
+        dict_ = self._inst._as_dict()
+        self.assertIsInstance(dict_['args::0::args::0::args::1::args::0'], np.ndarray)
+        self.assertNotIsInstance(dict_['args::0::args::0::args::1::args::0'], tfields.Tensors)
 
     def setUp(self):
         self._inst = tfields.Maps(
diff --git a/tfields/core.py b/tfields/core.py
index 0279719..5ef52bd 100644
--- a/tfields/core.py
+++ b/tfields/core.py
@@ -17,6 +17,8 @@ TODO:
 """
 # builtin
 import warnings
+import typing
+import builtins
 import inspect
 from contextlib import contextmanager
 from collections import Counter
@@ -55,6 +57,63 @@ def dim(tensor):
     return tensor.shape[1]
 
 
+_HIERARCHY_SEPARATOR = "::"
+
+
+def _from_dict(content: dict, eval_type=None):
+    # build the nested dict from the chierarchy_flattened dict keys
+    here = {}
+    for string in content:
+        if string == "type":
+            continue
+
+        value = content[string]
+
+        attr, _, end = string.partition(_HIERARCHY_SEPARATOR)
+        key, _, end = end.partition(_HIERARCHY_SEPARATOR)
+        if attr not in here:
+            here[attr] = {}
+        if key not in here[attr]:
+            here[attr][key] = {}
+        here[attr][key][end] = value
+
+    if "type" not in content:
+        return here[attr][key].pop("")
+
+    obj_type = content.get("type")
+    # build type for recursion
+    if isinstance(obj_type, np.ndarray):  # happens on np.load
+        obj_type = obj_type.tolist()
+    if isinstance(obj_type, bytes):
+        # astonishingly, this is not necessary under linux.
+        # Found under nt. ???
+        obj_type = obj_type.decode("UTF-8")
+
+    try:
+        cls = getattr(builtins, obj_type)
+    except AttributeError:
+        cls = getattr(tfields, obj_type)
+
+    # Do the recursion
+    # pylint:disable=consider-using-dict-items
+    for attr in here:
+        # pylint:disable=consider-using-dict-items
+        for key in here[attr]:
+            here[attr][key] = _from_dict(here[attr][key])
+
+    # Build the generic way
+    args = here.pop("args", tuple())
+    args = tuple(args[key] for key in sorted(args))
+    kwargs = here.pop("kwargs", {})
+    assert len(here) == 0
+    if cls in (tuple, list):
+        # TODO: remove this in favour of better pacing tuple?
+        obj = cls(args, **kwargs)
+        return obj
+    obj = cls(*args, **kwargs)
+    return obj
+
+
 class AbstractObject(rna.polymorphism.Storable):
     """
     Abstract base class for all tfields objects implementing polymorphisms
@@ -135,13 +194,14 @@ class AbstractObject(rna.polymorphism.Storable):
         """
         return dict()
 
-    _HIERARCHY_SEPARATOR = "::"
-
-    def _as_dict(self) -> dict:
+    def _as_dict(self, recurse: typing.Dict[str, typing.Callable] = None) -> dict:
         """
         Get an object represenation in a dict format. This is necessary e.g. for saving the full
         file uniquely in the npz format
 
+        Args:
+            recurse: dict of {attribute: callable(iterable) -> dict}
+
         Returns:
             dict: object packed as nested dictionary
         """
@@ -156,12 +216,19 @@ class AbstractObject(rna.polymorphism.Storable):
             ("kwargs", self._kwargs().items()),
         ]:
             for attr, value in iterable:
-                attr = base_attr + self._HIERARCHY_SEPARATOR + attr
-                if hasattr(value, "_as_dict"):
-                    part_dict = value._as_dict()  # pylint: disable=protected-access
+                attr = base_attr + _HIERARCHY_SEPARATOR + attr
+                if (
+                    (recurse is not None and attr in recurse)
+                    or hasattr(value, "_as_dict")
+                ):
+                    if recurse is not None and attr in recurse:
+                        part_dict = recurse[attr](value)
+                    else:
+                        part_dict = value._as_dict()  # pylint: disable=protected-access
+
                     for part_attr, part_value in part_dict.items():
                         content[
-                            attr + self._HIERARCHY_SEPARATOR + part_attr
+                            attr + _HIERARCHY_SEPARATOR + part_attr
                         ] = part_value
                 else:
                     content[attr] = value
@@ -169,110 +236,10 @@ class AbstractObject(rna.polymorphism.Storable):
 
     @classmethod
     def _from_dict(cls, content: dict):
-        try:
-            content.pop("type")
-        except KeyError:
-            # legacy
-            return cls._from_dict_legacy(**content)
-
-        here = {}
-        for string in content:  # TODO no sortelist
-            value = content[string]
-
-            attr, _, end = string.partition(cls._HIERARCHY_SEPARATOR)
-            key, _, end = end.partition(cls._HIERARCHY_SEPARATOR)
-            if attr not in here:
-                here[attr] = {}
-            if key not in here[attr]:
-                here[attr][key] = {}
-            here[attr][key][end] = value
-
-        # Do the recursion
-        # pylint:disable=consider-using-dict-items
-        for attr in here:
-            # pylint:disable=consider-using-dict-items
-            for key in here[attr]:
-                if "type" in here[attr][key]:
-                    obj_type = here[attr][key].get("type")
-                    if isinstance(obj_type, np.ndarray):  # happens on np.load
-                        obj_type = obj_type.tolist()
-                    if isinstance(obj_type, bytes):
-                        # astonishingly, this is not necessary under linux.
-                        # Found under nt. ???
-                        obj_type = obj_type.decode("UTF-8")
-                    obj_type = getattr(tfields, obj_type)
-                    attr_value = (
-                        obj_type._from_dict(  # pylint: disable=protected-access
-                            here[attr][key]
-                        )
-                    )
-                else:  # if len(here[attr][key]) == 1:
-                    attr_value = here[attr][key].pop("")
-                here[attr][key] = attr_value
-
-        # Build the generic way
-        args = here.pop("args", tuple())
-        args = tuple(args[key] for key in sorted(args))
-        kwargs = here.pop("kwargs", {})
-        assert len(here) == 0
-        obj = cls(*args, **kwargs)
-        return obj
-
-    @classmethod
-    def _from_dict_legacy(cls, **content):
-        """
-        legacy method of _from_dict - Opposite of old _as_dict method
-        which is overridden in this version
-        """
-        list_dict = {}
-        kwargs = {}
-        # De-Flatten the first layer of lists
-        for key in sorted(list(content)):
-            if "::" in key:
-                attr, _, end = key.partition("::")
-                if attr not in list_dict:
-                    list_dict[attr] = {}
-
-                index, _, end = end.partition("::")
-                if not index.isdigit():
-                    raise ValueError("None digit index given")
-                index = int(index)
-                if index not in list_dict[attr]:
-                    list_dict[attr][index] = {}
-                list_dict[attr][index][end] = content[key]
-            else:
-                kwargs[key] = content[key]
-
-        # Build the lists (recursively)
-        for key in list(list_dict):
-            sub_dict = list_dict[key]
-            list_dict[key] = []
-            for index in sorted(list(sub_dict)):
-                bulk_type = sub_dict[index].get("bulk_type")
-                bulk_type = bulk_type.tolist()
-                if isinstance(bulk_type, bytes):
-                    # asthonishingly, this is not necessary under linux.
-                    # Found under nt. ???
-                    bulk_type = bulk_type.decode("UTF-8")
-                bulk_type = getattr(tfields, bulk_type)
-                list_dict[key].append(
-                    bulk_type._from_dict_legacy(  # noqa: E501 pylint: disable=protected-access
-                        **sub_dict[index]
-                    )
-                )
+        type_ = content.get("type")
+        assert type_ == cls.__name__
 
-        with cls._bypass_setters(  # pylint: disable=protected-access,no-member
-            "fields", demand_existence=False
-        ):
-            # Build the normal way
-            bulk = kwargs.pop("bulk")
-            bulk_type = kwargs.pop("bulk_type")
-            obj = cls.__new__(cls, bulk, **kwargs)
-
-            # Set list attributes
-            for attr, list_value in list_dict.items():
-                setattr(obj, attr, list_value)
-        return obj
+        return _from_dict(content)
 
 
 class AbstractNdarray(np.ndarray, AbstractObject):
@@ -1952,7 +1919,7 @@ class TensorFields(Tensors):
 
     def _kwargs(self) -> dict:
         content = super()._kwargs()
-        content.pop("fields")
+        content.pop("fields")  # instantiated via _args
         return content
 
     def __getitem__(self, index):
@@ -2411,6 +2378,31 @@ class Maps(sortedcontainers.SortedDict, AbstractObject):
     def _args(self):
         return super()._args() + (list(self.items()),)
 
+    def _as_dict(self, recurse: typing.Dict[str, typing.Callable] = None) -> dict:
+        if recurse is None:
+            recurse = {}
+
+        def recurse_args_0(iterable: typing.List[typing.Tuple[int, typing.Any]]) -> dict:
+            # iterable is list of tuple
+            part_dict = {"type": "list"}
+            for i, (dim, tensor) in enumerate(iterable):
+                content = tensor._as_dict()
+                tuple_key = _HIERARCHY_SEPARATOR.join(["args", str(i), ""])
+                part_dict[tuple_key + "type"] = "tuple"
+                args_key = tuple_key + _HIERARCHY_SEPARATOR.join(["args", ""])
+
+                part_dict[args_key + _HIERARCHY_SEPARATOR.join(["0", "args", "0"])] = dim
+                part_dict[args_key + _HIERARCHY_SEPARATOR.join(["0", "type"])] = 'int'
+
+                for key, value in content.items():
+                    part_dict[args_key + _HIERARCHY_SEPARATOR.join(["1", key])] = value
+
+            return part_dict
+
+        attr = 'args' + _HIERARCHY_SEPARATOR + str(0)
+        recurse[attr] = recurse_args_0
+        return super()._as_dict(recurse=recurse)
+
     def equal(self, other, **kwargs):
         """
         Test equality with other object.
diff --git a/tfields/lib/util.py b/tfields/lib/util.py
index c71978f..cf36565 100644
--- a/tfields/lib/util.py
+++ b/tfields/lib/util.py
@@ -2,6 +2,7 @@
 Various utility functions
 """
 import itertools
+import typing
 from six import string_types
 import numpy as np
 
@@ -20,15 +21,18 @@ def pairwise(iterable):
     return zip(a, b)
 
 
-def flatten(seq, container=None, keep_types=None):
+def flatten(seq, container=None, keep_types=None, key: typing.Callable = None):
     """
     Approach to flatten a nested sequence.
+
     Args:
         seq (iterable): iterable to be flattened
         containter (iterable): iterable defining an append method. Values will
             be appended there
         keep_types (list of type): types that should not be flattened but kept
             in nested form
+        key (callable): callable with the signature key(iterable) -> iterable
+
     Examples:
         >>> from tfields.lib.util import flatten
         >>> import numpy as np
@@ -43,21 +47,36 @@ def flatten(seq, container=None, keep_types=None):
         >>> flatten([[0, 0, 0, 'A'], [1, 2, 3]])
         [0, 0, 0, 'A', 1, 2, 3]
 
+        Dictionaries will return flattened keys
+        >>> flatten({"a": 1, "b": 2})
+        ['a', 'b']
+
+        You can use the key keyword to specify a transformation on the iterable:
+        >>> flatten({"a": {"a1": 1, "a2": 4}, "b": 2}, key=dict.values)
+        [1, 4, 2]
+
+        >>> def dict_flat_key(item):
+        ...     if isinstance(item, dict):
+        ...         return item.values()
+        ...     return item
+        >>> flatten({"a": {"a1": 1, "a2": [3, 4]}, "b": 2}, key=dict_flat_key)
+        [1, 3, 4, 2]
     """
     if keep_types is None:
         keep_types = []
     if container is None:
         container = []
-    # pylint:disable=invalid-name
-    for s in seq:
+    if key is not None:
+        seq = key(seq)
+    for item in seq:
         if (
-            hasattr(s, "__iter__")
-            and not isinstance(s, string_types)
-            and not any((isinstance(s, t) for t in keep_types))
+            hasattr(item, "__iter__")
+            and not isinstance(item, string_types)
+            and not any((isinstance(item, t) for t in keep_types))
         ):
-            flatten(s, container, keep_types)
+            flatten(item, container, keep_types, key=key)
         else:
-            container.append(s)
+            container.append(item)
     return container
 
 
-- 
GitLab