Commit 1e6ced88 by Daniel Boeckenhoff

### new indices method

parent 071c85db
 ... @@ -938,19 +938,33 @@ class Tensors(AbstractNdarray): ... @@ -938,19 +938,33 @@ class Tensors(AbstractNdarray): """ """ Returns: Returns: list of int: indices of tensor occuring list of int: indices of tensor occuring Examples: >>> import tfields >>> p = tfields.Tensors([[1,2,3], [4,5,6], [6,7,8], [4,5,6], ... [4.1, 5, 6]]) >>> p.indices([4,5,6]) array([1, 3]) >>> p.indices([4,5,6.1], rtol=1e-5, atol=1e-1) array([1, 3, 4]) """ """ x, y = np.asarray(self), np.asarray(tensor) x, y = np.asarray(self), np.asarray(tensor) if rtol is None and atol is None: if rtol is None and atol is None: equal_method = np.array_equal equal_method = np.equal else: else: equal_method = lambda a, b: np.isclose(a, b, rtol=rtol, atol=atol) equal_method = lambda a, b: np.isclose(a, b, rtol=rtol, atol=atol) indices = [] for i, p in enumerate(x): # inspired by https://stackoverflow.com/questions/19228295/find-ordered-vector-in-numpy-array if equal_method(p, y).all(): indices = np.where(np.all(equal_method((x-y), 0), axis=1))[0] indices.append(i) if early_stopping: break return indices return indices # old manual method # indices = [] # for i, p in enumerate(x): # if equal_method(p, y).all(): # indices.append(i) # if early_stopping: # break # return indices def index(self, tensor, **kwargs): def index(self, tensor, **kwargs): """ """ ... ...
 import numpy as np import numpy as np import functools import functools import tfields.lib.util def ensure_complex(*base_vectors): def ensure_complex(*base_vectors): ... @@ -140,10 +141,10 @@ def base_vectors(array, rtol=None, atol=None): ... @@ -140,10 +141,10 @@ def base_vectors(array, rtol=None, atol=None): """ """ if len(array.shape) == 1: if len(array.shape) == 1: values = sorted(set(array)) values = set(array) if rtol is not None and atol is not None: if rtol is not None and atol is not None: duplicates = set([]) duplicates = set([]) for v1, v2 in tfields.lib.util.pairwise(values): for v1, v2 in tfields.lib.util.pairwise(sorted(values)): if np.isclose(v1, v2, rtol=rtol, atol=atol): if np.isclose(v1, v2, rtol=rtol, atol=atol): duplicates.add(v2) duplicates.add(v2) values = values.difference(duplicates) values = values.difference(duplicates) ... ...
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!