diff --git a/tfields/core.py b/tfields/core.py index 92e8d7dedaf60947bf9b8b941f2f5d7e925ce098..38b3f0c34b164681cd942d299c67f4e618f464b1 100644 --- a/tfields/core.py +++ b/tfields/core.py @@ -938,19 +938,33 @@ class Tensors(AbstractNdarray): """ Returns: list of int: indices of tensor occuring + Examples: + >>> import tfields + >>> p = tfields.Tensors([[1,2,3], [4,5,6], [6,7,8], [4,5,6], + ... [4.1, 5, 6]]) + >>> p.indices([4,5,6]) + array([1, 3]) + >>> p.indices([4,5,6.1], rtol=1e-5, atol=1e-1) + array([1, 3, 4]) + """ x, y = np.asarray(self), np.asarray(tensor) if rtol is None and atol is None: - equal_method = np.array_equal + equal_method = np.equal else: equal_method = lambda a, b: np.isclose(a, b, rtol=rtol, atol=atol) - indices = [] - for i, p in enumerate(x): - if equal_method(p, y).all(): - indices.append(i) - if early_stopping: - break + + # inspired by https://stackoverflow.com/questions/19228295/find-ordered-vector-in-numpy-array + indices = np.where(np.all(equal_method((x-y), 0), axis=1))[0] return indices + # old manual method + # indices = [] + # for i, p in enumerate(x): + # if equal_method(p, y).all(): + # indices.append(i) + # if early_stopping: + # break + # return indices def index(self, tensor, **kwargs): """ diff --git a/tfields/lib/grid.py b/tfields/lib/grid.py index d9f886dab613851796f726c7e4f48ca10076e57b..8fde1b8f2f38775d23cb3825de147dc171bc52d5 100644 --- a/tfields/lib/grid.py +++ b/tfields/lib/grid.py @@ -1,5 +1,6 @@ import numpy as np import functools +import tfields.lib.util def ensure_complex(*base_vectors): @@ -140,10 +141,10 @@ def base_vectors(array, rtol=None, atol=None): """ if len(array.shape) == 1: - values = sorted(set(array)) + values = set(array) if rtol is not None and atol is not None: duplicates = set([]) - for v1, v2 in tfields.lib.util.pairwise(values): + for v1, v2 in tfields.lib.util.pairwise(sorted(values)): if np.isclose(v1, v2, rtol=rtol, atol=atol): duplicates.add(v2) values = values.difference(duplicates)