diff --git a/tfields/core.py b/tfields/core.py index ce7a441476d344846e6a543ec83abc25b5d1b470..92e8d7dedaf60947bf9b8b941f2f5d7e925ce098 100644 --- a/tfields/core.py +++ b/tfields/core.py @@ -934,25 +934,36 @@ class Tensors(AbstractNdarray): """ return any(self.equal(other, return_bool=False).all(1)) - def indices(self, tensor): + def indices(self, tensor, rtol=None, atol=None, early_stopping=False): """ Returns: list of int: indices of tensor occuring """ + x, y = np.asarray(self), np.asarray(tensor) + if rtol is None and atol is None: + equal_method = np.array_equal + else: + equal_method = lambda a, b: np.isclose(a, b, rtol=rtol, atol=atol) indices = [] - for i, p in enumerate(self): - if all(p == tensor): + for i, p in enumerate(x): + if equal_method(p, y).all(): indices.append(i) + if early_stopping: + break return indices - def index(self, tensor): + def index(self, tensor, **kwargs): """ Args: tensor Returns: int: index of tensor occuring + Raises: + ValueError: Multiple occurences + use early_stopping=True if first entry should be returned """ - indices = self.indices(tensor) + indices = self.indices(tensor, **kwargs) + print(indices) if not indices: return None if len(indices) == 1: diff --git a/tfields/lib/grid.py b/tfields/lib/grid.py index e3ff95e71dc23f0f6d80d191e623c471cf360f7f..d9f886dab613851796f726c7e4f48ca10076e57b 100644 --- a/tfields/lib/grid.py +++ b/tfields/lib/grid.py @@ -123,7 +123,7 @@ def igrid(*base_vectors, **kwargs): return obj -def base_vectors(array): +def base_vectors(array, rtol=None, atol=None): """ describe the array in terms of base vectors Inverse function of igrid @@ -141,6 +141,17 @@ def base_vectors(array): """ if len(array.shape) == 1: values = sorted(set(array)) + if rtol is not None and atol is not None: + duplicates = set([]) + for v1, v2 in tfields.lib.util.pairwise(values): + if np.isclose(v1, v2, rtol=rtol, atol=atol): + duplicates.add(v2) + values = values.difference(duplicates) + # round to given absolute precision + n_digits = int(abs(np.log10(atol))) + 1 + values = {round(v, n_digits) for v in values} + elif rtol is not None or atol is not None: + raise ValueError("rtol and atol arguments only come in pairs.") spacing = complex(0, len(values)) vmin = min(values) vmax = max(values) @@ -148,7 +159,7 @@ def base_vectors(array): elif len(array.shape) == 2: bases = [] for i in range(array.shape[1]): - bases.append(base_vectors(array[:, i])) + bases.append(base_vectors(array[:, i], rtol=rtol, atol=atol)) return bases else: raise NotImplementedError("Description yet only till rank 1")