Skip to content
Snippets Groups Projects
Commit e3d73b70 authored by dboe's avatar dboe
Browse files

new remap method

parent 3ed3ff89
No related branches found
No related tags found
No related merge requests found
import unittest
import numpy as np
import tfields
class RemapTest(unittest.TestCase):
def setUp(self):
self.method = tfields.lib.sets.remap
def test_remap_1d(self):
array = np.array([1, 2, 2, 1])
keys = np.array([1, 2])
values = np.array([0, 10])
solution = np.array([0, 10, 10, 0])
result = self.method(array, keys, values)
comparison = result == solution
self.assertTrue(comparison.all())
self.assertIsNot(array, result)
def test_remap_inplace(self):
array = np.array([1, 2, 3, 1])
keys = np.array([1, 2])
values = np.array([0, 10])
solution = np.array([0, 10, 3, 0])
# inplace
result = self.method(array, keys, values, inplace=True)
comparison = result == solution
self.assertTrue(comparison.all())
self.assertIs(array, result)
def test_remap(self):
array = np.array([1, 3, 2, 1]).reshape(2, 2)
keys = np.array([1, 2])
values = np.array([0, 10])
solution = np.array([[0, 3], [10, 0]])
result = self.method(array, keys, values)
comparison = result == solution
self.assertTrue(comparison.all())
def test_remap_large(self):
shape = (20, 200)
num = np.prod(shape)
array = np.arange(num).reshape(shape)
keys = np.arange(num)
values = np.flip(np.arange(num))
solution = np.flip(np.arange(num)).reshape(shape)
result = self.method(array, keys, values)
comparison = result == solution
self.assertTrue(comparison.all())
def test_remap_incomplete(self):
shape = (4, 4)
num = np.prod(shape)
array = np.arange(num).reshape(shape)
keys = np.arange(num)
values = np.arange(num - 1, -1, -1)
keys = keys[1:-1]
values = values[1:-1]
solution = np.flip(np.arange(num)).reshape(shape)
solution[0, 0] = 0 # skipped first value in remap keys
solution[-1, -1] = num - 1 # skipped first value in remap keys
result = self.method(array, keys, values)
comparison = result == solution
self.assertTrue(comparison.all())
......@@ -20,6 +20,7 @@ class UnionFind(object):
ufset.find(obja) != ufset.find(objb)
ufset.union(obja, objb)
"""
def __init__(self):
"""
Create an empty union find data structure.
......@@ -93,7 +94,7 @@ class UnionFind(object):
for i in sets.itervalues():
if i:
out.append(repr(i))
return ', '.join(out)
return ", ".join(out)
def __call__(self, iterator):
"""
......@@ -177,6 +178,39 @@ def disjoint_group_indices(iterator):
return uf.group_indices(iterator)
if __name__ == '__main__':
def remap(
arr: np.ndarray, keys: np.ndarray, values: np.ndarray, inplace=False
) -> np.ndarray:
"""
Given an input array, remap its entries corresponding to 'keys' to 'values'
Args:
input: array to remap
keys: values to be replaced
values : values to replace 'keys' with
Returns:
output:
like 'input', but with elements remapped according to the mapping
defined by 'keys' and 'values'
"""
assert arr.dtype == int
assert arr.min() >= 0
"""
Assuming the values are between 0 and some maximum integer,
one could implement a fast replace by using the numpy-array as int->int dict, like below
"""
if not inplace:
arr = arr.copy()
mp = np.arange(0, arr.max() + 1)
mp[keys] = values
arr.ravel()[:] = mp[arr.ravel()]
return arr
if __name__ == "__main__":
import doctest
doctest.testmod()
......@@ -194,6 +194,7 @@ def index(arr, entry, rtol=0, atol=0, equal_nan=False, axis=None):
"""
Examples:
>>> import tfields
>>> import numpy as np
>>> a = np.array([[1, 0, 0], [1, 0, 0], [2, 3, 4]])
>>> tfields.lib.util.index(a, [2, 3, 4], axis=0)
2
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment