From e3d73b70a2249b301f7828e7443c6a077c458743 Mon Sep 17 00:00:00 2001 From: dboe <dboe@ipp.mpg.de> Date: Sat, 4 Dec 2021 22:22:33 +0100 Subject: [PATCH] new remap method --- tests/test_util_sets.py | 69 +++++++++++++++++++++++++++++++++++++++++ tfields/lib/sets.py | 38 +++++++++++++++++++++-- tfields/lib/util.py | 1 + 3 files changed, 106 insertions(+), 2 deletions(-) create mode 100644 tests/test_util_sets.py diff --git a/tests/test_util_sets.py b/tests/test_util_sets.py new file mode 100644 index 0000000..ef4e832 --- /dev/null +++ b/tests/test_util_sets.py @@ -0,0 +1,69 @@ +import unittest +import numpy as np +import tfields + + +class RemapTest(unittest.TestCase): + def setUp(self): + self.method = tfields.lib.sets.remap + + def test_remap_1d(self): + array = np.array([1, 2, 2, 1]) + keys = np.array([1, 2]) + values = np.array([0, 10]) + solution = np.array([0, 10, 10, 0]) + + result = self.method(array, keys, values) + comparison = result == solution + self.assertTrue(comparison.all()) + self.assertIsNot(array, result) + + def test_remap_inplace(self): + array = np.array([1, 2, 3, 1]) + keys = np.array([1, 2]) + values = np.array([0, 10]) + solution = np.array([0, 10, 3, 0]) + + # inplace + result = self.method(array, keys, values, inplace=True) + comparison = result == solution + self.assertTrue(comparison.all()) + self.assertIs(array, result) + + def test_remap(self): + array = np.array([1, 3, 2, 1]).reshape(2, 2) + keys = np.array([1, 2]) + values = np.array([0, 10]) + solution = np.array([[0, 3], [10, 0]]) + + result = self.method(array, keys, values) + comparison = result == solution + self.assertTrue(comparison.all()) + + def test_remap_large(self): + shape = (20, 200) + num = np.prod(shape) + array = np.arange(num).reshape(shape) + keys = np.arange(num) + values = np.flip(np.arange(num)) + solution = np.flip(np.arange(num)).reshape(shape) + + result = self.method(array, keys, values) + comparison = result == solution + self.assertTrue(comparison.all()) + + def test_remap_incomplete(self): + shape = (4, 4) + num = np.prod(shape) + array = np.arange(num).reshape(shape) + keys = np.arange(num) + values = np.arange(num - 1, -1, -1) + keys = keys[1:-1] + values = values[1:-1] + solution = np.flip(np.arange(num)).reshape(shape) + solution[0, 0] = 0 # skipped first value in remap keys + solution[-1, -1] = num - 1 # skipped first value in remap keys + + result = self.method(array, keys, values) + comparison = result == solution + self.assertTrue(comparison.all()) diff --git a/tfields/lib/sets.py b/tfields/lib/sets.py index d52fa73..8180b0f 100644 --- a/tfields/lib/sets.py +++ b/tfields/lib/sets.py @@ -20,6 +20,7 @@ class UnionFind(object): ufset.find(obja) != ufset.find(objb) ufset.union(obja, objb) """ + def __init__(self): """ Create an empty union find data structure. @@ -93,7 +94,7 @@ class UnionFind(object): for i in sets.itervalues(): if i: out.append(repr(i)) - return ', '.join(out) + return ", ".join(out) def __call__(self, iterator): """ @@ -177,6 +178,39 @@ def disjoint_group_indices(iterator): return uf.group_indices(iterator) -if __name__ == '__main__': +def remap( + arr: np.ndarray, keys: np.ndarray, values: np.ndarray, inplace=False +) -> np.ndarray: + """ + Given an input array, remap its entries corresponding to 'keys' to 'values' + + Args: + input: array to remap + keys: values to be replaced + values : values to replace 'keys' with + + Returns: + output: + like 'input', but with elements remapped according to the mapping + defined by 'keys' and 'values' + """ + + assert arr.dtype == int + assert arr.min() >= 0 + + """ + Assuming the values are between 0 and some maximum integer, + one could implement a fast replace by using the numpy-array as int->int dict, like below + """ + if not inplace: + arr = arr.copy() + mp = np.arange(0, arr.max() + 1) + mp[keys] = values + arr.ravel()[:] = mp[arr.ravel()] + return arr + + +if __name__ == "__main__": import doctest + doctest.testmod() diff --git a/tfields/lib/util.py b/tfields/lib/util.py index 7850043..410f87a 100644 --- a/tfields/lib/util.py +++ b/tfields/lib/util.py @@ -194,6 +194,7 @@ def index(arr, entry, rtol=0, atol=0, equal_nan=False, axis=None): """ Examples: >>> import tfields + >>> import numpy as np >>> a = np.array([[1, 0, 0], [1, 0, 0], [2, 3, 4]]) >>> tfields.lib.util.index(a, [2, 3, 4], axis=0) 2 -- GitLab