From e3d73b70a2249b301f7828e7443c6a077c458743 Mon Sep 17 00:00:00 2001
From: dboe <dboe@ipp.mpg.de>
Date: Sat, 4 Dec 2021 22:22:33 +0100
Subject: [PATCH] new remap method

---
 tests/test_util_sets.py | 69 +++++++++++++++++++++++++++++++++++++++++
 tfields/lib/sets.py     | 38 +++++++++++++++++++++--
 tfields/lib/util.py     |  1 +
 3 files changed, 106 insertions(+), 2 deletions(-)
 create mode 100644 tests/test_util_sets.py

diff --git a/tests/test_util_sets.py b/tests/test_util_sets.py
new file mode 100644
index 0000000..ef4e832
--- /dev/null
+++ b/tests/test_util_sets.py
@@ -0,0 +1,69 @@
+import unittest
+import numpy as np
+import tfields
+
+
+class RemapTest(unittest.TestCase):
+    def setUp(self):
+        self.method = tfields.lib.sets.remap
+
+    def test_remap_1d(self):
+        array = np.array([1, 2, 2, 1])
+        keys = np.array([1, 2])
+        values = np.array([0, 10])
+        solution = np.array([0, 10, 10, 0])
+
+        result = self.method(array, keys, values)
+        comparison = result == solution
+        self.assertTrue(comparison.all())
+        self.assertIsNot(array, result)
+
+    def test_remap_inplace(self):
+        array = np.array([1, 2, 3, 1])
+        keys = np.array([1, 2])
+        values = np.array([0, 10])
+        solution = np.array([0, 10, 3, 0])
+
+        # inplace
+        result = self.method(array, keys, values, inplace=True)
+        comparison = result == solution
+        self.assertTrue(comparison.all())
+        self.assertIs(array, result)
+
+    def test_remap(self):
+        array = np.array([1, 3, 2, 1]).reshape(2, 2)
+        keys = np.array([1, 2])
+        values = np.array([0, 10])
+        solution = np.array([[0, 3], [10, 0]])
+
+        result = self.method(array, keys, values)
+        comparison = result == solution
+        self.assertTrue(comparison.all())
+
+    def test_remap_large(self):
+        shape = (20, 200)
+        num = np.prod(shape)
+        array = np.arange(num).reshape(shape)
+        keys = np.arange(num)
+        values = np.flip(np.arange(num))
+        solution = np.flip(np.arange(num)).reshape(shape)
+
+        result = self.method(array, keys, values)
+        comparison = result == solution
+        self.assertTrue(comparison.all())
+
+    def test_remap_incomplete(self):
+        shape = (4, 4)
+        num = np.prod(shape)
+        array = np.arange(num).reshape(shape)
+        keys = np.arange(num)
+        values = np.arange(num - 1, -1, -1)
+        keys = keys[1:-1]
+        values = values[1:-1]
+        solution = np.flip(np.arange(num)).reshape(shape)
+        solution[0, 0] = 0  # skipped first value in remap keys
+        solution[-1, -1] = num - 1  # skipped first value in remap keys
+
+        result = self.method(array, keys, values)
+        comparison = result == solution
+        self.assertTrue(comparison.all())
diff --git a/tfields/lib/sets.py b/tfields/lib/sets.py
index d52fa73..8180b0f 100644
--- a/tfields/lib/sets.py
+++ b/tfields/lib/sets.py
@@ -20,6 +20,7 @@ class UnionFind(object):
     ufset.find(obja) != ufset.find(objb)
     ufset.union(obja, objb)
     """
+
     def __init__(self):
         """
         Create an empty union find data structure.
@@ -93,7 +94,7 @@ class UnionFind(object):
         for i in sets.itervalues():
             if i:
                 out.append(repr(i))
-        return ', '.join(out)
+        return ", ".join(out)
 
     def __call__(self, iterator):
         """
@@ -177,6 +178,39 @@ def disjoint_group_indices(iterator):
     return uf.group_indices(iterator)
 
 
-if __name__ == '__main__':
+def remap(
+    arr: np.ndarray, keys: np.ndarray, values: np.ndarray, inplace=False
+) -> np.ndarray:
+    """
+    Given an input array, remap its entries corresponding to 'keys' to 'values'
+
+    Args:
+        input: array to remap
+        keys: values to be replaced
+        values : values to replace 'keys' with
+
+    Returns:
+        output:
+            like 'input', but with elements remapped according to the mapping
+            defined by 'keys' and 'values'
+    """
+
+    assert arr.dtype == int
+    assert arr.min() >= 0
+
+    """
+    Assuming the values are between 0 and some maximum integer,
+    one could implement a fast replace by using the numpy-array as int->int dict, like below
+    """
+    if not inplace:
+        arr = arr.copy()
+    mp = np.arange(0, arr.max() + 1)
+    mp[keys] = values
+    arr.ravel()[:] = mp[arr.ravel()]
+    return arr
+
+
+if __name__ == "__main__":
     import doctest
+
     doctest.testmod()
diff --git a/tfields/lib/util.py b/tfields/lib/util.py
index 7850043..410f87a 100644
--- a/tfields/lib/util.py
+++ b/tfields/lib/util.py
@@ -194,6 +194,7 @@ def index(arr, entry, rtol=0, atol=0, equal_nan=False, axis=None):
     """
     Examples:
         >>> import tfields
+        >>> import numpy as np
         >>> a = np.array([[1, 0, 0], [1, 0, 0], [2, 3, 4]])
         >>> tfields.lib.util.index(a, [2, 3, 4], axis=0)
         2
-- 
GitLab